spirv: fix typo in spec_constant_decoration_cb()
[mesa.git] / src / compiler / spirv / vtn_variables.c
index 5d5460ce29f85ba13386e3aee68f55d9e0f14a23..e3845365bddb7bcfdb0785f40b443fcb57a2384d 100644 (file)
@@ -35,7 +35,8 @@ vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
    struct vtn_access_chain *chain;
 
    unsigned new_len = old->length + new_ids;
-   chain = ralloc_size(b, sizeof(*chain) + new_len * sizeof(chain->link[0]));
+   /* TODO: don't use rzalloc */
+   chain = rzalloc_size(b, sizeof(*chain) + new_len * sizeof(chain->link[0]));
 
    chain->var = old->var;
    chain->length = new_len;
@@ -185,8 +186,7 @@ _vtn_local_load_store(struct vtn_builder *b, bool load, nir_deref_var *deref,
                                    nir_intrinsic_store_var;
 
       nir_intrinsic_instr *intrin = nir_intrinsic_instr_create(b->shader, op);
-      intrin->variables[0] =
-         nir_deref_as_var(nir_copy_deref(intrin, &deref->deref));
+      intrin->variables[0] = nir_deref_var_clone(deref, intrin);
       intrin->num_components = glsl_get_vector_elements(tail->type);
 
       if (load) {
@@ -385,9 +385,86 @@ end:
    return offset;
 }
 
+/* Tries to compute the size of an interface block based on the strides and
+ * offsets that are provided to us in the SPIR-V source.
+ */
+static unsigned
+vtn_type_block_size(struct vtn_type *type)
+{
+   enum glsl_base_type base_type = glsl_get_base_type(type->type);
+   switch (base_type) {
+   case GLSL_TYPE_UINT:
+   case GLSL_TYPE_INT:
+   case GLSL_TYPE_FLOAT:
+   case GLSL_TYPE_BOOL:
+   case GLSL_TYPE_DOUBLE: {
+      unsigned cols = type->row_major ? glsl_get_vector_elements(type->type) :
+                                        glsl_get_matrix_columns(type->type);
+      if (cols > 1) {
+         assert(type->stride > 0);
+         return type->stride * cols;
+      } else if (base_type == GLSL_TYPE_DOUBLE) {
+         return glsl_get_vector_elements(type->type) * 8;
+      } else {
+         return glsl_get_vector_elements(type->type) * 4;
+      }
+   }
+
+   case GLSL_TYPE_STRUCT:
+   case GLSL_TYPE_INTERFACE: {
+      unsigned size = 0;
+      unsigned num_fields = glsl_get_length(type->type);
+      for (unsigned f = 0; f < num_fields; f++) {
+         unsigned field_end = type->offsets[f] +
+                              vtn_type_block_size(type->members[f]);
+         size = MAX2(size, field_end);
+      }
+      return size;
+   }
+
+   case GLSL_TYPE_ARRAY:
+      assert(type->stride > 0);
+      assert(glsl_get_length(type->type) > 0);
+      return type->stride * glsl_get_length(type->type);
+
+   default:
+      assert(!"Invalid block type");
+      return 0;
+   }
+}
+
+static void
+vtn_access_chain_get_offset_size(struct vtn_access_chain *chain,
+                                 unsigned *access_offset,
+                                 unsigned *access_size)
+{
+   /* Only valid for push constants accesses now. */
+   assert(chain->var->mode == vtn_variable_mode_push_constant);
+
+   struct vtn_type *type = chain->var->type;
+
+   *access_offset = 0;
+
+   for (unsigned i = 0; i < chain->length; i++) {
+      if (chain->link[i].mode != vtn_access_mode_literal)
+         break;
+
+      if (glsl_type_is_struct(type->type)) {
+         *access_offset += type->offsets[chain->link[i].id];
+         type = type->members[chain->link[i].id];
+      } else {
+         *access_offset += type->stride * chain->link[i].id;
+         type = type->array_element;
+      }
+   }
+
+   *access_size = vtn_type_block_size(type);
+}
+
 static void
 _vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                      nir_ssa_def *index, nir_ssa_def *offset,
+                     unsigned access_offset, unsigned access_size,
                      struct vtn_ssa_value **inout, const struct glsl_type *type)
 {
    nir_intrinsic_instr *instr = nir_intrinsic_instr_create(b->nb.shader, op);
@@ -399,18 +476,25 @@ _vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load,
       instr->src[src++] = nir_src_for_ssa((*inout)->def);
    }
 
