clover: Clean up the kernel and program object interface.
authorFrancisco Jerez <currojerez@riseup.net>
Tue, 17 Sep 2013 04:50:40 +0000 (21:50 -0700)
committerFrancisco Jerez <currojerez@riseup.net>
Mon, 21 Oct 2013 17:47:03 +0000 (10:47 -0700)
[ Tom Stellard: Make sure to bind global arguments before retrieving handles. ]
Tested-by: Tom Stellard <thomas.stellard@amd.com>
src/gallium/state_trackers/clover/api/kernel.cpp
src/gallium/state_trackers/clover/api/program.cpp
src/gallium/state_trackers/clover/core/kernel.cpp
src/gallium/state_trackers/clover/core/kernel.hpp
src/gallium/state_trackers/clover/core/module.cpp
src/gallium/state_trackers/clover/core/module.hpp
src/gallium/state_trackers/clover/core/program.cpp
src/gallium/state_trackers/clover/core/program.hpp

index dd742c3bb05ccad9528b244b96d28c619ca6ebe3..e55a55b546e3aa803761882691464389252c2c43 100644 (file)
@@ -33,15 +33,12 @@ clCreateKernel(cl_program d_prog, const char *name, cl_int *r_errcode) try {
    if (!name)
       throw error(CL_INVALID_VALUE);
 
-   if (prog.binaries().empty())
-      throw error(CL_INVALID_PROGRAM_EXECUTABLE);
-
-   auto sym = prog.binaries().begin()->second.sym(name);
+   auto &sym = find(name_equals(name), prog.symbols());
 
    ret_error(r_errcode, CL_SUCCESS);
    return new kernel(prog, name, range(sym.args));
 
-} catch (module::noent_error &e) {
+} catch (std::out_of_range &e) {
    ret_error(r_errcode, CL_INVALID_KERNEL_NAME);
    return NULL;
 
@@ -54,11 +51,7 @@ PUBLIC cl_int
 clCreateKernelsInProgram(cl_program d_prog, cl_uint count,
                          cl_kernel *rd_kerns, cl_uint *r_count) try {
    auto &prog = obj(d_prog);
-
-   if (prog.binaries().empty())
-      throw error(CL_INVALID_PROGRAM_EXECUTABLE);
-
-   auto &syms = prog.binaries().begin()->second.syms;
+   auto &syms = prog.symbols();
 
    if (rd_kerns && count < syms.size())
       throw error(CL_INVALID_VALUE);
@@ -102,15 +95,12 @@ clReleaseKernel(cl_kernel d_kern) try {
 PUBLIC cl_int
 clSetKernelArg(cl_kernel d_kern, cl_uint idx, size_t size,
                const void *value) try {
-   auto &kern = obj(d_kern);
-
-   if (idx >= kern.args.size())
-      throw error(CL_INVALID_ARG_INDEX);
-
-   kern.args[idx]->set(size, value);
-
+   obj(d_kern).args().at(idx).set(size, value);
    return CL_SUCCESS;
 
+} catch (std::out_of_range &e) {
+   return CL_INVALID_ARG_INDEX;
+
 } catch (error &e) {
    return e.get();
 }
@@ -127,7 +117,7 @@ clGetKernelInfo(cl_kernel d_kern, cl_kernel_info param,
       break;
 
    case CL_KERNEL_NUM_ARGS:
-      buf.as_scalar<cl_uint>() = kern.args.size();
+      buf.as_scalar<cl_uint>() = kern.args().size();
       break;
 
    case CL_KERNEL_REFERENCE_COUNT:
@@ -160,8 +150,8 @@ clGetKernelWorkGroupInfo(cl_kernel d_kern, cl_device_id d_dev,
    auto &kern = obj(d_kern);
    auto pdev = pobj(d_dev);
 
-   if ((!pdev && kern.prog.binaries().size() != 1) ||
-       (pdev && !kern.prog.binaries().count(pdev)))
+   if ((!pdev && kern.prog.devices().size() != 1) ||
+       (pdev && !count(*pdev, kern.prog.devices())))
       throw error(CL_INVALID_DEVICE);
 
    switch (param) {
@@ -210,10 +200,10 @@ namespace {
 
       if (any_of([](kernel::argument &arg) {
                return !arg.set();
-            }, map(derefs(), kern.args)))
+            }, kern.args()))
          throw error(CL_INVALID_KERNEL_ARGS);
 
-      if (!kern.prog.binaries().count(&q.dev))
+      if (!count(q.dev, kern.prog.devices()))
          throw error(CL_INVALID_PROGRAM_EXECUTABLE);
    }
 
index 45640424a0d4884ad169c7d0732f0ffbd3cd1f58..01133813f9868f69354876a2de561d824f91e180 100644 (file)
@@ -173,11 +173,11 @@ clGetProgramInfo(cl_program d_prog, cl_program_info param,
       break;
 
    case CL_PROGRAM_NUM_DEVICES:
-      buf.as_scalar<cl_uint>() = prog.binaries().size();
+      buf.as_scalar<cl_uint>() = prog.devices().size();
       break;
 
    case CL_PROGRAM_DEVICES:
-      buf.as_vector<cl_device_id>() = map(keys(), prog.binaries());
+      buf.as_vector<cl_device_id>() = descs(prog.devices());
       break;
 
    case CL_PROGRAM_SOURCE:
@@ -185,25 +185,23 @@ clGetProgramInfo(cl_program d_prog, cl_program_info param,
       break;
 
    case CL_PROGRAM_BINARY_SIZES:
-      buf.as_vector<size_t>() =
-         map([](const std::pair<device *, module> &ent) {
-               compat::ostream::buffer_t bin;
-               compat::ostream s(bin);
-               ent.second.serialize(s);
-               return bin.size();
-            },
-            prog.binaries());
+      buf.as_vector<size_t>() = map([&](const device &dev) {
+            compat::ostream::buffer_t bin;
+            compat::ostream s(bin);
+            prog.binary(dev).serialize(s);
+            return bin.size();
+         },
+         prog.devices());
       break;
 
    case CL_PROGRAM_BINARIES:
-      buf.as_matrix<unsigned char>() =
-         map([](const std::pair<device *, module> &ent) {
-               compat::ostream::buffer_t bin;
-               compat::ostream s(bin);
-               ent.second.serialize(s);
-               return bin;
-            },
-            prog.binaries());
+      buf.as_matrix<unsigned char>() = map([&](const device &dev) {
+            compat::ostream::buffer_t bin;
+            compat::ostream s(bin);
+            prog.binary(dev).serialize(s);
+            return bin;
+         },
+         prog.devices());
       break;
 
    default:
@@ -224,7 +222,7 @@ clGetProgramBuildInfo(cl_program d_prog, cl_device_id d_dev,
    auto &prog = obj(d_prog);
    auto &dev = obj(d_dev);
 
-   if (!count(dev, prog.ctx.devs()))
+   if (!count(dev, prog.devices()))
       return CL_INVALID_DEVICE;
 
    switch (param) {
index 7e18366c72883d6d5ec65d3c7aec47779da12ef9..10eb0e0b2ef5b0244054f460877b236de4a99532 100644 (file)
 
 using namespace clover;
 
-kernel::kernel(program &prog,
-               const std::string &name,
+kernel::kernel(program &prog, const std::string &name,
                const std::vector<module::argument> &margs) :
    prog(prog), _name(name), exec(*this) {
-   for (auto marg : margs) {
+   for (auto &marg : margs) {
       if (marg.type == module::argument::scalar)
-         args.emplace_back(new scalar_argument(marg.size));
+         _args.emplace_back(new scalar_argument(marg.size));
       else if (marg.type == module::argument::global)
-         args.emplace_back(new global_argument);
+         _args.emplace_back(new global_argument);
       else if (marg.type == module::argument::local)
-         args.emplace_back(new local_argument);
+         _args.emplace_back(new local_argument);
       else if (marg.type == module::argument::constant)
-         args.emplace_back(new constant_argument);
+         _args.emplace_back(new constant_argument);
       else if (marg.type == module::argument::image2d_rd ||
                marg.type == module::argument::image3d_rd)
-         args.emplace_back(new image_rd_argument);
+         _args.emplace_back(new image_rd_argument);
       else if (marg.type == module::argument::image2d_wr ||
                marg.type == module::argument::image3d_wr)
-         args.emplace_back(new image_wr_argument);
+         _args.emplace_back(new image_wr_argument);
       else if (marg.type == module::argument::sampler)
-         args.emplace_back(new sampler_argument);
+         _args.emplace_back(new sampler_argument);
       else
          throw error(CL_INVALID_KERNEL_DEFINITION);
    }
 }
 
-template<typename T, typename V>
-static inline std::vector<T>
-pad_vector(command_queue &q, const V &v, T x) {
-   std::vector<T> w { v.begin(), v.end() };
+template<typename V>
+static inline std::vector<uint>
+pad_vector(command_queue &q, const V &v, uint x) {
+   std::vector<uint> w { v.begin(), v.end() };
    w.resize(q.dev.max_block_size().size(), x);
    return w;
 }
@@ -66,7 +65,13 @@ kernel::launch(command_queue &q,
                const std::vector<size_t> &grid_offset,
                const std::vector<size_t> &grid_size,
                const std::vector<size_t> &block_size) {
+   const auto m = prog.binary(q.dev);
+   const auto reduced_grid_size =
+      map(divides(), grid_size, block_size);
    void *st = exec.bind(&q);
+
+   // The handles are created during exec_context::bind(), so we need make
+   // sure to call exec_context::bind() before retrieving them.
    std::vector<uint32_t *> g_handles = map([&](size_t h) {
          return (uint32_t *)&exec.input[h];
       }, exec.g_handles);
@@ -84,9 +89,9 @@ kernel::launch(command_queue &q,
                               exec.g_buffers.data(), g_handles.data());
 
    q.pipe->launch_grid(q.pipe,
-                       pad_vector<uint>(q, block_size, 1).data(),
-                       pad_vector<uint>(q, grid_size, 1).data(),
-                       module(q).sym(_name).offset,
+                       pad_vector(q, block_size, 1).data(),
+                       pad_vector(q, reduced_grid_size, 1).data(),
+                       find(name_equals(_name), m.syms).offset,
                        exec.input.data());
 
    q.pipe->set_global_binding(q.pipe, 0, exec.g_buffers.size(), NULL, NULL);
@@ -101,9 +106,9 @@ size_t
 kernel::mem_local() const {
    size_t sz = 0;
 
-   for (auto &arg : args) {
-      if (dynamic_cast<local_argument *>(arg.get()))
-         sz += arg->storage();
+   for (auto &arg : args()) {
+      if (dynamic_cast<local_argument *>(&arg))
+         sz += arg.storage();
    }
 
    return sz;
@@ -129,13 +134,23 @@ kernel::block_size() const {
    return { 0, 0, 0 };
 }
 
+kernel::argument_range
+kernel::args() {
+   return map(derefs(), _args);
+}
+
+kernel::const_argument_range
+kernel::args() const {
+   return map(derefs(), _args);
+}
+
 const module &
 kernel::module(const command_queue &q) const {
-   return prog.binaries().find(&q.dev)->second;
+   return prog.binary(q.dev);
 }
 
 kernel::exec_context::exec_context(kernel &kern) :
-   kern(kern), q(NULL), mem_local(0), st(NULL) {
+   kern(kern), q(NULL), mem_local(0), st(NULL), cs() {
 }
 
 kernel::exec_context::~exec_context() {
@@ -148,11 +163,13 @@ kernel::exec_context::bind(command_queue *_q) {
    std::swap(q, _q);
 
    // Bind kernel arguments.
-   auto margs = kern.module(*q).sym(kern.name()).args;
-   for_each([=](std::unique_ptr<kernel::argument> &karg,
-                const module::argument &marg) {
-               karg->bind(*this, marg);
-            }, kern.args, margs);
+   auto &m = kern.prog.binary(q->dev);
+   auto margs = find(name_equals(kern.name()), m.syms).args;
+   auto msec = find(type_equals(module::section::text), m.secs);
+
+   for_each([=](kernel::argument &karg, const module::argument &marg) {
+               karg.bind(*this, marg);
+            }, kern.args(), margs);
 
    // Create a new compute state if anything changed.
    if (!st || q != _q ||
@@ -161,7 +178,7 @@ kernel::exec_context::bind(command_queue *_q) {
       if (st)
          _q->pipe->delete_compute_state(_q->pipe, st);
 
-      cs.prog = kern.module(*q).sec(module::section::text).data.begin();
+      cs.prog = msec.data.begin();
       cs.req_local_mem = mem_local;
       cs.req_input_mem = input.size();
       st = q->pipe->create_compute_state(q->pipe, &cs);
@@ -172,8 +189,8 @@ kernel::exec_context::bind(command_queue *_q) {
 
 void
 kernel::exec_context::unbind() {
-   for (auto &arg : kern.args)
-      arg->unbind(*this);
+   for (auto &arg : kern.args())
+      arg.unbind(*this);
 
    input.clear();
    samplers.clear();
index 907a9fdfbc5efafaddfcd1c0291eb8b9c556fb5c..41d7b6f2c918c8acb27bd917f1ade18d59888f2f 100644 (file)
@@ -95,9 +95,18 @@ namespace clover {
          bool _set;
       };
 
-      kernel(program &prog,
-             const std::string &name,
-             const std::vector<module::argument> &margs);
+   private:
+      typedef adaptor_range<
+            derefs, std::vector<std::unique_ptr<argument>> &
+         > argument_range;
+
+      typedef adaptor_range<
+            derefs, const std::vector<std::unique_ptr<argument>> &
+         > const_argument_range;
+
+   public:
+      kernel(program &prog, const std::string &name,
+             const std::vector<clover::module::argument> &margs);
 
       kernel(const kernel &kern) = delete;
       kernel &
@@ -115,12 +124,13 @@ namespace clover {
       const std::string &name() const;
       std::vector<size_t> block_size() const;
 
+      argument_range args();
+      const_argument_range args() const;
+
       program &prog;
-      std::vector<std::unique_ptr<argument>> args;
 
    private:
-      const clover::module &
-      module(const command_queue &q) const;
+      const clover::module &module(const command_queue &q) const;
 
       class scalar_argument : public argument {
       public:
@@ -208,6 +218,7 @@ namespace clover {
          void *st;
       };
 
+      std::vector<std::unique_ptr<argument>> _args;
       std::string _name;
       exec_context exec;
    };
index de38438c5a35600daa271b41359c8b020e9d56cd..040f521505408da72228d5d4dbe5f6698f405444 100644 (file)
@@ -21,7 +21,6 @@
 //
 
 #include <type_traits>
-#include <algorithm>
 
 #include "core/module.hpp"
 
@@ -147,28 +146,4 @@ namespace clover {
    module::deserialize(compat::istream &is) {
       return _proc<module>(is);
    }
-
-   const module::symbol &
-   module::sym(compat::string name) const {
-      auto it = std::find_if(syms.begin(), syms.end(), [&](const symbol &x) {
-            return compat::string(x.name) == name;
-         });
-
-      if (it == syms.end())
-         throw noent_error();
-
-      return *it;
-   }
-
-   const module::section &
-   module::sec(typename section::type type) const {
-      auto it = std::find_if(secs.begin(), secs.end(), [&](const section &x) {
-            return x.type == type;
-         });
-
-      if (it == secs.end())
-         throw noent_error();
-
-      return *it;
-   }
 }
index 03b8d3a0b037f7762fb86e77b6eae87f6499e813..4a8dbcb4949bff04cc3eca3cd5501a5aed4fc4cf 100644 (file)
 
 namespace clover {
    struct module {
-      class noent_error {
-      public:
-         virtual ~noent_error() {}
-      };
-
       typedef uint32_t resource_id;
       typedef uint32_t size_t;
 
@@ -45,14 +40,14 @@ namespace clover {
          };
 
          section(resource_id id, enum type type, size_t size,
-                 const clover::compat::vector<char> &data) :
+                 const compat::vector<char> &data) :
                  id(id), type(type), size(size), data(data) { }
          section() : id(0), type(text), size(0), data() { }
 
          resource_id id;
          type type;
          size_t size;
-         clover::compat::vector<char> data;
+         compat::vector<char> data;
       };
 
       struct argument {
@@ -97,30 +92,22 @@ namespace clover {
       };
 
       struct symbol {
-         symbol(const clover::compat::vector<char> &name, resource_id section,
-                size_t offset, const clover::compat::vector<argument> &args) :
+         symbol(const compat::vector<char> &name, resource_id section,
+                size_t offset, const compat::vector<argument> &args) :
                 name(name), section(section), offset(offset), args(args) { }
          symbol() : name(), section(0), offset(0), args() { }
 
-         clover::compat::vector<char> name;
+         compat::vector<char> name;
          resource_id section;
          size_t offset;
-         clover::compat::vector<argument> args;
+         compat::vector<argument> args;
       };
 
       void serialize(compat::ostream &os) const;
       static module deserialize(compat::istream &is);
 
-      /// Look up a symbol by name.  Throws module::noent_error if not
-      /// found.
-      const symbol &sym(compat::string name) const;
-
-      /// Look up a section by type.  Throws module::noent_error if not
-      /// found.
-      const section &sec(typename section::type type) const;
-
-      clover::compat::vector<symbol> syms;
-      clover::compat::vector<section> secs;
+      compat::vector<symbol> syms;
+      compat::vector<section> secs;
    };
 }
 
index 8082cf0f6f6942e554bb66a913e71a1b49ef3bab..6d4a9ba81ff5ef69537659e2720b6efad7423d7b 100644 (file)
@@ -67,22 +67,38 @@ program::source() const {
    return _source;
 }
 
-const std::map<device *, module> &
-program::binaries() const {
-   return _binaries;
+program::device_range
+program::devices() const {
+   return map(derefs(), map(keys(), _binaries));
+}
+
+const module &
+program::binary(const device &dev) const {
+   return _binaries.find(const_cast<device *>(&dev))->second;
 }
 
 cl_build_status
-program::build_status(device &dev) const {
-   return _binaries.count(&dev) ? CL_BUILD_SUCCESS : CL_BUILD_NONE;
+program::build_status(const device &dev) const {
+   if (_binaries.count(const_cast<device *>(&dev)))
+      return CL_BUILD_SUCCESS;
+   else
+      return CL_BUILD_NONE;
 }
 
 std::string
-program::build_opts(device &dev) const {
+program::build_opts(const device &dev) const {
    return _opts.count(&dev) ? _opts.find(&dev)->second : "";
 }
 
 std::string
-program::build_log(device &dev) const {
+program::build_log(const device &dev) const {
    return _logs.count(&dev) ? _logs.find(&dev)->second : "";
 }
+
+const compat::vector<module::symbol> &
+program::symbols() const {
+   if (_binaries.empty())
+      throw error(CL_INVALID_PROGRAM_EXECUTABLE);
+
+   return _binaries.begin()->second.syms;
+}
index b8531bce1e8e8426d7508c2110bb48b1648d5077..25ec239507de9afb4df5589016e8e3b65cd861c8 100644 (file)
 
 namespace clover {
    class program : public ref_counter, public _cl_program {
+   private:
+      typedef adaptor_range<
+         derefs, adaptor_range<
+             keys, const std::map<device *, module> &>> device_range;
+
    public:
       program(context &ctx,
               const std::string &source);
@@ -45,18 +50,22 @@ namespace clover {
       void build(const ref_vector<device> &devs, const char *opts);
 
       const std::string &source() const;
-      const std::map<device *, module> &binaries() const;
 
-      cl_build_status build_status(device &dev) const;
-      std::string build_opts(device &dev) const;
-      std::string build_log(device &dev) const;
+      device_range devices() const;
+
+      const module &binary(const device &dev) const;
+      cl_build_status build_status(const device &dev) const;
+      std::string build_opts(const device &dev) const;
+      std::string build_log(const device &dev) const;
+
+      const compat::vector<module::symbol> &symbols() const;
 
       context &ctx;
 
    private:
       std::map<device *, module> _binaries;
-      std::map<device *, std::string> _logs;
-      std::map<device *, std::string> _opts;
+      std::map<const device *, std::string> _logs;
+      std::map<const device *, std::string> _opts;
       std::string _source;
    };
 }