nir/spirv: Split up Uniform and UniformConstant storage classes
[mesa.git] / src / compiler / spirv / vtn_variables.c
index e564fb03cbe459b060d1948d3497b126b86f7bf3..754320afffb909b71d9f9881c6a4f8280c0537b1 100644 (file)
@@ -96,6 +96,10 @@ rewrite_deref_types(nir_deref *deref, const struct glsl_type *type)
 nir_deref_var *
 vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
 {
+   /* Do on-the-fly copy propagation for samplers. */
+   if (chain->var->copy_prop_sampler)
+      return vtn_access_chain_to_deref(b, chain->var->copy_prop_sampler);
+
    nir_deref_var *deref_var;
    if (chain->var->var) {
       deref_var = nir_deref_var_create(b, chain->var->var);
@@ -115,6 +119,8 @@ vtn_access_chain_to_deref(struct vtn_builder *b, struct vtn_access_chain *chain)
       switch (base_type) {
       case GLSL_TYPE_UINT:
       case GLSL_TYPE_INT:
+      case GLSL_TYPE_UINT64:
+      case GLSL_TYPE_INT64:
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_DOUBLE:
       case GLSL_TYPE_BOOL:
@@ -347,6 +353,8 @@ vtn_access_chain_to_offset(struct vtn_builder *b,
       switch (base_type) {
       case GLSL_TYPE_UINT:
       case GLSL_TYPE_INT:
+      case GLSL_TYPE_UINT64:
+      case GLSL_TYPE_INT64:
       case GLSL_TYPE_FLOAT:
       case GLSL_TYPE_DOUBLE:
       case GLSL_TYPE_BOOL:
@@ -395,6 +403,8 @@ vtn_type_block_size(struct vtn_type *type)
    switch (base_type) {
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT64:
+   case GLSL_TYPE_INT64:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_DOUBLE: {
@@ -403,7 +413,9 @@ vtn_type_block_size(struct vtn_type *type)
       if (cols > 1) {
          assert(type->stride > 0);
          return type->stride * cols;
-      } else if (base_type == GLSL_TYPE_DOUBLE) {
+      } else if (base_type == GLSL_TYPE_DOUBLE ||
+                base_type == GLSL_TYPE_UINT64 ||
+                base_type == GLSL_TYPE_INT64) {
          return glsl_get_vector_elements(type->type) * 8;
       } else {
          return glsl_get_vector_elements(type->type) * 4;
@@ -526,6 +538,8 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load,
    switch (base_type) {
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT64:
+   case GLSL_TYPE_INT64:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_DOUBLE:
    case GLSL_TYPE_BOOL:
@@ -738,6 +752,8 @@ _vtn_variable_load_store(struct vtn_builder *b, bool load,
    switch (base_type) {
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT64:
+   case GLSL_TYPE_INT64:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_BOOL:
    case GLSL_TYPE_DOUBLE:
@@ -815,6 +831,8 @@ _vtn_variable_copy(struct vtn_builder *b, struct vtn_access_chain *dest,
    switch (base_type) {
    case GLSL_TYPE_UINT:
    case GLSL_TYPE_INT:
+   case GLSL_TYPE_UINT64:
+   case GLSL_TYPE_INT64:
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_DOUBLE:
    case GLSL_TYPE_BOOL:
@@ -908,7 +926,10 @@ vtn_get_builtin_location(struct vtn_builder *b,
       set_mode_system_value(mode);
       break;
    case SpvBuiltInPrimitiveId:
-      if (*mode == nir_var_shader_out) {
+      if (b->shader->stage == MESA_SHADER_FRAGMENT) {
+         assert(*mode == nir_var_shader_in);
+         *location = VARYING_SLOT_PRIMITIVE_ID;
+      } else if (*mode == nir_var_shader_out) {
          *location = VARYING_SLOT_PRIMITIVE_ID;
       } else {
          *location = SYSTEM_VALUE_PRIMITIVE_ID;
@@ -972,8 +993,12 @@ vtn_get_builtin_location(struct vtn_builder *b,
       set_mode_system_value(mode);
       break;
    case SpvBuiltInSampleMask:
-      *location = SYSTEM_VALUE_SAMPLE_MASK_IN; /* XXX out? */
-      set_mode_system_value(mode);
+      if (*mode == nir_var_shader_out) {
+         *location = FRAG_RESULT_SAMPLE_MASK;
+      } else {
+         *location = SYSTEM_VALUE_SAMPLE_MASK_IN;
+         set_mode_system_value(mode);
+      }
       break;
    case SpvBuiltInFragDepth:
       *location = FRAG_RESULT_DEPTH;
@@ -1003,6 +1028,22 @@ vtn_get_builtin_location(struct vtn_builder *b,
       *location = SYSTEM_VALUE_GLOBAL_INVOCATION_ID;
       set_mode_system_value(mode);
       break;
+   case SpvBuiltInBaseVertex:
+      *location = SYSTEM_VALUE_BASE_VERTEX;
+      set_mode_system_value(mode);
+      break;
+   case SpvBuiltInBaseInstance:
+      *location = SYSTEM_VALUE_BASE_INSTANCE;
+      set_mode_system_value(mode);
+      break;
+   case SpvBuiltInDrawIndex:
+      *location = SYSTEM_VALUE_DRAW_ID;
+      set_mode_system_value(mode);
+      break;
+   case SpvBuiltInViewIndex:
+      *location = SYSTEM_VALUE_VIEW_INDEX;
+      set_mode_system_value(mode);
+      break;
    case SpvBuiltInHelperInvocation:
    default:
       unreachable("unsupported builtin");
@@ -1035,8 +1076,12 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
       assert(nir_var->constant_initializer != NULL);
       nir_var->data.read_only = true;
       break;
+   case SpvDecorationNonReadable:
+      nir_var->data.image.write_only = true;
+      break;
    case SpvDecorationNonWritable:
       nir_var->data.read_only = true;
+      nir_var->data.image.read_only = true;
       break;
    case SpvDecorationComponent:
       nir_var->data.location_frac = dec->literals[0];
@@ -1053,9 +1098,9 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
          nir_var->data.read_only = true;
 
          nir_constant *c = rzalloc(nir_var, nir_constant);
-         c->values[0].u32[0] = b->shader->info->cs.local_size[0];
-         c->values[0].u32[1] = b->shader->info->cs.local_size[1];
-         c->values[0].u32[2] = b->shader->info->cs.local_size[2];
+         c->values[0].u32[0] = b->shader->info.cs.local_size[0];
+         c->values[0].u32[1] = b->shader->info.cs.local_size[1];
+         c->values[0].u32[2] = b->shader->info.cs.local_size[2];
          nir_var->constant_initializer = c;
          break;
       }
@@ -1075,6 +1120,8 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
       case SpvBuiltInFragCoord:
          nir_var->data.pixel_center_integer = b->pixel_center_integer;
          break;
+      default:
+         break;
       }
    }
 
@@ -1086,7 +1133,6 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
    case SpvDecorationAliased:
    case SpvDecorationVolatile:
    case SpvDecorationCoherent:
-   case SpvDecorationNonReadable:
    case SpvDecorationUniform:
    case SpvDecorationStream:
    case SpvDecorationOffset:
@@ -1127,9 +1173,12 @@ apply_var_decoration(struct vtn_builder *b, nir_variable *nir_var,
    case SpvDecorationFPRoundingMode:
    case SpvDecorationFPFastMathMode:
    case SpvDecorationAlignment:
-      vtn_warn("Decoraiton only allowed for CL-style kernels: %s",
+      vtn_warn("Decoration only allowed for CL-style kernels: %s",
                spirv_decoration_to_string(dec->decoration));
       break;
+
+   default:
+      unreachable("Unhandled decoration");
    }
 }
 
@@ -1194,7 +1243,8 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
          is_vertex_input = false;
          location += vtn_var->patch ? VARYING_SLOT_PATCH0 : VARYING_SLOT_VAR0;
       } else {
-         unreachable("Location must be on input or output variable");
+         vtn_warn("Location must be on input or output variable");
+         return;
       }
 
       if (vtn_var->var) {
@@ -1239,6 +1289,73 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
    }
 }
 
+static enum vtn_variable_mode
+vtn_storage_class_to_mode(SpvStorageClass class,
+                          struct vtn_type *interface_type,
+                          nir_variable_mode *nir_mode_out)
+{
+   enum vtn_variable_mode mode;
+   nir_variable_mode nir_mode;
+   switch (class) {
+   case SpvStorageClassUniform:
+      if (interface_type->block) {
+         mode = vtn_variable_mode_ubo;
+         nir_mode = 0;
+      } else if (interface_type->buffer_block) {
+         mode = vtn_variable_mode_ssbo;
+         nir_mode = 0;
+      } else {
+         assert(!"Invalid uniform variable type");
+      }
+      break;
+   case SpvStorageClassUniformConstant:
+      if (glsl_type_is_image(interface_type->type)) {
+         mode = vtn_variable_mode_image;
+         nir_mode = nir_var_uniform;
+      } else if (glsl_type_is_sampler(interface_type->type)) {
+         mode = vtn_variable_mode_sampler;
+         nir_mode = nir_var_uniform;
+      } else {
+         assert(!"Invalid uniform constant variable type");
+      }
+      break;
+   case SpvStorageClassPushConstant:
+      mode = vtn_variable_mode_push_constant;
+      nir_mode = nir_var_uniform;
+      break;
+   case SpvStorageClassInput:
+      mode = vtn_variable_mode_input;
+      nir_mode = nir_var_shader_in;
+      break;
+   case SpvStorageClassOutput:
+      mode = vtn_variable_mode_output;
+      nir_mode = nir_var_shader_out;
+      break;
+   case SpvStorageClassPrivate:
+      mode = vtn_variable_mode_global;
+      nir_mode = nir_var_global;
+      break;
+   case SpvStorageClassFunction:
+      mode = vtn_variable_mode_local;
+      nir_mode = nir_var_local;
+      break;
+   case SpvStorageClassWorkgroup:
+      mode = vtn_variable_mode_workgroup;
+      nir_mode = nir_var_shared;
+      break;
+   case SpvStorageClassCrossWorkgroup:
+   case SpvStorageClassGeneric:
+   case SpvStorageClassAtomicCounter:
+   default:
+      unreachable("Unhandled variable storage class");
+   }
+
+   if (nir_mode_out)
+      *nir_mode_out = nir_mode;
+
+   return mode;
+}
+
 static bool
 is_per_vertex_inout(const struct vtn_variable *var, gl_shader_stage stage)
 {
@@ -1262,6 +1379,12 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
                      const uint32_t *w, unsigned count)
 {
    switch (opcode) {
+   case SpvOpUndef: {
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_undef);
+      val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      break;
+   }
+
    case SpvOpVariable: {
       struct vtn_variable *var = rzalloc(b, struct vtn_variable);
       var->type = vtn_value(b, w[1], vtn_value_type_type)->type;
@@ -1278,57 +1401,27 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
          without_array = without_array->array_element;
 
       nir_variable_mode nir_mode;
-      switch ((SpvStorageClass)w[3]) {
-      case SpvStorageClassUniform:
-      case SpvStorageClassUniformConstant:
-         if (without_array->block) {
-            var->mode = vtn_variable_mode_ubo;
-            b->shader->info->num_ubos++;
-         } else if (without_array->buffer_block) {
-            var->mode = vtn_variable_mode_ssbo;
-            b->shader->info->num_ssbos++;
-         } else if (glsl_type_is_image(without_array->type)) {
-            var->mode = vtn_variable_mode_image;
-            nir_mode = nir_var_uniform;
-            b->shader->info->num_images++;
-         } else if (glsl_type_is_sampler(without_array->type)) {
-            var->mode = vtn_variable_mode_sampler;
-            nir_mode = nir_var_uniform;
-            b->shader->info->num_textures++;
-         } else {
-            assert(!"Invalid uniform variable type");
-         }
-         break;
-      case SpvStorageClassPushConstant:
-         var->mode = vtn_variable_mode_push_constant;
-         assert(b->shader->num_uniforms == 0);
-         b->shader->num_uniforms = vtn_type_block_size(var->type);
-         break;
-      case SpvStorageClassInput:
-         var->mode = vtn_variable_mode_input;
-         nir_mode = nir_var_shader_in;
+      var->mode = vtn_storage_class_to_mode(w[3], without_array, &nir_mode);
+
+      switch (var->mode) {
+      case vtn_variable_mode_ubo:
+         b->shader->info.num_ubos++;
          break;
-      case SpvStorageClassOutput:
-         var->mode = vtn_variable_mode_output;
-         nir_mode = nir_var_shader_out;
+      case vtn_variable_mode_ssbo:
+         b->shader->info.num_ssbos++;
          break;
-      case SpvStorageClassPrivate:
-         var->mode = vtn_variable_mode_global;
-         nir_mode = nir_var_global;
+      case vtn_variable_mode_image:
+         b->shader->info.num_images++;
          break;
-      case SpvStorageClassFunction:
-         var->mode = vtn_variable_mode_local;
-         nir_mode = nir_var_local;
+      case vtn_variable_mode_sampler:
+         b->shader->info.num_textures++;
          break;
-      case SpvStorageClassWorkgroup:
-         var->mode = vtn_variable_mode_workgroup;
-         nir_mode = nir_var_shared;
+      case vtn_variable_mode_push_constant:
+         b->shader->num_uniforms = vtn_type_block_size(var->type);
          break;
-      case SpvStorageClassCrossWorkgroup:
-      case SpvStorageClassGeneric:
-      case SpvStorageClassAtomicCounter:
       default:
-         unreachable("Unhandled variable storage class");
+         /* No tallying is needed */
+         break;
       }
 
       switch (var->mode) {
@@ -1356,8 +1449,29 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
 
       case vtn_variable_mode_input:
       case vtn_variable_mode_output: {
+         /* In order to know whether or not we're a per-vertex inout, we need
+          * the patch qualifier.  This means walking the variable decorations
+          * early before we actually create any variables.  Not a big deal.
+          *
+          * GLSLang really likes to place decorations in the most interior
+          * thing it possibly can.  In particular, if you have a struct, it
+          * will place the patch decorations on the struct members.  This
+          * should be handled by the variable splitting below just fine.
+          *
+          * If you have an array-of-struct, things get even more weird as it
+          * will place the patch decorations on the struct even though it's
+          * inside an array and some of the members being patch and others not
+          * makes no sense whatsoever.  Since the only sensible thing is for
+          * it to be all or nothing, we'll call it patch if any of the members
+          * are declared patch.
+          */
          var->patch = false;
          vtn_foreach_decoration(b, val, var_is_patch_cb, &var->patch);
+         if (glsl_type_is_array(var->type->type) &&
+             glsl_type_is_struct(without_array->type)) {
+            vtn_foreach_decoration(b, without_array->val,
+                                   var_is_patch_cb, &var->patch);
+         }
 
          /* For inputs and outputs, we immediately split structures.  This
           * is for a couple of reasons.  For one, builtins may all come in
@@ -1397,6 +1511,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
                var->members[i]->interface_type =
                   interface_type->members[i]->type;
                var->members[i]->data.mode = nir_mode;
+               var->members[i]->data.patch = var->patch;
             }
          } else {
             var->var = rzalloc(b->shader, nir_variable);
@@ -1404,6 +1519,7 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
             var->var->type = var->type->type;
             var->var->interface_type = interface_type->type;
             var->var->data.mode = nir_mode;
+            var->var->data.patch = var->patch;
          }
 
          /* For inputs and outputs, we need to grab locations and builtin
@@ -1411,10 +1527,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
           */
          vtn_foreach_decoration(b, interface_type->val, var_decoration_cb, var);
          break;
+      }
 
       case vtn_variable_mode_param:
          unreachable("Not created through OpVariable");
-      }
 
       case vtn_variable_mode_ubo:
       case vtn_variable_mode_ssbo:
@@ -1538,6 +1654,16 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
    case SpvOpStore: {
       struct vtn_access_chain *dest =
          vtn_value(b, w[1], vtn_value_type_access_chain)->access_chain;
+
+      if (glsl_type_is_sampler(dest->var->type->type)) {
+         vtn_warn("OpStore of a sampler detected.  Doing on-the-fly copy "
+                  "propagation to workaround the problem.");
+         assert(dest->var->copy_prop_sampler == NULL);
+         dest->var->copy_prop_sampler =
+            vtn_value(b, w[2], vtn_value_type_access_chain)->access_chain;
+         break;
+      }
+
       struct vtn_ssa_value *src = vtn_ssa_value(b, w[2]);
       vtn_variable_store(b, src, dest);
       break;