clover/llvm: Clean up codestyle of get_kernel_args().
authorFrancisco Jerez <currojerez@riseup.net>
Tue, 17 May 2016 14:02:52 +0000 (16:02 +0200)
committerFrancisco Jerez <currojerez@riseup.net>
Tue, 12 Jul 2016 03:22:59 +0000 (20:22 -0700)
Reviewed-by: Serge Martin <edb+mesa@sigluy.net>
Tested-by: Jan Vesely <jan.vesely@rutgers.edu>
src/gallium/state_trackers/clover/llvm/invocation.cpp

index 754e477bd292ca0a242bf4b74b81599ae2ac490e..0fc6190280f06c66b381a5d519bf370056106cbe 100644 (file)
 using namespace clover;
 using namespace clover::llvm;
 
+using ::llvm::cast;
+using ::llvm::dyn_cast;
 using ::llvm::Function;
+using ::llvm::isa;
 using ::llvm::LLVMContext;
 using ::llvm::Module;
 using ::llvm::raw_string_ostream;
@@ -362,147 +365,127 @@ namespace {
    }
 #endif
 
+   enum module::argument::type
+   get_image_type(const std::string &type,
+                  const std::string &qual) {
+      if (type == "image2d_t" && qual == "read_only")
+         return module::argument::image2d_rd;
+      else if (type == "image2d_t" && qual == "write_only")
+         return module::argument::image2d_wr;
+      else if (type == "image3d_t" && qual == "read_only")
+         return module::argument::image3d_rd;
+      else if (type == "image3d_t" && qual == "write_only")
+         return module::argument::image3d_wr;
+      else
+         unreachable("Unknown image type");
+   }
+
    std::vector<module::argument>
