nir/spirv: Rework access chains a bit to allow for literals
authorJason Ekstrand <jason.ekstrand@intel.com>
Thu, 21 Jan 2016 18:20:50 +0000 (10:20 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Fri, 22 Jan 2016 00:20:39 +0000 (16:20 -0800)
This makes them much easier to construct because you can also just specify
a literal number and it doesn't have to be a valid SPIR-V id.

src/glsl/nir/spirv/vtn_private.h
src/glsl/nir/spirv/vtn_variables.c

index 682bff5e8bb1340cb91204828a5cf3f9ce61a262..e0f4b220c4ce10a04522a152f8c8cd0d8f62cb2e 100644 (file)
@@ -236,13 +236,23 @@ struct vtn_type {
 
 struct vtn_variable;
 
+enum vtn_access_mode {
+   vtn_access_mode_id,
+   vtn_access_mode_literal,
+};
+
+struct vtn_access_link {
+   enum vtn_access_mode mode;
+   uint32_t id;
+};
+
 struct vtn_access_chain {
    struct vtn_variable *var;
 
    uint32_t length;
 
    /* Struct elements and array offsets */
-   uint32_t ids[0];
+   struct vtn_access_link link[0];
 };
 
 enum vtn_variable_mode {
index a9c2857f4cfab480418c5b2cc01adf6d4e024034..d41f5cd3deca5739a54a255c9d9b13f2de8e2149 100644 (file)
 
 #include "vtn_private.h"
 
+static struct vtn_access_chain *
+vtn_access_chain_extend(struct vtn_builder *b, struct vtn_access_chain *old,
+                        unsigned new_ids)
+{
+   struct vtn_access_chain *chain;
+
+   unsigned new_len = old->length + new_ids;
+   chain = ralloc_size(b, sizeof(*chain) + new_len * sizeof(chain->link[0]));
+
+   chain->var = old->var;
+   chain->length = new_len;
+
+   for (unsigned i = 0; i < old->length; i++)
+      chain->link[i] = old->link[i];
+
+   return chain;
+}
+
+static nir_ssa_def *
+vtn_access_link_as_ssa(struct vtn_builder *b, struct vtn_access_link link,
+                       unsigned stride)
+{
+   assert(stride > 0);
+   if (link.mode == vtn_access_mode_literal) {
+      return nir_imm_int(&b->nb, link.id * stride);
+   } else if (stride == 1) {
+      return vtn_ssa_value(b, link.id)->def;
+   } else {
+      return nir_imul(&b->nb, vtn_ssa_value(b, link.id)->def,
+                              nir_imm_int(&b->nb, stride));
+   }
+}
+
 /* Crawls a chain of array derefs and rewrites the types so that the
  * lengths stay the same but the terminal type is the one given by
  * tail_type.  This is useful for split structures.
@@ -60,7 +93,6 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
    nir_variable **members = chain->var->members;
 
    for (unsigned i = 0; i < chain->length; i++) {
-      struct vtn_value *idx_val = vtn_untyped_value(b, chain->ids[i]);
       enum glsl_base_type base_type = glsl_get_base_type(deref_type->type);
       switch (base_type) {
       case GLSL_TYPE_UINT:
@@ -81,15 +113,15 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
 
          deref_arr->deref.type = deref_type->type;
 
-         if (idx_val->value_type == vtn_value_type_constant) {
+         if (chain->link[i].mode == vtn_access_mode_literal) {
             deref_arr->deref_array_type = nir_deref_array_type_direct;
-            deref_arr->base_offset = idx_val->constant->value.u[0];
+            deref_arr->base_offset = chain->link[i].id;
          } else {
-            assert(idx_val->value_type == vtn_value_type_ssa);
-            assert(glsl_type_is_scalar(idx_val->ssa->type));
+            assert(chain->link[i].mode == vtn_access_mode_id);
             deref_arr->deref_array_type = nir_deref_array_type_indirect;
             deref_arr->base_offset = 0;
-            deref_arr->indirect = nir_src_for_ssa(idx_val->ssa->def);
+            deref_arr->indirect =
+               nir_src_for_ssa(vtn_ssa_value(b, chain->link[i].id)->def);
          }
          tail->child = &deref_arr->deref;
          tail = tail->child;
@@ -97,8 +129,8 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
       }
 
       case GLSL_TYPE_STRUCT: {
-         assert(idx_val->value_type == vtn_value_type_constant);
-         unsigned idx = idx_val->constant->value.u[0];
+         assert(chain->link[i].mode == vtn_access_mode_literal);
+         unsigned idx = chain->link[i].id;
          deref_type = deref_type->members[idx];
          if (members) {
             /* This is a pre-split structure. */
@@ -265,7 +297,7 @@ get_vulkan_resource_index(struct vtn_builder *b, struct vtn_access_chain *chain,
    nir_ssa_def *array_index;
    if (glsl_type_is_array(chain->var->type->type)) {
       assert(chain->length > 0);
-      array_index = vtn_ssa_value(b, chain->ids[0])->def;
+      array_index = vtn_access_link_as_ssa(b, chain->link[0], 1);
       *chain_idx = 1;
       *type = chain->var->type->array_element;
    } else {
@@ -315,9 +347,8 @@ vtn_access_chain_to_offset(struct vtn_builder *b,
 
       case GLSL_TYPE_ARRAY:
          offset = nir_iadd(&b->nb, offset,
-                           nir_imul(&b->nb,
-                                    vtn_ssa_value(b, chain->ids[idx])->def,
-                                    nir_imm_int(&b->nb, type->stride)));
+                           vtn_access_link_as_ssa(b, chain->link[idx],
+                                                  type->stride));
 
          if (glsl_type_is_vector(type->type)) {
             /* This had better be the tail */
@@ -330,10 +361,8 @@ vtn_access_chain_to_offset(struct vtn_builder *b,
          break;
 
       case GLSL_TYPE_STRUCT: {
-         struct vtn_value *member_val =
-            vtn_value(b, chain->ids[idx], vtn_value_type_constant);
-         unsigned member = member_val->constant->value.u[0];
-
+         assert(chain->link[idx].mode == vtn_access_mode_literal);
+         unsigned member = chain->link[idx].id;
          offset = nir_iadd(&b->nb, offset,
                            nir_imm_int(&b->nb, type->offsets[member]));
          type = type->members[member];
@@ -448,16 +477,15 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
          } else if (type->row_major) {
             /* Row-major but with an access chiain. */
             nir_ssa_def *col_offset =
-               nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
-                        nir_imm_int(&b->nb, type->array_element->stride));
+               vtn_access_link_as_ssa(b, chain->link[chain_idx],
+                                      type->array_element->stride);
             offset = nir_iadd(&b->nb, offset, col_offset);
 
             if (chain_idx + 1 < chain->length) {
                /* Picking off a single element */
                nir_ssa_def *row_offset =
-                  nir_imul(&b->nb,
-                           vtn_ssa_value(b, chain->ids[chain_idx + 1])->def,
-                           nir_imm_int(&b->nb, type->stride));
+                  vtn_access_link_as_ssa(b, chain->link[chain_idx + 1],
+                                         type->stride);
                offset = nir_iadd(&b->nb, offset, row_offset);
                _vtn_load_store_tail(b, op, load, index, offset, inout,
                                     glsl_scalar_type(base_type));
@@ -487,8 +515,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
          } else {
             /* Column-major with a deref. Fall through to array case. */
             nir_ssa_def *col_offset =
-               nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
-                        nir_imm_int(&b->nb, type->stride));
+               vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
             offset = nir_iadd(&b->nb, offset, col_offset);
 
             _vtn_block_load_store(b, op, load, index, offset,
@@ -502,8 +529,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
       } else {
          /* Single component of a vector. Fall through to array case. */
          nir_ssa_def *elem_offset =
-            nir_imul(&b->nb, vtn_ssa_value(b, chain->ids[chain_idx])->def,
-                     nir_imm_int(&b->nb, type->stride));
+            vtn_access_link_as_ssa(b, chain->link[chain_idx], type->stride);
          offset = nir_iadd(&b->nb, offset, elem_offset);
 
          _vtn_block_load_store(b, op, load, index, offset, NULL, 0,
@@ -1158,18 +1184,20 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          base = base_val->access_chain;
       }
 
-      uint32_t new_len = base->length + count - 4;
-      chain = ralloc_size(b, sizeof(*chain) + new_len * sizeof(chain->ids[0]));
-
-      *chain = *base;
+      chain = vtn_access_chain_extend(b, base, count - 4);
 
-      chain->length = new_len;
-      unsigned idx = 0;
-      for (int i = 0; i < base->length; i++)
-         chain->ids[idx++] = base->ids[i];
-
-      for (int i = 4; i < count; i++)
-         chain->ids[idx++] = w[i];
+      unsigned idx = base->length;
+      for (int i = 4; i < count; i++) {
+         struct vtn_value *link_val = vtn_untyped_value(b, w[i]);
+         if (link_val->value_type == vtn_value_type_constant) {
+            chain->link[idx].mode = vtn_access_mode_literal;
+            chain->link[idx].id = link_val->constant->value.u[0];
+         } else {
+            chain->link[idx].mode = vtn_access_mode_id;
+            chain->link[idx].id = w[i];
+         }
+         idx++;
+      }
 
       if (base_val->value_type == vtn_value_type_sampled_image) {
          struct vtn_value *val =