clover/nir/spirv: Use uniform rather than shader_in for kernel inputs
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index 96c0c0767db48f9c2540a05b72c05db1e56f8684..6dd18075c847cb641807765fd355d43fa8909edc 100644 (file)
@@ -775,6 +775,33 @@ wrap_type_in_array(const struct glsl_type *type,
                           glsl_get_explicit_stride(array_type));
 }
 
+static bool
+vtn_type_needs_explicit_layout(struct vtn_builder *b, enum vtn_variable_mode mode)
+{
+   /* For OpenCL we never want to strip the info from the types, and it makes
+    * type comparisons easier in later stages.
+    */
+   if (b->options->environment == NIR_SPIRV_OPENCL)
+      return true;
+
+   switch (mode) {
+   case vtn_variable_mode_input:
+   case vtn_variable_mode_output:
+      /* Layout decorations kept because we need offsets for XFB arrays of
+       * blocks.
+       */
+      return b->shader->info.has_transform_feedback_varyings;
+
+   case vtn_variable_mode_ssbo:
+   case vtn_variable_mode_phys_ssbo:
+   case vtn_variable_mode_ubo:
+      return true;
+
+   default:
+      return false;
+   }
+}
+
 const struct glsl_type *
 vtn_type_get_nir_type(struct vtn_builder *b, struct vtn_type *type,
                       enum vtn_variable_mode mode)
@@ -787,16 +814,65 @@ vtn_type_get_nir_type(struct vtn_builder *b, struct vtn_type *type,
    }
 
    if (mode == vtn_variable_mode_uniform) {
-      struct vtn_type *tail = vtn_type_without_array(type);
-      if (tail->base_type == vtn_base_type_image) {
-         return wrap_type_in_array(tail->glsl_image, type->type);
-      } else if (tail->base_type == vtn_base_type_sampler) {
-         return wrap_type_in_array(glsl_bare_sampler_type(), type->type);
-      } else if (tail->base_type == vtn_base_type_sampled_image) {
-         return wrap_type_in_array(tail->image->glsl_image, type->type);
+      switch (type->base_type) {
+      case vtn_base_type_array: {
+         const struct glsl_type *elem_type =
+            vtn_type_get_nir_type(b, type->array_element, mode);
+
+         return glsl_array_type(elem_type, type->length,
+                                glsl_get_explicit_stride(type->type));
+      }
+
+      case vtn_base_type_struct: {
+         bool need_new_struct = false;
+         const uint32_t num_fields = type->length;
+         NIR_VLA(struct glsl_struct_field, fields, num_fields);
+         for (unsigned i = 0; i < num_fields; i++) {
+            fields[i] = *glsl_get_struct_field_data(type->type, i);
+            const struct glsl_type *field_nir_type =
+               vtn_type_get_nir_type(b, type->members[i], mode);
+            if (fields[i].type != field_nir_type) {
+               fields[i].type = field_nir_type;
+               need_new_struct = true;
+            }
+         }
+         if (need_new_struct) {
+            if (glsl_type_is_interface(type->type)) {
+               return glsl_interface_type(fields, num_fields,
+                                          /* packing */ 0, false,
+                                          glsl_get_type_name(type->type));
+            } else {
+               return glsl_struct_type(fields, num_fields,
+                                       glsl_get_type_name(type->type),
+                                       glsl_struct_type_is_packed(type->type));
+            }
+         } else {
+            /* No changes, just pass it on */
+            return type->type;
+         }
+      }
+
+      case vtn_base_type_image:
+         return type->glsl_image;
+
+      case vtn_base_type_sampler:
+         return glsl_bare_sampler_type();
+
+      case vtn_base_type_sampled_image:
+         return type->image->glsl_image;
+
+      default:
+         return type->type;
       }
    }
 
+   /* Layout decorations are allowed but ignored in certain conditions,
+    * to allow SPIR-V generators perform type deduplication.  Discard
+    * unnecessary ones when passing to NIR.
+    */
+   if (!vtn_type_needs_explicit_layout(b, mode))
+      return glsl_get_bare_type(type->type);
+
    return type->type;
 }
 
@@ -2857,6 +2933,8 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
    SpvScope scope = SpvScopeInvocation;
    SpvMemorySemanticsMask semantics = 0;
 
