nir/vtn: Use return type rather than image type for tex ops
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index 747d1a0cdcce5e929984a34ce794f406e718ce1b..e2912809bb05ca3ce58b36e550b796f3dcf66c40 100644 (file)
@@ -1569,9 +1569,14 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
          vtn_mode_to_address_format(b, vtn_variable_mode_function));
 
       const struct vtn_type *sampled_type = vtn_get_type(b, w[2]);
-      vtn_fail_if(sampled_type->base_type != vtn_base_type_scalar ||
-                  glsl_get_bit_size(sampled_type->type) != 32,
-                  "Sampled type of OpTypeImage must be a 32-bit scalar");
+      if (b->shader->info.stage == MESA_SHADER_KERNEL) {
+         vtn_fail_if(sampled_type->base_type != vtn_base_type_void,
+                     "Sampled type of OpTypeImage must be void for kernels");
+      } else {
+         vtn_fail_if(sampled_type->base_type != vtn_base_type_scalar ||
+                     glsl_get_bit_size(sampled_type->type) != 32,
+                     "Sampled type of OpTypeImage must be a 32-bit scalar");
+      }
 
       enum glsl_sampler_dim dim;
       switch ((SpvDim)w[3]) {
@@ -1597,6 +1602,9 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
 
       if (count > 9)
          val->type->access_qualifier = w[9];
+      else if (b->shader->info.stage == MESA_SHADER_KERNEL)
+         /* Per the CL C spec: If no qualifier is provided, read_only is assumed. */
+         val->type->access_qualifier = SpvAccessQualifierReadOnly;
       else
          val->type->access_qualifier = SpvAccessQualifierReadWrite;
 
@@ -1619,6 +1627,9 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
       } else if (sampled == 2) {
          val->type->glsl_image = glsl_image_type(dim, is_array,
                                                  sampled_base_type);
+      } else if (b->shader->info.stage == MESA_SHADER_KERNEL) {
+         val->type->glsl_image = glsl_image_type(dim, is_array,
+                                                 GLSL_TYPE_VOID);
       } else {
          vtn_fail("We need to know if the image will be sampled");
       }
