Commit 2453725c authored by Joseph Noir's avatar Joseph Noir

The command_dispatcher now uses event callbacks

to handle results from kernel executions.
parent f75f328b
This diff is collapsed.
......@@ -37,33 +37,125 @@
#include "cppa/actor.hpp"
#include "cppa/opencl/global.hpp"
#include "cppa/response_handle.hpp"
#include "cppa/opencl/smart_ptr.hpp"
namespace cppa { namespace opencl {
class command {
class command : public ref_counted {
public:
command* next;
// bool is_new_job; // if false == call handle_cl_result
// struct callbacks {
// std::function<cl_job_id(cl_command_queue)> enqueue_cl_task;
// std::function<void()> handle_cl_result;
// };
// union { callbacks cbs; cl_job_id jid; };
virtual void enqueue (command_queue_ptr queue) = 0;
std::function<void(cl_command_queue)> fun;
};
class command_dummy : public command {
public:
void enqueue(command_queue_ptr) { }
};
template<typename T>
class command_impl : public command {
actor* sender;
public:
command_impl(response_handle handle,
kernel_ptr kernel,
std::vector<mem_ptr> arguments,
std::vector<size_t> global_dimensions,
std::vector<size_t> local_dimensions)
: m_number_of_values(1)
, m_handle(handle)
, m_kernel(kernel)
, m_arguments(arguments)
, m_global_dimensions(global_dimensions)
, m_local_dimensions(local_dimensions)
{
m_kernel_event.adopt(cl_event());
for (size_t s : m_global_dimensions) {
m_number_of_values *= s;
}
}
void enqueue (command_queue_ptr queue) {
this->ref();
cl_int err{0};
m_queue = queue;
auto ptr = m_kernel_event.get();
template<typename T>
inline command(T f, actor* sender)
: next(nullptr), fun(f), sender(sender) { }
/* enqueue kernel */
err = clEnqueueNDRangeKernel(m_queue.get(),
m_kernel.get(),
3,
NULL,
m_global_dimensions.data(),
m_local_dimensions.data(),
0,
nullptr,
&ptr);
if (err != CL_SUCCESS) {
throw std::runtime_error("[!!!] clEnqueueNDRangeKernel: '"
+ get_opencl_error(err)
+ "'.");
}
err = clSetEventCallback(ptr,
CL_COMPLETE,
[](cl_event, cl_int, void* data) {
auto cmd = reinterpret_cast<command_impl*>(data);
cmd->handle_results();
cmd->deref();
},
this);
if (err != CL_SUCCESS) {
throw std::runtime_error("[!!!] clSetEventCallback: '"
+ get_opencl_error(err)
+ "'.");
}
}
inline command() : next(nullptr), sender(nullptr) { }
private:
int m_number_of_values;
response_handle m_handle;
kernel_ptr m_kernel;
event_ptr m_kernel_event;
command_queue_ptr m_queue;
std::vector<mem_ptr> m_arguments;
std::vector<size_t> m_global_dimensions;
std::vector<size_t> m_local_dimensions;
void handle_results () {
/* get results from gpu */
cl_int err{0};
cl_event read_event;
T results(m_number_of_values);
err = clEnqueueReadBuffer(m_queue.get(),
m_arguments[0].get(),
CL_TRUE,
0,
sizeof(typename T::value_type) * m_number_of_values,
results.data(),
0,
NULL,
&read_event);
clReleaseEvent(read_event);
if (err != CL_SUCCESS) {
throw std::runtime_error("[!!!] clEnqueueReadBuffer: '"
+ get_opencl_error(err)
+ "'.");
}
reply_to(m_handle, results);
}
};
typedef intrusive_ptr<command> command_ptr;
} } // namespace cppa::opencl
#endif // CPPA_OPENCL_COMMAND_HPP
......@@ -48,6 +48,10 @@
namespace cppa { namespace opencl {
struct dereferencer {
inline void operator()(ref_counted* ptr) { ptr->deref(); }
};
#ifdef CPPA_OPENCL
class command_dispatcher {
......@@ -59,7 +63,8 @@ class command_dispatcher {
friend class program;
friend void enqueue_to_dispatcher(command_dispatcher*, command*);
friend void enqueue_to_dispatcher(command_dispatcher* dispatcher,
command_ptr cmd);
public:
......@@ -102,7 +107,7 @@ class command_dispatcher {
, max_itms_per_dim(std::move(max_itms_per_dim)) { }
};
typedef intrusive::blocking_single_reader_queue<command> job_queue;
typedef intrusive::blocking_single_reader_queue<command,dereferencer> job_queue;
static inline command_dispatcher* create_singleton() {
return new command_dispatcher;
......@@ -114,6 +119,7 @@ class command_dispatcher {
std::atomic<unsigned> dev_id_gen;
job_queue m_job_queue;
command_ptr m_dummy;
std::thread m_supervisor;
......@@ -121,7 +127,9 @@ class command_dispatcher {
context_ptr m_context;
static void worker_loop(worker*);
static void supervisor_loop(command_dispatcher *scheduler, job_queue*);
static void supervisor_loop(command_dispatcher *scheduler,
job_queue*,
command_ptr);
};
#else // CPPA_OPENCL
......
......@@ -35,8 +35,10 @@
namespace cppa { namespace opencl {
void enqueue_to_dispatcher(command_dispatcher* dispatcher,
command* cmd) {
dispatcher->m_job_queue.push_back(cmd);
command_ptr cmd) {
cmd->ref(); // implicit ref count of m_job_queue
dispatcher->m_job_queue.push_back(cmd.get());
}
} } // namespace cppa::opencl
......@@ -43,13 +43,14 @@ struct command_dispatcher::worker {
command_dispatcher* m_parent;
typedef unique_ptr<command> job_ptr;
typedef command_ptr job_ptr;
job_queue* m_job_queue;
thread m_thread;
job_ptr m_dummy;
worker(command_dispatcher* parent, job_queue* jq)
: m_parent(parent), m_job_queue(jq) { }
worker(command_dispatcher* parent, job_queue* jq, job_ptr dummy)
: m_parent(parent), m_job_queue(jq), m_dummy(dummy) { }
void start() {
m_thread = thread(&command_dispatcher::worker_loop, this);
......@@ -65,12 +66,13 @@ struct command_dispatcher::worker {
/* wait for device */
/* get results */
/* wait for job */
job.reset(m_job_queue->pop());
if(job->fun) {
// adopt reference count of job queue
job.adopt(m_job_queue->pop());
if(job != m_dummy) {
try {
cl_command_queue cmd_q =
m_parent->m_devices.front().cmd_queue.get();
job->fun(cmd_q);
job->enqueue(cmd_q);
}
catch (exception& e) {
cerr << e.what() << endl;
......@@ -92,9 +94,9 @@ void command_dispatcher::worker_loop(command_dispatcher::worker* w) {
}
void command_dispatcher::supervisor_loop(command_dispatcher* scheduler,
job_queue* jq) {
job_queue* jq, command_ptr m_dummy) {
unique_ptr<command_dispatcher::worker> worker;
worker.reset(new command_dispatcher::worker(scheduler, jq));
worker.reset(new command_dispatcher::worker(scheduler, jq, m_dummy));
worker->start();
worker->m_thread.join();
worker.reset();
......@@ -102,6 +104,9 @@ void command_dispatcher::supervisor_loop(command_dispatcher* scheduler,
}
void command_dispatcher::initialize() {
m_dummy = make_counted<command_dummy>();
cl_int err{0};
/* find up to two available platforms */
......@@ -217,11 +222,13 @@ void command_dispatcher::initialize() {
}
m_supervisor = thread(&command_dispatcher::supervisor_loop,
this,
&m_job_queue);
&m_job_queue,
m_dummy);
}
void command_dispatcher::destroy() {
m_job_queue.push_back(new command);
m_dummy->ref(); // reference of m_job_queue
m_job_queue.push_back(m_dummy.get());
m_supervisor.join();
delete this;
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment