nir/spirv: Add initial support for specialization constants
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 13 Jan 2016 00:28:28 +0000 (16:28 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Wed, 13 Jan 2016 23:18:36 +0000 (15:18 -0800)
src/glsl/nir/spirv/nir_spirv.h
src/glsl/nir/spirv/spirv_to_nir.c
src/glsl/nir/spirv/vtn_private.h
src/glsl/nir/spirv2nir.c
src/vulkan/anv_pipeline.c

index 354c0a902df5de579e9e02a59aedcd595ab23862..9c9c93d36c25f24d2ce4c2b21457f7a0cc8aca3f 100644 (file)
 extern "C" {
 #endif
 
+struct nir_spirv_specialization {
+   uint32_t id;
+   uint32_t data;
+};
+
 nir_function *spirv_to_nir(const uint32_t *words, size_t word_count,
+                           struct nir_spirv_specialization *specializations,
+                           unsigned num_specializations,
                            gl_shader_stage stage, const char *entry_point_name,
                            const nir_shader_compiler_options *options);
 
index bbde4554c6bee994137f71279a809ecf2849a959..deea1adb3ea3716dcd97964982804e3bbfdde28c 100644 (file)
@@ -804,6 +804,33 @@ vtn_null_constant(struct vtn_builder *b, const struct glsl_type *type)
    return c;
 }
 
+static void
+spec_constant_deocoration_cb(struct vtn_builder *b, struct vtn_value *v,
+                             int member, const struct vtn_decoration *dec,
+                             void *data)
+{
+   assert(member == -1);
+   if (dec->decoration != SpvDecorationSpecId)
+      return;
+
+   uint32_t *const_value = data;
+
+   for (unsigned i = 0; i < b->num_specializations; i++) {
+      if (b->specializations[i].id == dec->literals[0]) {
+         *const_value = b->specializations[i].data;
+         return;
+      }
+   }
+}
+
+static uint32_t
+get_specialization(struct vtn_builder *b, struct vtn_value *val,
+                   uint32_t const_value)
+{
+   vtn_foreach_decoration(b, val, spec_constant_deocoration_cb, &const_value);
+   return const_value;
+}
+
 static void
 vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count)
@@ -820,10 +847,25 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
       assert(val->const_type == glsl_bool_type());
       val->constant->value.u[0] = NIR_FALSE;
       break;
+
+   case SpvOpSpecConstantTrue:
+   case SpvOpSpecConstantFalse: {
+      assert(val->const_type == glsl_bool_type());
+      uint32_t int_val =
+         get_specialization(b, val, (opcode == SpvOpSpecConstantTrue));
+      val->constant->value.u[0] = int_val ? NIR_TRUE : NIR_FALSE;
+      break;
+   }
+
    case SpvOpConstant:
       assert(glsl_type_is_scalar(val->const_type));
       val->constant->value.u[0] = w[3];
       break;
+   case SpvOpSpecConstant:
+      assert(glsl_type_is_scalar(val->const_type));
+      val->constant->value.u[0] = get_specialization(b, val, w[3]);
+      break;
+   case SpvOpSpecConstantComposite:
    case SpvOpConstantComposite: {
       unsigned elem_count = count - 3;
       nir_constant **elems = ralloc_array(b, nir_constant *, elem_count);
@@ -3493,6 +3535,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
 
 nir_function *
 spirv_to_nir(const uint32_t *words, size_t word_count,
+             struct nir_spirv_specialization *spec, unsigned num_spec,
              gl_shader_stage stage, const char *entry_point_name,
              const nir_shader_compiler_options *options)
 {
@@ -3533,6 +3576,9 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    vtn_foreach_execution_mode(b, b->entry_point,
                               vtn_handle_execution_mode, NULL);
 
+   b->specializations = spec;
+   b->num_specializations = num_spec;
+
    /* Handle all variable, type, and constant instructions */
    words = vtn_foreach_instruction(b, words, word_end,
                                    vtn_handle_variable_or_type_instruction);
index 7ab3c9fba4c4fe17355bab17f7b70c775253849c..1f88eeda9414c61a61c32a8def44b4adb81c2efe 100644 (file)
@@ -310,6 +310,9 @@ struct vtn_builder {
     */
    struct hash_table *phi_table;
 
+   unsigned num_specializations;
+   struct nir_spirv_specialization *specializations;
+
    /*
     * NIR variable for each SPIR-V builtin.
     */
index 4cb484cd0743cb2cd4876a8b7c71275780d42d34..c837186bdfce56605dcb26cdc0d3b25dd0e11446 100644 (file)
@@ -49,7 +49,7 @@ int main(int argc, char **argv)
    const void *map = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
    assert(map != NULL);
 
-   nir_function *func = spirv_to_nir(map, word_count, MESA_SHADER_FRAGMENT,
-                                     "main", NULL);
+   nir_function *func = spirv_to_nir(map, word_count, NULL, 0,
+                                     MESA_SHADER_FRAGMENT, "main", NULL);
    nir_print_shader(func->shader, stderr);
 }
index db4e19bf486ab1bfe67045abb80cab7878947513..3d632dec0383edac27df8db04f90c2634c19d4e4 100644 (file)
@@ -113,7 +113,7 @@ anv_shader_compile_to_nir(struct anv_device *device,
       assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
       assert(module->size % 4 == 0);
 
-      entry_point = spirv_to_nir(spirv, module->size / 4, stage,
+      entry_point = spirv_to_nir(spirv, module->size / 4, NULL, 0, stage,
                                  entrypoint_name, nir_options);
       nir = entry_point->shader;
       assert(nir->stage == stage);