nir/spirv: support physical pointers
authorKarol Herbst <kherbst@redhat.com>
Thu, 31 Jan 2019 00:56:25 +0000 (01:56 +0100)
committerKarol Herbst <karolherbst@gmail.com>
Tue, 19 Mar 2019 04:08:07 +0000 (04:08 +0000)
v2: add load_kernel_input

Signed-off-by: Karol Herbst <kherbst@redhat.com>
squash! nir/spirv: support physical pointers

src/compiler/nir/nir_builder.h
src/compiler/nir/nir_intrinsics.py
src/compiler/nir/nir_lower_io.c
src/compiler/spirv/nir_spirv.h
src/compiler/spirv/spirv_to_nir.c
src/compiler/spirv/vtn_private.h
src/compiler/spirv/vtn_variables.c

index 9662cd2a217ba3801eea9b400411c341e1d90863..bcf03957bc73f9aae6a948315f05c3cc07f03a2d 100644 (file)
@@ -828,6 +828,14 @@ nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
    return nir_imov_alu(build, *src, num_components);
 }
 
+static inline unsigned
+nir_get_ptr_bitsize(nir_builder *build)
+{
+   if (build->shader->info.stage == MESA_SHADER_KERNEL)
+      return build->shader->info.cs.ptr_size;
+   return 32;
+}
+
 static inline nir_deref_instr *
 nir_build_deref_var(nir_builder *build, nir_variable *var)
 {
@@ -838,7 +846,8 @@ nir_build_deref_var(nir_builder *build, nir_variable *var)
    deref->type = var->type;
    deref->var = var;
 
-   nir_ssa_dest_init(&deref->instr, &deref->dest, 1, 32, NULL);
+   nir_ssa_dest_init(&deref->instr, &deref->dest, 1,
+                     nir_get_ptr_bitsize(build), NULL);
 
    nir_builder_instr_insert(build, &deref->instr);
 
index d88e4ef7d458011ac4624149dcc3d7a8d8d12ff9..ea092a991ca4e9fd05f37d4160b8f514addd57aa 100644 (file)
@@ -635,6 +635,8 @@ load("constant", 1, [BASE, RANGE], [CAN_ELIMINATE, CAN_REORDER])
 # src[] = { address }.
 # const_index[] = { access, align_mul, align_offset }
 load("global", 1, [ACCESS, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
+# src[] = { address }. const_index[] = { base, range, align_mul, align_offset }
+load("kernel_input", 1, [BASE, RANGE, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE, CAN_REORDER])
 
 # Stores work the same way as loads, except now the first source is the value
 # to store and the second (and possibly third) source specify where to store
index 786f295128624ab36c14e58ce09aa97915b21f30..749ac91d47e13d9cd9869c975f4221d1186e5411 100644 (file)
@@ -680,6 +680,10 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
       assert(addr_format_is_global(addr_format));
       op = nir_intrinsic_load_global;
       break;
+   case nir_var_shader_in:
+      assert(addr_format_is_global(addr_format));
+      op = nir_intrinsic_load_kernel_input;
+      break;
    default:
       unreachable("Unsupported explicit IO variable mode");
    }
@@ -687,14 +691,13 @@ build_explicit_io_load(nir_builder *b, nir_intrinsic_instr *intrin,
    nir_intrinsic_instr *load = nir_intrinsic_instr_create(b->shader, op);
 
    if (addr_format_is_global(addr_format)) {
-      assert(op == nir_intrinsic_load_global);
       load->src[0] = nir_src_for_ssa(addr_to_global(b, addr, addr_format));
    } else {
       load->src[0] = nir_src_for_ssa(addr_to_index(b, addr, addr_format));
       load->src[1] = nir_src_for_ssa(addr_to_offset(b, addr, addr_format));
    }
 
-   if (mode != nir_var_mem_ubo)
+   if (mode != nir_var_mem_ubo && mode != nir_var_shader_in)
       nir_intrinsic_set_access(load, nir_intrinsic_access(intrin));
 
    /* TODO: We should try and provide a better alignment.  For OpenCL, we need
@@ -821,17 +824,20 @@ lower_explicit_io_deref(nir_builder *b, nir_deref_instr *deref,
 
    b->cursor = nir_after_instr(&deref->instr);
 
-   /* Var derefs must be lowered away by the driver */
-   assert(deref->deref_type != nir_deref_type_var);
+   nir_ssa_def *parent_addr = NULL;
+   if (deref->deref_type != nir_deref_type_var) {
+      assert(deref->parent.is_ssa);
+      parent_addr = deref->parent.ssa;
+   }
 
-   assert(deref->parent.is_ssa);
-   nir_ssa_def *parent_addr = deref->parent.ssa;
 
    nir_ssa_def *addr = NULL;
    assert(deref->dest.is_ssa);
    switch (deref->deref_type) {
    case nir_deref_type_var:
-      unreachable("Must be lowered by the driver");
+      assert(deref->mode == nir_var_shader_in);
+      addr = nir_imm_intN_t(b, deref->var->data.driver_location,
+                            deref->dest.ssa.bit_size);
       break;
 
    case nir_deref_type_array: {
index 35b30660e296f338c8a09a1798b743c0235c622e..d8ccbb48ff8682ce3ea6bc616319897dba62093f 100644 (file)
@@ -70,6 +70,8 @@ struct spirv_to_nir_options {
    const struct glsl_type *phys_ssbo_ptr_type;
    const struct glsl_type *push_const_ptr_type;
    const struct glsl_type *shared_ptr_type;
+   const struct glsl_type *global_ptr_type;
+   const struct glsl_type *temp_ptr_type;
 
    struct {
       void (*func)(void *private_data,
index df5bba2c2a03be085c45c59e9061318c7d58cfe2..0ef8d67519a6d89fd891aadd426860ab356be96f 100644 (file)
@@ -32,6 +32,8 @@
 #include "nir/nir_deref.h"
 #include "spirv_info.h"
 
+#include "util/u_math.h"
+
 #include <stdio.h>
 
 void
@@ -1242,7 +1244,8 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
 
       val->type->base_type = vtn_base_type_array;
       val->type->array_element = array_element;
-      val->type->stride = 0;
+      if (b->shader->info.stage == MESA_SHADER_KERNEL)
+         val->type->stride = glsl_get_cl_size(array_element->type);
 
       vtn_foreach_decoration(b, val, array_stride_decoration_cb, NULL);
       val->type->type = glsl_array_type(array_element->type, val->type->length,
@@ -1270,6 +1273,15 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
          };
       }
 
+      if (b->shader->info.stage == MESA_SHADER_KERNEL) {
+         unsigned offset = 0;
+         for (unsigned i = 0; i < num_fields; i++) {
+            offset = align(offset, glsl_get_cl_alignment(fields[i].type));
+            fields[i].offset = offset;
+            offset += glsl_get_cl_size(fields[i].type);
+         }
+      }
+
       struct member_decoration_ctx ctx = {
          .num_fields = num_fields,
          .fields = fields,
@@ -1307,6 +1319,7 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
        * declaration.
        */
       val = vtn_untyped_value(b, w[1]);
+      struct vtn_type *deref_type = vtn_untyped_value(b, w[3])->type;
 
       SpvStorageClass storage_class = w[2];
 
@@ -1335,6 +1348,19 @@ vtn_handle_type(struct vtn_builder *b, SpvOp opcode,
             break;
          case SpvStorageClassWorkgroup:
             val->type->type = b->options->shared_ptr_type;
+            if (b->physical_ptrs)
+               val->type->stride = align(glsl_get_cl_size(deref_type->type), glsl_get_cl_alignment(deref_type->type));
+            break;
+         case SpvStorageClassCrossWorkgroup:
+            val->type->type = b->options->global_ptr_type;
+            if (b->physical_ptrs)
+               val->type->stride = align(glsl_get_cl_size(deref_type->type), glsl_get_cl_alignment(deref_type->type));
+            break;
+         case SpvStorageClassFunction:
+            if (b->physical_ptrs) {
+               val->type->type = b->options->temp_ptr_type;
+               val->type->stride = align(glsl_get_cl_size(deref_type->type), glsl_get_cl_alignment(deref_type->type));
+            }
             break;
          default:
             /* In this case, no variable pointers are allowed so all deref
@@ -3751,12 +3777,18 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
                      "AddressingModelPhysical32 only supported for kernels");
          b->shader->info.cs.ptr_size = 32;
          b->physical_ptrs = true;
+         b->options->shared_ptr_type = glsl_uint_type();
+         b->options->global_ptr_type = glsl_uint_type();
+         b->options->temp_ptr_type = glsl_uint_type();
          break;
       case SpvAddressingModelPhysical64:
          vtn_fail_if(b->shader->info.stage != MESA_SHADER_KERNEL,
                      "AddressingModelPhysical64 only supported for kernels");
          b->shader->info.cs.ptr_size = 64;
          b->physical_ptrs = true;
+         b->options->shared_ptr_type = glsl_uint64_t_type();
+         b->options->global_ptr_type = glsl_uint64_t_type();
+         b->options->temp_ptr_type = glsl_uint64_t_type();
          break;
       case SpvAddressingModelLogical:
          vtn_fail_if(b->shader->info.stage >= MESA_SHADER_STAGES,
@@ -4086,6 +4118,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpAccessChain:
    case SpvOpPtrAccessChain:
    case SpvOpInBoundsAccessChain:
+   case SpvOpInBoundsPtrAccessChain:
    case SpvOpArrayLength:
    case SpvOpConvertPtrToU:
    case SpvOpConvertUToPtr:
@@ -4402,6 +4435,10 @@ vtn_create_builder(const uint32_t *words, size_t word_count,
 {
    /* Initialize the vtn_builder object */
    struct vtn_builder *b = rzalloc(NULL, struct vtn_builder);
+   struct spirv_to_nir_options *dup_options =
+      ralloc(b, struct spirv_to_nir_options);
+   *dup_options = *options;
+
    b->spirv = words;
    b->spirv_word_count = word_count;
    b->file = NULL;
@@ -4410,7 +4447,7 @@ vtn_create_builder(const uint32_t *words, size_t word_count,
    exec_list_make_empty(&b->functions);
    b->entry_point_stage = stage;
    b->entry_point_name = entry_point_name;
-   b->options = options;
+   b->options = dup_options;
 
    /*
     * Handle the SPIR-V header (first 5 dwords).
index 20aedc170d91a928627a8b4f482aef4c2c1e8f39..463d6173640110a674d590ef64114ab716c81e3e 100644 (file)
@@ -571,7 +571,7 @@ struct vtn_builder {
    size_t spirv_word_count;
 
    nir_shader *shader;
-   const struct spirv_to_nir_options *options;
+   struct spirv_to_nir_options *options;
    struct vtn_block *block;
 
    /* Current offset, file, line, and column.  Useful for debugging.  Set
index 053d6089e45e22805d01c0fef609ec38a058d409..da4c57f7c77ca2a465bd97e8972679f696d801de 100644 (file)
@@ -1872,9 +1872,8 @@ vtn_pointer_from_ssa(struct vtn_builder *b, nir_ssa_def *ssa,
    } else {
       const struct glsl_type *deref_type = ptr_type->deref->type;
       if (!vtn_pointer_is_external_block(b, ptr)) {
-         assert(ssa->bit_size == 32 && ssa->num_components == 1);
          ptr->deref = nir_build_deref_cast(&b->nb, ssa, nir_mode,
-                                           glsl_get_bare_type(deref_type), 0);
+                                           deref_type, 0);
       } else if (vtn_type_contains_block(b, ptr->type) &&
                  ptr->mode != vtn_variable_mode_phys_ssbo) {
          /* This is a pointer to somewhere in an array of blocks, not a
@@ -2317,9 +2316,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
 
    case SpvOpAccessChain:
    case SpvOpPtrAccessChain:
-   case SpvOpInBoundsAccessChain: {
+   case SpvOpInBoundsAccessChain:
+   case SpvOpInBoundsPtrAccessChain: {
       struct vtn_access_chain *chain = vtn_access_chain_create(b, count - 4);
-      chain->ptr_as_array = (opcode == SpvOpPtrAccessChain);
+      chain->ptr_as_array = (opcode == SpvOpPtrAccessChain || opcode == SpvOpInBoundsPtrAccessChain);
 
       unsigned idx = 0;
       for (int i = 4; i < count; i++) {