+   enum gl_access_qualifier access = 0;
+
    struct vtn_value *res_val;
    switch (opcode) {
    case SpvOpAtomicExchange:
@@ -2879,6 +2957,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
       image = *res_val->image;
       scope = vtn_constant_uint(b, w[4]);
       semantics = vtn_constant_uint(b, w[5]);
+      access |= ACCESS_COHERENT;
       break;
 
    case SpvOpAtomicStore:
@@ -2886,6 +2965,7 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
       image = *res_val->image;
       scope = vtn_constant_uint(b, w[2]);
       semantics = vtn_constant_uint(b, w[3]);
+      access |= ACCESS_COHERENT;
       break;
 
    case SpvOpImageQuerySize:
@@ -3010,8 +3090,10 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
 
    intrin->src[0] = nir_src_for_ssa(&image.image->dest.ssa);
 
-   /* ImageQuerySize doesn't take any extra parameters */
-   if (opcode != SpvOpImageQuerySize) {
+   if (opcode == SpvOpImageQuerySize) {
+      /* ImageQuerySize only has an LOD which is currently always 0 */
+      intrin->src[1] = nir_src_for_ssa(nir_imm_int(&b->nb, 0));
+   } else {
       /* The image coordinate is always 4 components but we may not have that
        * many.  Swizzle to compensate.
        */
@@ -3032,7 +3114,6 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
     * chains to find the NonUniform decoration.  It's either right there or we
     * can assume it doesn't exist.
     */
-   enum gl_access_qualifier access = 0;
    vtn_foreach_decoration(b, res_val, non_uniform_decoration_cb, &access);
    nir_intrinsic_set_access(intrin, access);
 
@@ -3300,6 +3381,8 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode,
       nir_intrinsic_op op  = get_ssbo_nir_atomic_op(b, opcode);
       atomic = nir_intrinsic_instr_create(b->nb.shader, op);
 
+      nir_intrinsic_set_access(atomic, ACCESS_COHERENT);
+
       int src = 0;
       switch (opcode) {
       case SpvOpAtomicLoad:
@@ -3351,6 +3434,9 @@ vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode,
       atomic = nir_intrinsic_instr_create(b->nb.shader, op);
       atomic->src[0] = nir_src_for_ssa(&deref->dest.ssa);
 
+      if (ptr->mode != vtn_variable_mode_workgroup)
+         nir_intrinsic_set_access(atomic, ACCESS_COHERENT);
+
       switch (opcode) {
       case SpvOpAtomicLoad:
          atomic->num_components = glsl_get_vector_elements(deref_type);
@@ -4294,18 +4380,28 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
                      "AddressingModelPhysical32 only supported for kernels");
          b->shader->info.cs.ptr_size = 32;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_32bit_global;
-         b->options->global_addr_format = nir_address_format_32bit_global;
-         b->options->temp_addr_format = nir_address_format_32bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 32);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelPhysical64:
          vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
                      "AddressingModelPhysical64 only supported for kernels");
          b->shader->info.cs.ptr_size = 64;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_64bit_global;
-         b->options->global_addr_format = nir_address_format_64bit_global;
-         b->options->temp_addr_format = nir_address_format_64bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 64);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelLogical:
          vtn_fail_if(b->shader->info.stage == MESA_SHADER_KERNEL,
@@ -5379,7 +5475,7 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b,
 
       /* input variable */
       nir_variable *in_var = rzalloc(b->nb.shader, nir_variable);
-      in_var->data.mode = nir_var_shader_in;
+      in_var->data.mode = nir_var_uniform;
       in_var->data.read_only = true;
       in_var->data.location = i;
 
@@ -5447,6 +5543,10 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
       return NULL;
    }
 
+   /* Ensure a sane address mode is being used for function temps */
+   assert(nir_address_format_bit_size(b->options->temp_addr_format) == nir_get_ptr_bitsize(b->shader));
+   assert(nir_address_format_num_components(b->options->temp_addr_format) == 1);
+
    /* Set shader info defaults */
    if (stage == MESA_SHADER_GEOMETRY)
       b->shader->info.gs.invocations = 1;
@@ -5510,6 +5610,9 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    if (entry_point->num_params && b->shader->info.stage == MESA_SHADER_KERNEL)
       entry_point = vtn_emit_kernel_entry_point_wrapper(b, entry_point);
 
+   /* structurize the CFG */
+   nir_lower_goto_ifs(b->shader);
+
    entry_point->is_entrypoint = true;
 
    /* When multiple shader stages exist in the same SPIR-V module, we