clover/nir: use offset for temp memory
[mesa.git] / src / gallium / frontends / clover / nir / invocation.cpp
index 08dc78d18d9fa4ebc682d97bfccb50366d40f300..c916190f3a2811fe08d79860760e52ffaebdf6c1 100644 (file)
@@ -104,9 +104,9 @@ clover_lower_nir_instr(nir_builder *b, nir_instr *instr, void *_state)
          const glsl_type *type = glsl_uint_type();
          for (uint32_t i = 0; i < 3; i++) {
             state->offset_vars[i] =
-               nir_variable_create(b->shader, nir_var_shader_in, type,
+               nir_variable_create(b->shader, nir_var_uniform, type,
                                    "global_invocation_id_offsets");
-            state->offset_vars[i]->data.location = b->shader->num_inputs++;
+            state->offset_vars[i]->data.location = b->shader->num_uniforms++;
          }
       }
 
@@ -139,11 +139,11 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
    if (dev.address_bits() == 32u) {
       spirv_options.shared_addr_format = nir_address_format_32bit_offset;
       spirv_options.global_addr_format = nir_address_format_32bit_global;
-      spirv_options.temp_addr_format = nir_address_format_32bit_global;
+      spirv_options.temp_addr_format = nir_address_format_32bit_offset;
    } else {
       spirv_options.shared_addr_format = nir_address_format_32bit_offset_as_64bit;
       spirv_options.global_addr_format = nir_address_format_64bit_global;
-      spirv_options.temp_addr_format = nir_address_format_64bit_global;
+      spirv_options.temp_addr_format = nir_address_format_32bit_offset_as_64bit;
    }
    spirv_options.caps.address = true;
    spirv_options.caps.float64 = true;
@@ -201,8 +201,7 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
 
       nir_validate_shader(nir, "clover after function inlining");
 
-      NIR_PASS_V(nir, nir_lower_variable_initializers,
-                 static_cast<nir_variable_mode>(~nir_var_function_temp));
+      NIR_PASS_V(nir, nir_lower_variable_initializers, ~nir_var_function_temp);
 
       // copy propagate to prepare for lower_explicit_io
       NIR_PASS_V(nir, nir_split_var_copies);
@@ -221,25 +220,31 @@ module clover::nir::spirv_to_nir(const module &mod, const device &dev,
 
       // Calculate input offsets.
       unsigned offset = 0;
-      nir_foreach_shader_in_variable(var, nir) {
+      nir_foreach_uniform_variable(var, nir) {
          offset = align(offset, glsl_get_cl_alignment(var->type));
          var->data.driver_location = offset;
          offset += glsl_get_cl_size(var->type);
       }
 
-      NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,
+      NIR_PASS_V(nir, nir_lower_vars_to_explicit_types,
+                 nir_var_mem_shared | nir_var_function_temp,
                  glsl_get_cl_type_size_align);
 
-      /* use offsets for shader_in and shared memory */
-      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_shader_in,
+      /* use offsets for uniform and shared memory */
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_uniform,
                  nir_address_format_32bit_offset);
 
       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_shared,
                  spirv_options.shared_addr_format);
 
+      NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_function_temp,
+                 spirv_options.temp_addr_format);
+
       NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_global,
                  spirv_options.global_addr_format);
 
+      NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_all, NULL);
+
       if (compiler_options->lower_int64_options)
          NIR_PASS_V(nir, nir_lower_int64);