-   /* We set the base and size for push constant load to the entire push
-    * constant block for now.
-    */
    if (op == nir_intrinsic_load_push_constant) {
-      nir_intrinsic_set_base(instr, 0);
-      nir_intrinsic_set_range(instr, 128);
+      assert(access_offset % 4 == 0);
+
+      nir_intrinsic_set_base(instr, access_offset);
+      nir_intrinsic_set_range(instr, access_size);
    }
 
    if (index)
       instr->src[src++] = nir_src_for_ssa(index);
 
-   instr->src[src++] = nir_src_for_ssa(offset);
+   if (op == nir_intrinsic_load_push_constant) {
+      /* We need to subtract the offset from where the intrinsic will load the
+       * data. */
+      instr->src[src++] =
+         nir_src_for_ssa(nir_isub(&b->nb, offset,
+                                  nir_imm_int(&b->nb, access_offset)));
+   } else {
+      instr->src[src++] = nir_src_for_ssa(offset);
+   }
 
    if (load) {
       nir_ssa_dest_init(&instr->instr, &instr->dest,
@@ -428,6 +512,7 @@ _vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load,
 static void
 _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                       nir_ssa_def *index, nir_ssa_def *offset,
+                      unsigned access_offset, unsigned access_size,
                       struct vtn_access_chain *chain, unsigned chain_idx,
                       struct vtn_type *type, struct vtn_ssa_value **inout)
 {
@@ -472,6 +557,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                   nir_iadd(&b->nb, offset,
                            nir_imm_int(&b->nb, i * type->stride));
                _vtn_load_store_tail(b, op, load, index, elem_offset,
+                                    access_offset, access_size,
                                     &(*inout)->elems[i],
                                     glsl_vector_type(base_type, vec_width));
             }
@@ -493,8 +579,9 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                offset = nir_iadd(&b->nb, offset, row_offset);
                if (load)
                   *inout = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
-               _vtn_load_store_tail(b, op, load, index, offset, inout,
-                                    glsl_scalar_type(base_type));
+               _vtn_load_store_tail(b, op, load, index, offset,
+                                    access_offset, access_size,
+                                    inout, glsl_scalar_type(base_type));
             } else {
                /* Grabbing a column; picking one element off each row */
                unsigned num_comps = glsl_get_vector_elements(type->type);
@@ -514,6 +601,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
                   }
                   comp = &temp_val;
                   _vtn_load_store_tail(b, op, load, index, elem_offset,
+                                       access_offset, access_size,
                                        &comp, glsl_scalar_type(base_type));
                   comps[i] = comp->def;
                }
@@ -532,20 +620,25 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
             offset = nir_iadd(&b->nb, offset, col_offset);
 
             _vtn_block_load_store(b, op, load, index, offset,
+                                  access_offset, access_size,
                                   chain, chain_idx + 1,
                                   type->array_element, inout);
          }
       } else if (chain == NULL) {
          /* Single whole vector */
          assert(glsl_type_is_vector_or_scalar(type->type));
-         _vtn_load_store_tail(b, op, load, index, offset, inout, type->type);
+         _vtn_load_store_tail(b, op, load, index, offset,
+                              access_offset, access_size,
+                              inout, type->type);
       } else {
          /* Single component of a vector. Fall through to array case. */
          nir_ssa_def *elem_offset =
             vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
          offset = nir_iadd(&b->nb, offset, elem_offset);
 
-         _vtn_block_load_store(b, op, load, index, offset, NULL, 0,
+         _vtn_block_load_store(b, op, load, index, offset,
+                               access_offset, access_size,
+                               NULL, 0,
                                type->array_element, inout);
       }
       return;
@@ -555,7 +648,9 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
       for (unsigned i = 0; i < elems; i++) {
          nir_ssa_def *elem_off =
             nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, i * type->stride));