@@ -2790,9 +2801,18 @@ vtn_handle_texture(struct vtn_builder *b, SpvOp opcode,
    if (sampler && (access & ACCESS_NON_UNIFORM))
       instr->sampler_non_uniform = true;
 
-   /* for non-query ops, get dest_type from sampler type */
+   /* for non-query ops, get dest_type from SPIR-V return type */
    if (dest_type == nir_type_invalid) {
-      switch (glsl_get_sampler_result_type(image->type)) {
+      /* the return type should match the image type, unless the image type is
+       * VOID (CL image), in which case the return type dictates the sampler
+       */
+      enum glsl_base_type sampler_base =
+         glsl_get_sampler_result_type(image->type);
+      enum glsl_base_type ret_base = glsl_get_base_type(ret_type->type);
+      vtn_fail_if(sampler_base != ret_base && sampler_base != GLSL_TYPE_VOID,
+                  "SPIR-V return type mismatches image type. This is only valid "
+                  "for untyped images (OpenCL).");
+      switch (ret_base) {
       case GLSL_TYPE_FLOAT:   dest_type = nir_type_float;   break;
       case GLSL_TYPE_INT:     dest_type = nir_type_int;     break;
       case GLSL_TYPE_UINT:    dest_type = nir_type_uint;    break;
@@ -3090,8 +3110,10 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
 
    intrin->src[0] = nir_src_for_ssa(&image.image->dest.ssa);
 
-   /* ImageQuerySize doesn't take any extra parameters */
-   if (opcode != SpvOpImageQuerySize) {
+   if (opcode == SpvOpImageQuerySize) {
+      /* ImageQuerySize only has an LOD which is currently always 0 */
+      intrin->src[1] = nir_src_for_ssa(nir_imm_int(&b->nb, 0));
+   } else {
       /* The image coordinate is always 4 components but we may not have that
        * many.  Swizzle to compensate.
        */
@@ -3131,17 +3153,20 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
    case SpvOpAtomicStore:
    case SpvOpImageWrite: {
       const uint32_t value_id = opcode == SpvOpAtomicStore ? w[4] : w[3];
-      nir_ssa_def *value = vtn_get_nir_ssa(b, value_id);
+      struct vtn_ssa_value *value = vtn_ssa_value(b, value_id);
       /* nir_intrinsic_image_deref_store always takes a vec4 value */
       assert(op == nir_intrinsic_image_deref_store);
       intrin->num_components = 4;
-      intrin->src[3] = nir_src_for_ssa(expand_to_vec4(&b->nb, value));
+      intrin->src[3] = nir_src_for_ssa(expand_to_vec4(&b->nb, value->def));
       /* Only OpImageWrite can support a lod parameter if
        * SPV_AMD_shader_image_load_store_lod is used but the current NIR
        * intrinsics definition for atomics requires us to set it for
        * OpAtomicStore.
        */
       intrin->src[4] = nir_src_for_ssa(image.lod);
+
+      if (opcode == SpvOpImageWrite)
+         nir_intrinsic_set_type(intrin, nir_get_nir_type_for_glsl_type(value->type));
       break;
    }
 
@@ -3194,6 +3219,9 @@ vtn_handle_image(struct vtn_builder *b, SpvOp opcode,
          result = nir_channels(&b->nb, result, (1 << dest_components) - 1);
 
       vtn_push_nir_ssa(b, w[2], result);
+
+      if (opcode == SpvOpImageRead)
+         nir_intrinsic_set_type(intrin, nir_get_nir_type_for_glsl_type(type->type));
    } else {
       nir_builder_instr_insert(&b->nb, &intrin->instr);
    }
@@ -4163,6 +4191,9 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
          break;
 
       case SpvCapabilityImageBasic:
+         spv_check_supported(kernel_image, cap);
+         break;
+
       case SpvCapabilityImageReadWrite:
       case SpvCapabilityImageMipmap:
       case SpvCapabilityPipes:
@@ -4378,18 +4409,28 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
                      "AddressingModelPhysical32 only supported for kernels");
          b->shader->info.cs.ptr_size = 32;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_32bit_global;
-         b->options->global_addr_format = nir_address_format_32bit_global;
-         b->options->temp_addr_format = nir_address_format_32bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 32);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 32);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelPhysical64:
          vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
                      "AddressingModelPhysical64 only supported for kernels");
          b->shader->info.cs.ptr_size = 64;
          b->physical_ptrs = true;
-         b->options->shared_addr_format = nir_address_format_64bit_global;
-         b->options->global_addr_format = nir_address_format_64bit_global;
-         b->options->temp_addr_format = nir_address_format_64bit_global;
+         assert(nir_address_format_bit_size(b->options->global_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->global_addr_format) == 1);
+         assert(nir_address_format_bit_size(b->options->shared_addr_format) == 64);
+         assert(nir_address_format_num_components(b->options->shared_addr_format) == 1);
+         if (!b->options->constant_as_global) {
+            assert(nir_address_format_bit_size(b->options->ubo_addr_format) == 64);
+            assert(nir_address_format_num_components(b->options->ubo_addr_format) == 1);
+         }
          break;
       case SpvAddressingModelLogical:
          vtn_fail_if(b->shader->info.stage == MESA_SHADER_KERNEL,
@@ -5463,12 +5504,23 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b,
 
       /* input variable */
       nir_variable *in_var = rzalloc(b->nb.shader, nir_variable);
-      in_var->data.mode = nir_var_shader_in;
+      in_var->data.mode = nir_var_uniform;
       in_var->data.read_only = true;
       in_var->data.location = i;
+      if (param_type->base_type == vtn_base_type_image) {
+         in_var->data.access = 0;
+         if (param_type->access_qualifier & SpvAccessQualifierReadOnly)
+            in_var->data.access |= ACCESS_NON_WRITEABLE;
+         if (param_type->access_qualifier & SpvAccessQualifierWriteOnly)
+            in_var->data.access |= ACCESS_NON_READABLE;
+      }
 
       if (is_by_val)
          in_var->type = param_type->deref->type;
+      else if (param_type->base_type == vtn_base_type_image)
+         in_var->type = param_type->glsl_image;
+      else if (param_type->base_type == vtn_base_type_sampler)
+         in_var->type = glsl_bare_sampler_type();
       else
          in_var->type = param_type->type;
 
@@ -5483,6 +5535,10 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b,
          nir_copy_var(&b->nb, copy_var, in_var);
          call->params[i] =
             nir_src_for_ssa(&nir_build_deref_var(&b->nb, copy_var)->dest.ssa);
+      } else if (param_type->base_type == vtn_base_type_image ||
+                 param_type->base_type == vtn_base_type_sampler) {
+         /* Don't load the var, just pass a deref of it */
+         call->params[i] = nir_src_for_ssa(&nir_build_deref_var(&b->nb, in_var)->dest.ssa);
       } else {
          call->params[i] = nir_src_for_ssa(nir_load_var(&b->nb, in_var));
       }
@@ -5531,6 +5587,10 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
       return NULL;
    }
 
+   /* Ensure a sane address mode is being used for function temps */
+   assert(nir_address_format_bit_size(b->options->temp_addr_format) == nir_get_ptr_bitsize(b->shader));
+   assert(nir_address_format_num_components(b->options->temp_addr_format) == 1);
+
    /* Set shader info defaults */
    if (stage == MESA_SHADER_GEOMETRY)
       b->shader->info.gs.invocations = 1;