-   get_kernel_args(const llvm::Module *mod, const std::string &kernel_name,
-                   const clang::CompilerInstance &c) {
+   make_kernel_args(const Module &mod, const std::string &kernel_name,
+                    const clang::CompilerInstance &c) {
       std::vector<module::argument> args;
       const auto address_spaces = c.getTarget().getAddressSpaceMap();
-      llvm::Function *kernel_func = mod->getFunction(kernel_name);
-      assert(kernel_func && "Kernel name not found in module.");
-      auto arg_md = get_kernel_arg_md(kernel_func);
-
-      llvm::DataLayout TD(mod);
-      llvm::Type *size_type =
-         TD.getSmallestLegalIntType(mod->getContext(), sizeof(cl_uint) * 8);
-
-      for (const auto &arg: kernel_func->args()) {
+      const Function &f = *mod.getFunction(kernel_name);
+      const auto arg_md = get_kernel_arg_md(&f);
+      ::llvm::DataLayout dl(&mod);
+      const auto size_type =
+         dl.getSmallestLegalIntType(mod.getContext(), sizeof(cl_uint) * 8);
 
-         llvm::Type *arg_type = arg.getType();
-         const unsigned arg_store_size = TD.getTypeStoreSize(arg_type);
+      for (const auto &arg : f.args()) {
+         const auto arg_type = arg.getType();
 
          // OpenCL 1.2 specification, Ch. 6.1.5: "A built-in data
          // type that is not a power of two bytes in size must be
          // aligned to the next larger power of two".  We need this
          // alignment for three element vectors, which have
          // non-power-of-2 store size.
+         const unsigned arg_store_size = dl.getTypeStoreSize(arg_type);
          const unsigned arg_api_size = util_next_power_of_two(arg_store_size);
 
-         llvm::Type *target_type = arg_type->isIntegerTy() ?
-               TD.getSmallestLegalIntType(mod->getContext(), arg_store_size * 8)
-               : arg_type;
-         unsigned target_size = TD.getTypeStoreSize(target_type);
-         unsigned target_align = TD.getABITypeAlignment(target_type);
-
-         llvm::StringRef type_name = arg_md[arg.getArgNo()].type_name;
-         llvm::StringRef access_qual = arg_md[arg.getArgNo()].access_qual;
-
-         // Image
-         const bool is_image2d = type_name == "image2d_t";
-         const bool is_image3d = type_name == "image3d_t";
-         if (is_image2d || is_image3d) {
-            const bool is_write_only = access_qual == "write_only";
-            const bool is_read_only = access_qual == "read_only";
-
-            enum module::argument::type marg_type;
-            if (is_image2d && is_read_only) {
-               marg_type = module::argument::image2d_rd;
-            } else if (is_image2d && is_write_only) {
-               marg_type = module::argument::image2d_wr;
-            } else if (is_image3d && is_read_only) {
-               marg_type = module::argument::image3d_rd;
-            } else if (is_image3d && is_write_only) {
-               marg_type = module::argument::image3d_wr;
-            } else {
-               assert(0 && "Wrong image access qualifier");
-            }
-
-            args.push_back(module::argument(marg_type,
-                                            arg_store_size, target_size,
-                                            target_align,
-                                            module::argument::zero_ext));
-            continue;
-         }
-
-         // Image size implicit argument
-         if (type_name == "__llvm_image_size") {
-            args.push_back(module::argument(module::argument::scalar,
-                                            sizeof(cl_uint),
-                                            TD.getTypeStoreSize(size_type),
-                                            TD.getABITypeAlignment(size_type),
-                                            module::argument::zero_ext,
-                                            module::argument::image_size));
-            continue;
-         }
-
-         // Image format implicit argument
-         if (type_name == "__llvm_image_format") {
-            args.push_back(module::argument(module::argument::scalar,
-                                            sizeof(cl_uint),
-                                            TD.getTypeStoreSize(size_type),
-                                            TD.getABITypeAlignment(size_type),
-                                            module::argument::zero_ext,
-                                            module::argument::image_format));
-            continue;
-         }
+         const auto target_type = !arg_type->isIntegerTy() ? arg_type :
+            dl.getSmallestLegalIntType(mod.getContext(), arg_store_size * 8);
+         const unsigned target_size = dl.getTypeStoreSize(target_type);
+         const unsigned target_align = dl.getABITypeAlignment(target_type);
+
+         const auto type_name = arg_md[arg.getArgNo()].type_name;
+
+         if (type_name == "image2d_t" || type_name == "image3d_t") {
+            // Image.
+            const auto access_qual = arg_md[arg.getArgNo()].access_qual;
+            args.emplace_back(get_image_type(type_name, access_qual),
+                              arg_store_size, target_size,
+                              target_align, module::argument::zero_ext);
+
+         } else if (type_name == "__llvm_image_size") {
+            // Image size implicit argument.
+            args.emplace_back(module::argument::scalar, sizeof(cl_uint),
+                              dl.getTypeStoreSize(size_type),
+                              dl.getABITypeAlignment(size_type),
+                              module::argument::zero_ext,
+                              module::argument::image_size);
+
+         } else if (type_name == "__llvm_image_format") {
+            // Image format implicit argument.
+            args.emplace_back(module::argument::scalar, sizeof(cl_uint),
+                              dl.getTypeStoreSize(size_type),
+                              dl.getABITypeAlignment(size_type),
+                              module::argument::zero_ext,
+                              module::argument::image_format);
 
-         // Other types
-         if (llvm::isa<llvm::PointerType>(arg_type) && arg.hasByValAttr()) {
-            arg_type =
-                  llvm::dyn_cast<llvm::PointerType>(arg_type)->getElementType();
-         }
+         } else {
+            // Other types.
+            const auto actual_type =
+               isa<::llvm::PointerType>(arg_type) && arg.hasByValAttr() ?
+               cast<::llvm::PointerType>(arg_type)->getElementType() : arg_type;
+
+            if (actual_type->isPointerTy()) {
+               const unsigned address_space =
+                  cast<::llvm::PointerType>(actual_type)->getAddressSpace();
+
+               if (address_space == address_spaces[clang::LangAS::opencl_local
+                                                   - clang::LangAS::Offset]) {
+                  args.emplace_back(module::argument::local, arg_api_size,
+                                    target_size, target_align,
+                                    module::argument::zero_ext);
+               } else {
+                  // XXX: Correctly handle constant address space.  There is no
+                  // way for r600g to pass a handle for constant buffers back
+                  // to clover like it can for global buffers, so
+                  // creating constant arguments will break r600g.  For now,
+                  // continue treating constant buffers as global buffers
+                  // until we can come up with a way to create handles for
+                  // constant buffers.
+                  args.emplace_back(module::argument::global, arg_api_size,
+                                    target_size, target_align,
+                                    module::argument::zero_ext);
+               }
 
-         if (arg_type->isPointerTy()) {
-            unsigned address_space = llvm::cast<llvm::PointerType>(arg_type)->getAddressSpace();
-            if (address_space == address_spaces[clang::LangAS::opencl_local
-                                                     - clang::LangAS::Offset]) {
-               args.push_back(module::argument(module::argument::local,
-                                               arg_api_size, target_size,
-                                               target_align,
-                                               module::argument::zero_ext));
             } else {
-               // XXX: Correctly handle constant address space.  There is no
-               // way for r600g to pass a handle for constant buffers back
-               // to clover like it can for global buffers, so
-               // creating constant arguments will break r600g.  For now,
-               // continue treating constant buffers as global buffers
-               // until we can come up with a way to create handles for
-               // constant buffers.
-               args.push_back(module::argument(module::argument::global,
-                                               arg_api_size, target_size,
-                                               target_align,
-                                               module::argument::zero_ext));
-           }
+               const bool needs_sign_ext = f.getAttributes().hasAttribute(
+                  arg.getArgNo() + 1, ::llvm::Attribute::SExt);
 
-         } else {
-            llvm::AttributeSet attrs = kernel_func->getAttributes();
-            enum module::argument::ext_type ext_type =
-                  (attrs.hasAttribute(arg.getArgNo() + 1,
-                                     llvm::Attribute::SExt) ?
-                   module::argument::sign_ext :
-                   module::argument::zero_ext);
-
-            args.push_back(
-               module::argument(module::argument::scalar, arg_api_size,
-                                target_size, target_align, ext_type));
+               args.emplace_back(module::argument::scalar, arg_api_size,
+                                 target_size, target_align,
+                                 (needs_sign_ext ? module::argument::sign_ext :
+                                  module::argument::zero_ext));
+            }
          }
       }
 
       // Append implicit arguments.  XXX - The types, ordering and
       // vector size of the implicit arguments should depend on the
       // target according to the selected calling convention.
-      args.push_back(
-         module::argument(module::argument::scalar, sizeof(cl_uint),
-                          TD.getTypeStoreSize(size_type),
-                          TD.getABITypeAlignment(size_type),
-                          module::argument::zero_ext,
-                          module::argument::grid_dimension));
-
-      args.push_back(
-         module::argument(module::argument::scalar, sizeof(cl_uint),
-                          TD.getTypeStoreSize(size_type),
-                          TD.getABITypeAlignment(size_type),
-                          module::argument::zero_ext,
-                          module::argument::grid_offset));
+      args.emplace_back(module::argument::scalar, sizeof(cl_uint),
+                        dl.getTypeStoreSize(size_type),
+                        dl.getABITypeAlignment(size_type),
+                        module::argument::zero_ext,
+                        module::argument::grid_dimension);
+
+      args.emplace_back(module::argument::scalar, sizeof(cl_uint),
+                        dl.getTypeStoreSize(size_type),
+                        dl.getABITypeAlignment(size_type),
+                        module::argument::zero_ext,
+                        module::argument::grid_offset);
 
       return args;
    }
@@ -531,7 +514,7 @@ namespace {
                                   find_kernels(const_cast<Module *>(&mod)))) {
          if (offsets.count(name))
             m.syms.emplace_back(name, 0, offsets.at(name),
-                                get_kernel_args(&mod, name, c));
+                                make_kernel_args(mod, name, c));
       }
 
       m.secs.push_back(make_text_section(code));