-         _vtn_block_load_store(b, op, load, index, elem_off, NULL, 0,
+         _vtn_block_load_store(b, op, load, index, elem_off,
+                               access_offset, access_size,
+                               NULL, 0,
                                type->array_element, &(*inout)->elems[i]);
       }
       return;
@@ -566,7 +661,9 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
       for (unsigned i = 0; i < elems; i++) {
          nir_ssa_def *elem_off =
             nir_iadd(&b->nb, offset, nir_imm_int(&b->nb, type->offsets[i]));
-         _vtn_block_load_store(b, op, load, index, elem_off, NULL, 0,
+         _vtn_block_load_store(b, op, load, index, elem_off,
+                               access_offset, access_size,
+                               NULL, 0,
                                type->members[i], &(*inout)->elems[i]);
       }
       return;
@@ -581,6 +678,7 @@ static struct vtn_ssa_value *
 vtn_block_load(struct vtn_builder *b, struct vtn_access_chain *src)
 {
    nir_intrinsic_op op;
+   unsigned access_offset = 0, access_size = 0;
    switch (src->var->mode) {
    case vtn_variable_mode_ubo:
       op = nir_intrinsic_load_ubo;
@@ -590,6 +688,7 @@ vtn_block_load(struct vtn_builder *b, struct vtn_access_chain *src)
       break;
    case vtn_variable_mode_push_constant:
       op = nir_intrinsic_load_push_constant;
+      vtn_access_chain_get_offset_size(src, &access_offset, &access_size);
       break;
    default:
       assert(!"Invalid block variable mode");
@@ -602,6 +701,7 @@ vtn_block_load(struct vtn_builder *b, struct vtn_access_chain *src)
 
    struct vtn_ssa_value *value = NULL;
    _vtn_block_load_store(b, op, true, index, offset,
+                         access_offset, access_size,
                          src, chain_idx, type, &value);
    return value;
 }
@@ -616,7 +716,7 @@ vtn_block_store(struct vtn_builder *b, struct vtn_ssa_value *src,
    offset = vtn_access_chain_to_offset(b, dst, &index, &type, &chain_idx, true);
 
    _vtn_block_load_store(b, nir_intrinsic_store_ssbo, false, index, offset,
-                         dst, chain_idx, type, &src);
+                         0, 0, dst, chain_idx, type, &src);
 }
 
 static bool
@@ -783,7 +883,7 @@ vtn_get_builtin_location(struct vtn_builder *b,
       *location = VARYING_SLOT_CLIP_DIST0; /* XXX CLIP_DIST1? */
       break;
    case SpvBuiltInCullDistance:
-      /* XXX figure this out */
+      *location = VARYING_SLOT_CULL_DIST0;
       break;
    case SpvBuiltInVertexIndex:
       *location = SYSTEM_VALUE_VERTEX_ID;
@@ -805,8 +905,12 @@ vtn_get_builtin_location(struct vtn_builder *b,
       set_mode_system_value(mode);
       break;
    case SpvBuiltInPrimitiveId:
-      *location = VARYING_SLOT_PRIMITIVE_ID;
-      *mode = nir_var_shader_out;
+      if (*mode == nir_var_shader_out) {
+         *location = VARYING_SLOT_PRIMITIVE_ID;
+      } else {
+         *location = SYSTEM_VALUE_PRIMITIVE_ID;
+         set_mode_system_value(mode);
+      }
       break;
    case SpvBuiltInInvocationId:
       *location = SYSTEM_VALUE_INVOCATION_ID;
@@ -814,7 +918,12 @@ vtn_get_builtin_location(struct vtn_builder *b,
       break;
    case SpvBuiltInLayer:
       *location = VARYING_SLOT_LAYER;
-      *mode = nir_var_shader_out;
+      if (b->shader->stage == MESA_SHADER_FRAGMENT)
+         *mode = nir_var_shader_in;
+      else if (b->shader->stage == MESA_SHADER_GEOMETRY)
+         *mode = nir_var_shader_out;
+      else
+         unreachable("invalid stage for SpvBuiltInLayer");
       break;
    case SpvBuiltInViewportIndex:
       *location = VARYING_SLOT_VIEWPORT;
@@ -921,7 +1030,6 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
       nir_var->data.location_frac = dec->literals[0];
       break;
    case SpvDecorationIndex:
-      nir_var->data.explicit_index = true;
       nir_var->data.index = dec->literals[0];
       break;
    case SpvDecorationBuiltIn: {
@@ -933,16 +1041,15 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
          nir_var->data.read_only = true;
 
          nir_constant *c = rzalloc(nir_var, nir_constant);
-         c->value.u[0] = b->shader->info.cs.local_size[0];
-         c->value.u[1] = b->shader->info.cs.local_size[1];
-         c->value.u[2] = b->shader->info.cs.local_size[2];
+         c->values[0].u32[0] = b->shader->info->cs.local_size[0];
+         c->values[0].u32[1] = b->shader->info->cs.local_size[1];
+         c->values[0].u32[2] = b->shader->info->cs.local_size[2];
          nir_var->constant_initializer = c;
          break;
       }
 
       nir_variable_mode mode = nir_var->data.mode;
       vtn_get_builtin_location(b, builtin, &nir_var->data.location, &mode);
-      nir_var->data.explicit_location = true;
       nir_var->data.mode = mode;
 
       if (builtin == SpvBuiltInFragCoord || builtin == SpvBuiltInSamplePosition)
@@ -1022,25 +1129,24 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
    case SpvDecorationDescriptorSet:
       vtn_var->descriptor_set = dec->literals[0];
       return;
+   case SpvDecorationInputAttachmentIndex:
+      vtn_var->input_attachment_index = dec->literals[0];
+      return;
    default:
       break;
    }
 
-   /* Now we handle decorations that apply to a particular nir_variable */
-   nir_variable *nir_var = vtn_var->var;
    if (val->value_type == vtn_value_type_access_chain) {
       assert(val->access_chain->length == 0);
       assert(val->access_chain->var == void_var);
       assert(member == -1);
    } else {
       assert(val->value_type == vtn_value_type_type);
-      if (member != -1)
-         nir_var = vtn_var->members[member];
    }
 
-   /* Location is odd in that it can apply in three different cases: To a
-    * non-split variable, to a whole split variable, or to one structure
-    * member of a split variable.
+   /* Location is odd.  If applied to a split structure, we have to walk the
+    * whole thing and accumulate the location.  It's easier to handle as a
+    * special case.
     */
    if (dec->decoration == SpvDecorationLocation) {
       unsigned location = dec->literals[0];
@@ -1058,13 +1164,12 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
          is_vertex_input = false;
          location += VARYING_SLOT_VAR0;
       } else {
-         assert(!"Location must be on input or output variable");
+         unreachable("Location must be on input or output variable");
       }
 
-      if (nir_var) {
+      if (vtn_var->var) {
          /* This handles the member and lone variable cases */
-         nir_var->data.location = location;
-         nir_var->data.explicit_location = true;
+         vtn_var->var->data.location = location;
       } else {
          /* This handles the structure member case */
          assert(vtn_var->members);
@@ -1072,66 +1177,35 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
             glsl_get_length(glsl_without_array(vtn_var->type->type));
          for (unsigned i = 0; i < length; i++) {
             vtn_var->members[i]->data.location = location;
-            vtn_var->members[i]->data.explicit_location = true;
             location +=
                glsl_count_attribute_slots(vtn_var->members[i]->interface_type,
                                           is_vertex_input);
          }
       }
       return;
-   }
-
-   if (nir_var == NULL)
-      return;
-
-   apply_var_decoration(b, nir_var, dec);
-}
-
-/* Tries to compute the size of an interface block based on the strides and
- * offsets that are provided to us in the SPIR-V source.
- */
-static unsigned
-vtn_type_block_size(struct vtn_type *type)
-{
-   enum glsl_base_type base_type = glsl_get_base_type(type->type);
-   switch (base_type) {
-   case GLSL_TYPE_UINT:
-   case GLSL_TYPE_INT:
-   case GLSL_TYPE_FLOAT:
-   case GLSL_TYPE_BOOL:
-   case GLSL_TYPE_DOUBLE: {
-      unsigned cols = type->row_major ? glsl_get_vector_elements(type->type) :
-                                        glsl_get_matrix_columns(type->type);
-      if (cols > 1) {
-         assert(type->stride > 0);
-         return type->stride * cols;
-      } else if (base_type == GLSL_TYPE_DOUBLE) {
-         return glsl_get_vector_elements(type->type) * 8;
+   } else {
+      if (vtn_var->var) {
+         assert(member <= 0);
+         apply_var_decoration(b, vtn_var->var, dec);
+      } else if (vtn_var->members) {
+         if (member >= 0) {
+            assert(vtn_var->members);
+            apply_var_decoration(b, vtn_var->members[member], dec);
+         } else {
+            unsigned length =
+               glsl_get_length(glsl_without_array(vtn_var->type->type));
+            for (unsigned i = 0; i < length; i++)
+               apply_var_decoration(b, vtn_var->members[i], dec);
+         }
       } else {
-         return glsl_get_vector_elements(type->type) * 4;
-      }
-   }
-
-   case GLSL_TYPE_STRUCT:
-   case GLSL_TYPE_INTERFACE: {
-      unsigned size = 0;
-      unsigned num_fields = glsl_get_length(type->type);
-      for (unsigned f = 0; f < num_fields; f++) {
-         unsigned field_end = type->offsets[f] +
-                              vtn_type_block_size(type->members[f]);
-         size = MAX2(size, field_end);
+         /* A few variables, those with external storage, have no actual
+          * nir_variables associated with them.  Fortunately, all decorations
+          * we care about for those variables are on the type only.
+          */
+         assert(vtn_var->mode == vtn_variable_mode_ubo ||
+                vtn_var->mode == vtn_variable_mode_ssbo ||
+                vtn_var->mode == vtn_variable_mode_push_constant);
       }
-      return size;
-   }
-
-   case GLSL_TYPE_ARRAY:
-      assert(type->stride > 0);
-      assert(glsl_get_length(type->type) > 0);
-      return type->stride * glsl_get_length(type->type);
-
-   default:
-      assert(!"Invalid block type");
-      return 0;
    }
 }
 
@@ -1161,18 +1235,18 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
       case SpvStorageClassUniformConstant:
          if (without_array->block) {
             var->mode = vtn_variable_mode_ubo;
-            b->shader->info.num_ubos++;
+            b->shader->info->num_ubos++;
          } else if (without_array->buffer_block) {
             var->mode = vtn_variable_mode_ssbo;
-            b->shader->info.num_ssbos++;
+            b->shader->info->num_ssbos++;
          } else if (glsl_type_is_image(without_array->type)) {
             var->mode = vtn_variable_mode_image;
             nir_mode = nir_var_uniform;
-            b->shader->info.num_images++;
+            b->shader->info->num_images++;
          } else if (glsl_type_is_sampler(without_array->type)) {
             var->mode = vtn_variable_mode_sampler;
             nir_mode = nir_var_uniform;
-            b->shader->info.num_textures++;
+            b->shader->info->num_textures++;
          } else {
             assert(!"Invalid uniform variable type");
          }
@@ -1316,6 +1390,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
           */
          var->var->data.binding = var->binding;
          var->var->data.descriptor_set = var->descriptor_set;
+         var->var->data.index = var->input_attachment_index;
 
          if (var->mode == vtn_variable_mode_image)
             var->var->data.image.format = without_array->image_format;
@@ -1365,7 +1440,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
          if (link_val->value_type == vtn_value_type_constant) {
             chain->link[idx].mode = vtn_access_mode_literal;
-            chain->link[idx].id = link_val->constant->value.u[0];
+            chain->link[idx].id = link_val->constant->values[0].u32[0];
          } else {
             chain->link[idx].mode = vtn_access_mode_id;
             chain->link[idx].id = w[i];