clover/nir: support int64 atomics if the device supports it
[mesa.git] / src / gallium / frontends / clover / nir / invocation.cpp
index 46440d96e09dc2fa9da424381b692281eacf010f..8c6b34afc9c335e0ef91d9a826417ad911942b87 100644 (file)
@@ -63,12 +63,24 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
 {
    struct spirv_to_nir_options spirv_options = {};
    spirv_options.environment = NIR_SPIRV_OPENCL;
+   if (dev.address_bits() == 32u) {
+      spirv_options.shared_addr_format = nir_address_format_32bit_global;
+      spirv_options.global_addr_format = nir_address_format_32bit_global;
+      spirv_options.temp_addr_format = nir_address_format_32bit_global;
+      spirv_options.ubo_addr_format = nir_address_format_32bit_global;
+   } else {
+      spirv_options.shared_addr_format = nir_address_format_64bit_global;
+      spirv_options.global_addr_format = nir_address_format_64bit_global;
+      spirv_options.temp_addr_format = nir_address_format_64bit_global;
+      spirv_options.ubo_addr_format = nir_address_format_32bit_index_offset;
+   }
    spirv_options.caps.address = true;
    spirv_options.caps.float64 = true;
    spirv_options.caps.int8 = true;
    spirv_options.caps.int16 = true;
    spirv_options.caps.int64 = true;
    spirv_options.caps.kernel = true;
+   spirv_options.caps.int64_atomics = dev.has_int64_atomics();
    spirv_options.constant_as_global = true;
 
    module m;
@@ -101,7 +113,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
 
       // Calculate input offsets.
       unsigned offset = 0;
-      nir_foreach_variable_safe(var, &nir->inputs) {
+      nir_foreach_shader_in_variable_safe(var, nir) {
          offset = align(offset, glsl_get_cl_alignment(var->type));
          var->data.driver_location = offset;
          offset += glsl_get_cl_size(var->type);
@@ -112,6 +124,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
       NIR_PASS_V(nir, nir_lower_variable_initializers, nir_var_function_temp);
       NIR_PASS_V(nir, nir_lower_returns);
       NIR_PASS_V(nir, nir_inline_functions);
+      NIR_PASS_V(nir, nir_copy_prop);
       NIR_PASS_V(nir, nir_opt_deref);
 
       // Pick off the single entrypoint that we want.
@@ -133,18 +146,23 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
       NIR_PASS_V(nir, nir_lower_vars_to_ssa);
       NIR_PASS_V(nir, nir_opt_dce);
 
+      NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,
+                 glsl_get_cl_type_size_align);
+
+      /* use offsets for shader_in and shared memory */
       nir_variable_mode modes = (nir_variable_mode)(
          nir_var_shader_in |
-         nir_var_mem_global |
          nir_var_mem_shared);
+      NIR_PASS_V(nir, nir_lower_explicit_io, modes, nir_address_format_32bit_offset);
+
+      /* use global format for global memory */
       nir_address_format format = nir->info.cs.ptr_size == 64 ?
          nir_address_format_64bit_global : nir_address_format_32bit_global;
-      NIR_PASS_V(nir, nir_lower_explicit_io, modes, format);
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_global, format);
 
       NIR_PASS_V(nir, nir_lower_system_values);
       if (compiler_options->lower_int64_options)
-         NIR_PASS_V(nir, nir_lower_int64,
-                    compiler_options->lower_int64_options);
+         NIR_PASS_V(nir, nir_lower_int64);
 
       NIR_PASS_V(nir, nir_opt_dce);