From 7a9bbff7d641b82deae73e043fe1f02b7492993b Mon Sep 17 00:00:00 2001 From: Francisco Jerez Date: Mon, 16 Sep 2013 21:50:40 -0700 Subject: [PATCH] clover: Clean up the kernel and program object interface. [ Tom Stellard: Make sure to bind global arguments before retrieving handles. ] Tested-by: Tom Stellard --- .../state_trackers/clover/api/kernel.cpp | 34 +++----- .../state_trackers/clover/api/program.cpp | 36 ++++----- .../state_trackers/clover/core/kernel.cpp | 77 +++++++++++-------- .../state_trackers/clover/core/kernel.hpp | 23 ++++-- .../state_trackers/clover/core/module.cpp | 25 ------ .../state_trackers/clover/core/module.hpp | 29 ++----- .../state_trackers/clover/core/program.cpp | 30 ++++++-- .../state_trackers/clover/core/program.hpp | 21 +++-- 8 files changed, 139 insertions(+), 136 deletions(-) diff --git a/src/gallium/state_trackers/clover/api/kernel.cpp b/src/gallium/state_trackers/clover/api/kernel.cpp index dd742c3bb05..e55a55b546e 100644 --- a/src/gallium/state_trackers/clover/api/kernel.cpp +++ b/src/gallium/state_trackers/clover/api/kernel.cpp @@ -33,15 +33,12 @@ clCreateKernel(cl_program d_prog, const char *name, cl_int *r_errcode) try { if (!name) throw error(CL_INVALID_VALUE); - if (prog.binaries().empty()) - throw error(CL_INVALID_PROGRAM_EXECUTABLE); - - auto sym = prog.binaries().begin()->second.sym(name); + auto &sym = find(name_equals(name), prog.symbols()); ret_error(r_errcode, CL_SUCCESS); return new kernel(prog, name, range(sym.args)); -} catch (module::noent_error &e) { +} catch (std::out_of_range &e) { ret_error(r_errcode, CL_INVALID_KERNEL_NAME); return NULL; @@ -54,11 +51,7 @@ PUBLIC cl_int clCreateKernelsInProgram(cl_program d_prog, cl_uint count, cl_kernel *rd_kerns, cl_uint *r_count) try { auto &prog = obj(d_prog); - - if (prog.binaries().empty()) - throw error(CL_INVALID_PROGRAM_EXECUTABLE); - - auto &syms = prog.binaries().begin()->second.syms; + auto &syms = prog.symbols(); if (rd_kerns && count < syms.size()) throw error(CL_INVALID_VALUE); @@ -102,15 +95,12 @@ clReleaseKernel(cl_kernel d_kern) try { PUBLIC cl_int clSetKernelArg(cl_kernel d_kern, cl_uint idx, size_t size, const void *value) try { - auto &kern = obj(d_kern); - - if (idx >= kern.args.size()) - throw error(CL_INVALID_ARG_INDEX); - - kern.args[idx]->set(size, value); - + obj(d_kern).args().at(idx).set(size, value); return CL_SUCCESS; +} catch (std::out_of_range &e) { + return CL_INVALID_ARG_INDEX; + } catch (error &e) { return e.get(); } @@ -127,7 +117,7 @@ clGetKernelInfo(cl_kernel d_kern, cl_kernel_info param, break; case CL_KERNEL_NUM_ARGS: - buf.as_scalar() = kern.args.size(); + buf.as_scalar() = kern.args().size(); break; case CL_KERNEL_REFERENCE_COUNT: @@ -160,8 +150,8 @@ clGetKernelWorkGroupInfo(cl_kernel d_kern, cl_device_id d_dev, auto &kern = obj(d_kern); auto pdev = pobj(d_dev); - if ((!pdev && kern.prog.binaries().size() != 1) || - (pdev && !kern.prog.binaries().count(pdev))) + if ((!pdev && kern.prog.devices().size() != 1) || + (pdev && !count(*pdev, kern.prog.devices()))) throw error(CL_INVALID_DEVICE); switch (param) { @@ -210,10 +200,10 @@ namespace { if (any_of([](kernel::argument &arg) { return !arg.set(); - }, map(derefs(), kern.args))) + }, kern.args())) throw error(CL_INVALID_KERNEL_ARGS); - if (!kern.prog.binaries().count(&q.dev)) + if (!count(q.dev, kern.prog.devices())) throw error(CL_INVALID_PROGRAM_EXECUTABLE); } diff --git a/src/gallium/state_trackers/clover/api/program.cpp b/src/gallium/state_trackers/clover/api/program.cpp index 45640424a0d..01133813f98 100644 --- a/src/gallium/state_trackers/clover/api/program.cpp +++ b/src/gallium/state_trackers/clover/api/program.cpp @@ -173,11 +173,11 @@ clGetProgramInfo(cl_program d_prog, cl_program_info param, break; case CL_PROGRAM_NUM_DEVICES: - buf.as_scalar() = prog.binaries().size(); + buf.as_scalar() = prog.devices().size(); break; case CL_PROGRAM_DEVICES: - buf.as_vector() = map(keys(), prog.binaries()); + buf.as_vector() = descs(prog.devices()); break; case CL_PROGRAM_SOURCE: @@ -185,25 +185,23 @@ clGetProgramInfo(cl_program d_prog, cl_program_info param, break; case CL_PROGRAM_BINARY_SIZES: - buf.as_vector() = - map([](const std::pair &ent) { - compat::ostream::buffer_t bin; - compat::ostream s(bin); - ent.second.serialize(s); - return bin.size(); - }, - prog.binaries()); + buf.as_vector() = map([&](const device &dev) { + compat::ostream::buffer_t bin; + compat::ostream s(bin); + prog.binary(dev).serialize(s); + return bin.size(); + }, + prog.devices()); break; case CL_PROGRAM_BINARIES: - buf.as_matrix() = - map([](const std::pair &ent) { - compat::ostream::buffer_t bin; - compat::ostream s(bin); - ent.second.serialize(s); - return bin; - }, - prog.binaries()); + buf.as_matrix() = map([&](const device &dev) { + compat::ostream::buffer_t bin; + compat::ostream s(bin); + prog.binary(dev).serialize(s); + return bin; + }, + prog.devices()); break; default: @@ -224,7 +222,7 @@ clGetProgramBuildInfo(cl_program d_prog, cl_device_id d_dev, auto &prog = obj(d_prog); auto &dev = obj(d_dev); - if (!count(dev, prog.ctx.devs())) + if (!count(dev, prog.devices())) return CL_INVALID_DEVICE; switch (param) { diff --git a/src/gallium/state_trackers/clover/core/kernel.cpp b/src/gallium/state_trackers/clover/core/kernel.cpp index 7e18366c728..10eb0e0b2ef 100644 --- a/src/gallium/state_trackers/clover/core/kernel.cpp +++ b/src/gallium/state_trackers/clover/core/kernel.cpp @@ -27,36 +27,35 @@ using namespace clover; -kernel::kernel(program &prog, - const std::string &name, +kernel::kernel(program &prog, const std::string &name, const std::vector &margs) : prog(prog), _name(name), exec(*this) { - for (auto marg : margs) { + for (auto &marg : margs) { if (marg.type == module::argument::scalar) - args.emplace_back(new scalar_argument(marg.size)); + _args.emplace_back(new scalar_argument(marg.size)); else if (marg.type == module::argument::global) - args.emplace_back(new global_argument); + _args.emplace_back(new global_argument); else if (marg.type == module::argument::local) - args.emplace_back(new local_argument); + _args.emplace_back(new local_argument); else if (marg.type == module::argument::constant) - args.emplace_back(new constant_argument); + _args.emplace_back(new constant_argument); else if (marg.type == module::argument::image2d_rd || marg.type == module::argument::image3d_rd) - args.emplace_back(new image_rd_argument); + _args.emplace_back(new image_rd_argument); else if (marg.type == module::argument::image2d_wr || marg.type == module::argument::image3d_wr) - args.emplace_back(new image_wr_argument); + _args.emplace_back(new image_wr_argument); else if (marg.type == module::argument::sampler) - args.emplace_back(new sampler_argument); + _args.emplace_back(new sampler_argument); else throw error(CL_INVALID_KERNEL_DEFINITION); } } -template -static inline std::vector -pad_vector(command_queue &q, const V &v, T x) { - std::vector w { v.begin(), v.end() }; +template +static inline std::vector +pad_vector(command_queue &q, const V &v, uint x) { + std::vector w { v.begin(), v.end() }; w.resize(q.dev.max_block_size().size(), x); return w; } @@ -66,7 +65,13 @@ kernel::launch(command_queue &q, const std::vector &grid_offset, const std::vector &grid_size, const std::vector &block_size) { + const auto m = prog.binary(q.dev); + const auto reduced_grid_size = + map(divides(), grid_size, block_size); void *st = exec.bind(&q); + + // The handles are created during exec_context::bind(), so we need make + // sure to call exec_context::bind() before retrieving them. std::vector g_handles = map([&](size_t h) { return (uint32_t *)&exec.input[h]; }, exec.g_handles); @@ -84,9 +89,9 @@ kernel::launch(command_queue &q, exec.g_buffers.data(), g_handles.data()); q.pipe->launch_grid(q.pipe, - pad_vector(q, block_size, 1).data(), - pad_vector(q, grid_size, 1).data(), - module(q).sym(_name).offset, + pad_vector(q, block_size, 1).data(), + pad_vector(q, reduced_grid_size, 1).data(), + find(name_equals(_name), m.syms).offset, exec.input.data()); q.pipe->set_global_binding(q.pipe, 0, exec.g_buffers.size(), NULL, NULL); @@ -101,9 +106,9 @@ size_t kernel::mem_local() const { size_t sz = 0; - for (auto &arg : args) { - if (dynamic_cast(arg.get())) - sz += arg->storage(); + for (auto &arg : args()) { + if (dynamic_cast(&arg)) + sz += arg.storage(); } return sz; @@ -129,13 +134,23 @@ kernel::block_size() const { return { 0, 0, 0 }; } +kernel::argument_range +kernel::args() { + return map(derefs(), _args); +} + +kernel::const_argument_range +kernel::args() const { + return map(derefs(), _args); +} + const module & kernel::module(const command_queue &q) const { - return prog.binaries().find(&q.dev)->second; + return prog.binary(q.dev); } kernel::exec_context::exec_context(kernel &kern) : - kern(kern), q(NULL), mem_local(0), st(NULL) { + kern(kern), q(NULL), mem_local(0), st(NULL), cs() { } kernel::exec_context::~exec_context() { @@ -148,11 +163,13 @@ kernel::exec_context::bind(command_queue *_q) { std::swap(q, _q); // Bind kernel arguments. - auto margs = kern.module(*q).sym(kern.name()).args; - for_each([=](std::unique_ptr &karg, - const module::argument &marg) { - karg->bind(*this, marg); - }, kern.args, margs); + auto &m = kern.prog.binary(q->dev); + auto margs = find(name_equals(kern.name()), m.syms).args; + auto msec = find(type_equals(module::section::text), m.secs); + + for_each([=](kernel::argument &karg, const module::argument &marg) { + karg.bind(*this, marg); + }, kern.args(), margs); // Create a new compute state if anything changed. if (!st || q != _q || @@ -161,7 +178,7 @@ kernel::exec_context::bind(command_queue *_q) { if (st) _q->pipe->delete_compute_state(_q->pipe, st); - cs.prog = kern.module(*q).sec(module::section::text).data.begin(); + cs.prog = msec.data.begin(); cs.req_local_mem = mem_local; cs.req_input_mem = input.size(); st = q->pipe->create_compute_state(q->pipe, &cs); @@ -172,8 +189,8 @@ kernel::exec_context::bind(command_queue *_q) { void kernel::exec_context::unbind() { - for (auto &arg : kern.args) - arg->unbind(*this); + for (auto &arg : kern.args()) + arg.unbind(*this); input.clear(); samplers.clear(); diff --git a/src/gallium/state_trackers/clover/core/kernel.hpp b/src/gallium/state_trackers/clover/core/kernel.hpp index 907a9fdfbc5..41d7b6f2c91 100644 --- a/src/gallium/state_trackers/clover/core/kernel.hpp +++ b/src/gallium/state_trackers/clover/core/kernel.hpp @@ -95,9 +95,18 @@ namespace clover { bool _set; }; - kernel(program &prog, - const std::string &name, - const std::vector &margs); + private: + typedef adaptor_range< + derefs, std::vector> & + > argument_range; + + typedef adaptor_range< + derefs, const std::vector> & + > const_argument_range; + + public: + kernel(program &prog, const std::string &name, + const std::vector &margs); kernel(const kernel &kern) = delete; kernel & @@ -115,12 +124,13 @@ namespace clover { const std::string &name() const; std::vector block_size() const; + argument_range args(); + const_argument_range args() const; + program &prog; - std::vector> args; private: - const clover::module & - module(const command_queue &q) const; + const clover::module &module(const command_queue &q) const; class scalar_argument : public argument { public: @@ -208,6 +218,7 @@ namespace clover { void *st; }; + std::vector> _args; std::string _name; exec_context exec; }; diff --git a/src/gallium/state_trackers/clover/core/module.cpp b/src/gallium/state_trackers/clover/core/module.cpp index de38438c5a3..040f5215054 100644 --- a/src/gallium/state_trackers/clover/core/module.cpp +++ b/src/gallium/state_trackers/clover/core/module.cpp @@ -21,7 +21,6 @@ // #include -#include #include "core/module.hpp" @@ -147,28 +146,4 @@ namespace clover { module::deserialize(compat::istream &is) { return _proc(is); } - - const module::symbol & - module::sym(compat::string name) const { - auto it = std::find_if(syms.begin(), syms.end(), [&](const symbol &x) { - return compat::string(x.name) == name; - }); - - if (it == syms.end()) - throw noent_error(); - - return *it; - } - - const module::section & - module::sec(typename section::type type) const { - auto it = std::find_if(secs.begin(), secs.end(), [&](const section &x) { - return x.type == type; - }); - - if (it == secs.end()) - throw noent_error(); - - return *it; - } } diff --git a/src/gallium/state_trackers/clover/core/module.hpp b/src/gallium/state_trackers/clover/core/module.hpp index 03b8d3a0b03..4a8dbcb4949 100644 --- a/src/gallium/state_trackers/clover/core/module.hpp +++ b/src/gallium/state_trackers/clover/core/module.hpp @@ -27,11 +27,6 @@ namespace clover { struct module { - class noent_error { - public: - virtual ~noent_error() {} - }; - typedef uint32_t resource_id; typedef uint32_t size_t; @@ -45,14 +40,14 @@ namespace clover { }; section(resource_id id, enum type type, size_t size, - const clover::compat::vector &data) : + const compat::vector &data) : id(id), type(type), size(size), data(data) { } section() : id(0), type(text), size(0), data() { } resource_id id; type type; size_t size; - clover::compat::vector data; + compat::vector data; }; struct argument { @@ -97,30 +92,22 @@ namespace clover { }; struct symbol { - symbol(const clover::compat::vector &name, resource_id section, - size_t offset, const clover::compat::vector &args) : + symbol(const compat::vector &name, resource_id section, + size_t offset, const compat::vector &args) : name(name), section(section), offset(offset), args(args) { } symbol() : name(), section(0), offset(0), args() { } - clover::compat::vector name; + compat::vector name; resource_id section; size_t offset; - clover::compat::vector args; + compat::vector args; }; void serialize(compat::ostream &os) const; static module deserialize(compat::istream &is); - /// Look up a symbol by name. Throws module::noent_error if not - /// found. - const symbol &sym(compat::string name) const; - - /// Look up a section by type. Throws module::noent_error if not - /// found. - const section &sec(typename section::type type) const; - - clover::compat::vector syms; - clover::compat::vector
secs; + compat::vector syms; + compat::vector
secs; }; } diff --git a/src/gallium/state_trackers/clover/core/program.cpp b/src/gallium/state_trackers/clover/core/program.cpp index 8082cf0f6f6..6d4a9ba81ff 100644 --- a/src/gallium/state_trackers/clover/core/program.cpp +++ b/src/gallium/state_trackers/clover/core/program.cpp @@ -67,22 +67,38 @@ program::source() const { return _source; } -const std::map & -program::binaries() const { - return _binaries; +program::device_range +program::devices() const { + return map(derefs(), map(keys(), _binaries)); +} + +const module & +program::binary(const device &dev) const { + return _binaries.find(const_cast(&dev))->second; } cl_build_status -program::build_status(device &dev) const { - return _binaries.count(&dev) ? CL_BUILD_SUCCESS : CL_BUILD_NONE; +program::build_status(const device &dev) const { + if (_binaries.count(const_cast(&dev))) + return CL_BUILD_SUCCESS; + else + return CL_BUILD_NONE; } std::string -program::build_opts(device &dev) const { +program::build_opts(const device &dev) const { return _opts.count(&dev) ? _opts.find(&dev)->second : ""; } std::string -program::build_log(device &dev) const { +program::build_log(const device &dev) const { return _logs.count(&dev) ? _logs.find(&dev)->second : ""; } + +const compat::vector & +program::symbols() const { + if (_binaries.empty()) + throw error(CL_INVALID_PROGRAM_EXECUTABLE); + + return _binaries.begin()->second.syms; +} diff --git a/src/gallium/state_trackers/clover/core/program.hpp b/src/gallium/state_trackers/clover/core/program.hpp index b8531bce1e8..25ec239507d 100644 --- a/src/gallium/state_trackers/clover/core/program.hpp +++ b/src/gallium/state_trackers/clover/core/program.hpp @@ -31,6 +31,11 @@ namespace clover { class program : public ref_counter, public _cl_program { + private: + typedef adaptor_range< + derefs, adaptor_range< + keys, const std::map &>> device_range; + public: program(context &ctx, const std::string &source); @@ -45,18 +50,22 @@ namespace clover { void build(const ref_vector &devs, const char *opts); const std::string &source() const; - const std::map &binaries() const; - cl_build_status build_status(device &dev) const; - std::string build_opts(device &dev) const; - std::string build_log(device &dev) const; + device_range devices() const; + + const module &binary(const device &dev) const; + cl_build_status build_status(const device &dev) const; + std::string build_opts(const device &dev) const; + std::string build_log(const device &dev) const; + + const compat::vector &symbols() const; context &ctx; private: std::map _binaries; - std::map _logs; - std::map _opts; + std::map _logs; + std::map _opts; std::string _source; }; } -- 2.30.2