nir/spirv: Make vtn_get_builtin_location smarter
authorJason Ekstrand <jason.ekstrand@intel.com>
Fri, 23 Oct 2015 00:45:41 +0000 (17:45 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Fri, 23 Oct 2015 00:45:41 +0000 (17:45 -0700)
Instead of just stomping on the mode, it now validates asserts that the
previously set mode is correct and only changes it if needed.  We need to
do this because, in geometry shaders, there are some builtins that can be
either an input or an output depending on context.  We can get that
information from the SPIR-V source but we can't throw it away.

src/glsl/nir/spirv_to_nir.c

index c8594085d5ef0af20c455a8105745a85b0a4ea01..973ff7c6777aa807108b97a86649c35030799245 100644 (file)
@@ -633,21 +633,44 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
 }
 
 static void
-vtn_get_builtin_location(SpvBuiltIn builtin, int *location,
+set_mode_system_value(nir_variable_mode *mode)
+{
+   assert(*mode == nir_var_system_value || *mode == nir_var_shader_in);
+   *mode = nir_var_system_value;
+}
+
+static void
+validate_per_vertex_mode(struct vtn_builder *b, nir_variable_mode mode)
+{
+   switch (b->shader->stage) {
+   case MESA_SHADER_VERTEX:
+      assert(mode == nir_var_shader_out);
+      break;
+   case MESA_SHADER_GEOMETRY:
+      assert(mode == nir_var_shader_out || mode == nir_var_shader_in);
+      break;
+   default:
+      assert(!"Invalid shader stage");
+   }
+}
+
+static void
+vtn_get_builtin_location(struct vtn_builder *b,
+                         SpvBuiltIn builtin, int *location,
                          nir_variable_mode *mode)
 {
    switch (builtin) {
    case SpvBuiltInPosition:
       *location = VARYING_SLOT_POS;
-      *mode = nir_var_shader_out;
+      validate_per_vertex_mode(b, *mode);
       break;
    case SpvBuiltInPointSize:
       *location = VARYING_SLOT_PSIZ;
-      *mode = nir_var_shader_out;
+      validate_per_vertex_mode(b, *mode);
       break;
    case SpvBuiltInClipDistance:
       *location = VARYING_SLOT_CLIP_DIST0; /* XXX CLIP_DIST1? */
-      *mode = nir_var_shader_in;
+      validate_per_vertex_mode(b, *mode);
       break;
    case SpvBuiltInCullDistance:
       /* XXX figure this out */
@@ -657,11 +680,11 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location,
        * builtin keyword VertexIndex to indicate the non-zero-based value.
        */
       *location = SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
-      *mode = nir_var_system_value;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInInstanceId:
       *location = SYSTEM_VALUE_INSTANCE_ID;
-      *mode = nir_var_system_value;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInPrimitiveId:
       *location = VARYING_SLOT_PRIMITIVE_ID;
@@ -669,7 +692,7 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location,
       break;
    case SpvBuiltInInvocationId:
       *location = SYSTEM_VALUE_INVOCATION_ID;
-      *mode = nir_var_system_value;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInLayer:
       *location = VARYING_SLOT_LAYER;
@@ -682,35 +705,40 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location,
       unreachable("no tessellation support");
    case SpvBuiltInFragCoord:
       *location = VARYING_SLOT_POS;
-      *mode = nir_var_shader_in;
+      assert(b->shader->stage == MESA_SHADER_FRAGMENT);
+      assert(*mode == nir_var_shader_in);
       break;
    case SpvBuiltInPointCoord:
       *location = VARYING_SLOT_PNTC;
-      *mode = nir_var_shader_out;
+      assert(b->shader->stage == MESA_SHADER_FRAGMENT);
+      assert(*mode == nir_var_shader_in);
       break;
    case SpvBuiltInFrontFacing:
       *location = VARYING_SLOT_FACE;
-      *mode = nir_var_shader_out;
+      assert(b->shader->stage == MESA_SHADER_FRAGMENT);
+      assert(*mode == nir_var_shader_in);
       break;
    case SpvBuiltInSampleId:
       *location = SYSTEM_VALUE_SAMPLE_ID;
-      *mode = nir_var_shader_in;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInSamplePosition:
       *location = SYSTEM_VALUE_SAMPLE_POS;
-      *mode = nir_var_shader_in;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInSampleMask:
       *location = SYSTEM_VALUE_SAMPLE_MASK_IN; /* XXX out? */
-      *mode = nir_var_shader_in;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInFragColor:
       *location = FRAG_RESULT_COLOR;
-      *mode = nir_var_shader_out;
+      assert(b->shader->stage == MESA_SHADER_FRAGMENT);
+      assert(*mode == nir_var_shader_out);
       break;
    case SpvBuiltInFragDepth:
       *location = FRAG_RESULT_DEPTH;
-      *mode = nir_var_shader_out;
+      assert(b->shader->stage == MESA_SHADER_FRAGMENT);
+      assert(*mode == nir_var_shader_out);
       break;
    case SpvBuiltInNumWorkgroups:
    case SpvBuiltInWorkgroupSize:
@@ -723,11 +751,11 @@ vtn_get_builtin_location(SpvBuiltIn builtin, int *location,
       unreachable("unsupported builtin");
    case SpvBuiltInWorkgroupId:
       *location = SYSTEM_VALUE_WORK_GROUP_ID;
-      *mode = nir_var_system_value;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInLocalInvocationId:
       *location = SYSTEM_VALUE_LOCAL_INVOCATION_ID;
-      *mode = nir_var_system_value;
+      set_mode_system_value(mode);
       break;
    case SpvBuiltInHelperInvocation:
    default:
@@ -792,8 +820,8 @@ var_decoration_cb(struct vtn_builder *b, struct vtn_value *val, int member,
    case SpvDecorationBuiltIn: {
       SpvBuiltIn builtin = dec->literals[0];
 
-      nir_variable_mode mode;
-      vtn_get_builtin_location(builtin, &var->data.location, &mode);
+      nir_variable_mode mode = var->data.mode;
+      vtn_get_builtin_location(b, builtin, &var->data.location, &mode);
       var->data.explicit_location = true;
       var->data.mode = mode;
       if (mode == nir_var_shader_in || mode == nir_var_system_value)
@@ -842,7 +870,7 @@ get_builtin_variable(struct vtn_builder *b,
 
    if (!var) {
       int location;
-      vtn_get_builtin_location(builtin, &location, &mode);
+      vtn_get_builtin_location(b, builtin, &location, &mode);
 
       var = nir_variable_create(b->shader, mode, type, "builtin");