clover/nir/spirv: Use uniform rather than shader_in for kernel inputs
[mesa.git] / src / gallium / frontends / clover / nir / invocation.cpp
index 46440d96e09dc2fa9da424381b692281eacf010f..757ace61393a807045d932b76da98ce8e24aaf22 100644 (file)
 
 #include "core/device.hpp"
 #include "core/error.hpp"
+#include "core/module.hpp"
 #include "pipe/p_state.h"
 #include "util/algorithm.hpp"
 #include "util/functional.hpp"
 
 #include <compiler/glsl_types.h>
+#include <compiler/nir/nir_builder.h>
 #include <compiler/nir/nir_serialize.h>
 #include <compiler/spirv/nir_spirv.h>
 #include <util/u_math.h>
@@ -58,18 +60,101 @@ dev_get_nir_compiler_options(const device &dev)
    return static_cast<const nir_shader_compiler_options*>(co);
 }
 
+static void debug_function(void *private_data,
+                   enum nir_spirv_debug_level level, size_t spirv_offset,
+                   const char *message)
+{
+   assert(private_data);
+   auto r_log = reinterpret_cast<std::string *>(private_data);
+   *r_log += message;
+}
+
+struct clover_lower_nir_state {
+   std::vector<module::argument> &args;
+   uint32_t global_dims;
+   nir_variable *offset_vars[3];
+};
+
+static bool
+clover_lower_nir_filter(const nir_instr *instr, const void *)
+{
+   return instr->type == nir_instr_type_intrinsic;
+}
+
+static nir_ssa_def *
+clover_lower_nir_instr(nir_builder *b, nir_instr *instr, void *_state)
+{
+   clover_lower_nir_state *state = reinterpret_cast<clover_lower_nir_state*>(_state);
+   nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
+
+   switch (intrinsic->intrinsic) {
+   case nir_intrinsic_load_base_global_invocation_id: {
+      nir_ssa_def *loads[3];
+
+      /* create variables if we didn't do so alrady */
+      if (!state->offset_vars[0]) {
+         /* TODO: fix for 64 bit */
+         /* Even though we only place one scalar argument, clover will bind up to
+          * three 32 bit values
+         */
+         state->args.emplace_back(module::argument::scalar, 4, 4, 4,
+                                  module::argument::zero_ext,
+                                  module::argument::grid_offset);
+
+         const glsl_type *type = glsl_uint_type();
+         for (uint32_t i = 0; i < 3; i++) {
+            state->offset_vars[i] =
+               nir_variable_create(b->shader, nir_var_uniform, type,
+                                   "global_invocation_id_offsets");
+            state->offset_vars[i]->data.location = b->shader->num_uniforms++;
+         }
+      }
+
+      for (int i = 0; i < 3; i++) {
+         nir_variable *var = state->offset_vars[i];
+         loads[i] = var ? nir_load_var(b, var) : nir_imm_int(b, 0);
+      }
+
+      return nir_u2u(b, nir_vec(b, loads, state->global_dims),
+                     nir_dest_bit_size(intrinsic->dest));
+   }
+   default:
+      return NULL;
+   }
+}
+
+static bool
+clover_lower_nir(nir_shader *nir, std::vector<module::argument> &args, uint32_t dims)
+{
+   clover_lower_nir_state state = { args, dims };
+   return nir_shader_lower_instructions(nir,
+      clover_lower_nir_filter, clover_lower_nir_instr, &state);
+}
+
 module clover::nir::spirv_to_nir(const module &mod, const device &dev,
                                  std::string &r_log)
 {
    struct spirv_to_nir_options spirv_options = {};
    spirv_options.environment = NIR_SPIRV_OPENCL;
+   if (dev.address_bits() == 32u) {
+      spirv_options.shared_addr_format = nir_address_format_32bit_offset;
+      spirv_options.global_addr_format = nir_address_format_32bit_global;
+      spirv_options.temp_addr_format = nir_address_format_32bit_global;
+   } else {
+      spirv_options.shared_addr_format = nir_address_format_32bit_offset_as_64bit;
+      spirv_options.global_addr_format = nir_address_format_64bit_global;
+      spirv_options.temp_addr_format = nir_address_format_64bit_global;
+   }
    spirv_options.caps.address = true;
    spirv_options.caps.float64 = true;
    spirv_options.caps.int8 = true;
    spirv_options.caps.int16 = true;
    spirv_options.caps.int64 = true;
    spirv_options.caps.kernel = true;
+   spirv_options.caps.int64_atomics = dev.has_int64_atomics();
    spirv_options.constant_as_global = true;
+   spirv_options.debug.func = &debug_function;
+   spirv_options.debug.private_data = &r_log;
 
    module m;
    // We only insert one section.
@@ -99,19 +184,12 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
       nir->info.cs.local_size_variable = true;
       nir_validate_shader(nir, "clover");
 
-      // Calculate input offsets.
-      unsigned offset = 0;
-      nir_foreach_variable_safe(var, &nir->inputs) {
-         offset = align(offset, glsl_get_cl_alignment(var->type));
-         var->data.driver_location = offset;
-         offset += glsl_get_cl_size(var->type);
-      }
-
       // Inline all functions first.
       // according to the comment on nir_inline_functions
       NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
       NIR_PASS_V(nir, nir_lower_returns);
       NIR_PASS_V(nir, nir_inline_functions);
+      NIR_PASS_V(nir, nir_copy_prop);
       NIR_PASS_V(nir, nir_opt_deref);
 
       // Pick off the single entrypoint that we want.
@@ -133,18 +211,37 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
       NIR_PASS_V(nir, nir_lower_vars_to_ssa);
       NIR_PASS_V(nir, nir_opt_dce);
 
-      nir_variable_mode modes = (nir_variable_mode)(
-         nir_var_shader_in |
-         nir_var_mem_global |
-         nir_var_mem_shared);
-      nir_address_format format = nir->info.cs.ptr_size == 64 ?
-         nir_address_format_64bit_global : nir_address_format_32bit_global;
-      NIR_PASS_V(nir, nir_lower_explicit_io, modes, format);
-
       NIR_PASS_V(nir, nir_lower_system_values);
+      nir_lower_compute_system_values_options sysval_options = { 0 };
+      sysval_options.has_base_global_invocation_id = true;
+      NIR_PASS_V(nir, nir_lower_compute_system_values, &sysval_options);
+
+      auto args = sym.args;
+      NIR_PASS_V(nir, clover_lower_nir, args, dev.max_block_size().size());
+
+      // Calculate input offsets.
+      unsigned offset = 0;
+      nir_foreach_uniform_variable(var, nir) {
+         offset = align(offset, glsl_get_cl_alignment(var->type));
+         var->data.driver_location = offset;
+         offset += glsl_get_cl_size(var->type);
+      }
+
+      NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,
+                 glsl_get_cl_type_size_align);
+
+      /* use offsets for uniform and shared memory */
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
+                 nir_address_format_32bit_offset);
+
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared,
+                 spirv_options.shared_addr_format);
+
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_global,
+                 spirv_options.global_addr_format);
+
       if (compiler_options->lower_int64_options)
-         NIR_PASS_V(nir, nir_lower_int64,
-                    compiler_options->lower_int64_options);
+         NIR_PASS_V(nir, nir_lower_int64);
 
       NIR_PASS_V(nir, nir_opt_dce);
 
@@ -158,7 +255,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
                        reinterpret_cast<const char *>(&header) + sizeof(header));
       text.data.insert(text.data.end(), blob.data, blob.data + blob.size);
 
-      m.syms.emplace_back(sym.name, section_id, 0, sym.args);
+      m.syms.emplace_back(sym.name, section_id, 0, args);
       m.secs.push_back(text);
       section_id++;
    }