nir/spirv: Add support for multiple entrypoints per shader
authorJason Ekstrand <jason.ekstrand@intel.com>
Thu, 31 Dec 2015 01:17:12 +0000 (17:17 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Thu, 31 Dec 2015 01:45:43 +0000 (17:45 -0800)
This is done by passing the entrypoint name into spirv_to_nir.  It will
then process the shader as if that were the only entrypoint we care about.
Instead of returning a nir_shader, it now returns a nir_function.

src/glsl/nir/spirv/nir_spirv.h
src/glsl/nir/spirv/spirv_to_nir.c
src/glsl/nir/spirv/vtn_private.h
src/glsl/nir/spirv2nir.c
src/vulkan/anv_pipeline.c

index 3254f10a88d36f0ecc277ad618e751735140df83..506bd981101f4df86e6b86585bc63849c62e1dc4 100644 (file)
@@ -36,8 +36,9 @@
 extern "C" {
 #endif
 
-nir_shader *spirv_to_nir(const uint32_t *words, size_t word_count,
-                         const nir_shader_compiler_options *options);
+nir_function *spirv_to_nir(const uint32_t *words, size_t word_count,
+                           const char *entry_point_name,
+                           const nir_shader_compiler_options *options);
 
 #ifdef __cplusplus
 }
index 9caba9746e6533552f14450e41a030a8e32a9a0d..5b31f7e7e2a165623982c59db11741ad5d427965 100644 (file)
@@ -3223,11 +3223,16 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       assert(w[2] == SpvMemoryModelGLSL450);
       break;
 
-   case SpvOpEntryPoint:
+   case SpvOpEntryPoint: {
+      char *name = vtn_string_literal(b, &w[3], count - 3);
+      if (strcmp(name, b->entry_point_name) != 0)
+         break;
+
       assert(b->entry_point == NULL);
       b->entry_point = &b->values[w[2]];
       b->execution_model = w[1];
       break;
+   }
 
    case SpvOpString:
       vtn_push_value(b, w[1], vtn_value_type_string)->str =
@@ -3665,8 +3670,9 @@ stage_for_execution_model(SpvExecutionModel model)
    }
 }
 
-nir_shader *
+nir_function *
 spirv_to_nir(const uint32_t *words, size_t word_count,
+             const char *entry_point_name,
              const nir_shader_compiler_options *options)
 {
    const uint32_t *word_end = words + word_count;
@@ -3687,14 +3693,20 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    b->value_id_bound = value_id_bound;
    b->values = rzalloc_array(b, struct vtn_value, value_id_bound);
    exec_list_make_empty(&b->functions);
+   b->entry_point_name = entry_point_name;
 
    /* Handle all the preamble instructions */
    words = vtn_foreach_instruction(b, words, word_end,
                                    vtn_handle_preamble_instruction);
 
+   if (b->entry_point == NULL) {
+      assert(!"Entry point not found");
+      ralloc_free(b);
+      return NULL;
+   }
+
    gl_shader_stage stage = stage_for_execution_model(b->execution_model);
-   nir_shader *shader = nir_shader_create(NULL, stage, options);
-   b->shader = shader;
+   b->shader = nir_shader_create(NULL, stage, options);
 
    /* Parse execution modes */
    vtn_foreach_execution_mode(b, b->entry_point,
@@ -3717,12 +3729,16 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
                               vtn_handle_phi_second_pass);
    }
 
+   assert(b->entry_point->value_type == vtn_value_type_function);
+   nir_function *entry_point = b->entry_point->func->impl->function;
+   assert(entry_point);
+
    ralloc_free(b);
 
    /* Because we can still have output reads in NIR, we need to lower
     * outputs to temporaries before we are truely finished.
     */
-   nir_lower_outputs_to_temporaries(shader);
+   nir_lower_outputs_to_temporaries(entry_point->shader);
 
-   return shader;
+   return entry_point;
 }
index 318f60804cc86c094c2b63b61809b76d6859d565..2af0e357acdb3e98320e00951b6f5e73f5e05c99 100644 (file)
@@ -317,6 +317,7 @@ struct vtn_builder {
    unsigned value_id_bound;
    struct vtn_value *values;
 
+   const char *entry_point_name;
    struct vtn_value *entry_point;
    SpvExecutionModel execution_model;
    bool origin_upper_left;
index f86825bedc502bd01a8486d7fdd248281cb09648..db56d09c98d5defd9c1de64940e1eba7f7e671d8 100644 (file)
@@ -49,6 +49,6 @@ int main(int argc, char **argv)
    const void *map = mmap(NULL, len, PROT_READ, MAP_PRIVATE, fd, 0);
    assert(map != NULL);
 
-   nir_shader *shader = spirv_to_nir(map, word_count, NULL);
-   nir_print_shader(shader, stderr);
+   nir_function *func = spirv_to_nir(map, word_count, "main", NULL);
+   nir_print_shader(func->shader, stderr);
 }
index 0f7835ea84f24bfcce2e22537257494fbfcb3c24..12430f824adec29b6a7021b264638f7f0d56983b 100644 (file)
@@ -96,6 +96,7 @@ anv_shader_compile_to_nir(struct anv_device *device,
       compiler->glsl_compiler_options[stage].NirOptions;
 
    nir_shader *nir;
+   nir_function *entry_point;
    if (module->nir) {
       /* Some things such as our meta clear/blit code will give us a NIR
        * shader directly.  In that case, we just ignore the SPIR-V entirely
@@ -103,12 +104,18 @@ anv_shader_compile_to_nir(struct anv_device *device,
       nir = module->nir;
       nir->options = nir_options;
       nir_validate_shader(nir);
+
+      assert(exec_list_length(&nir->functions) == 1);
+      struct exec_node *node = exec_list_get_head(&nir->functions);
+      entry_point = exec_node_data(nir_function, node, node);
    } else {
       uint32_t *spirv = (uint32_t *) module->data;
       assert(spirv[0] == SPIR_V_MAGIC_NUMBER);
       assert(module->size % 4 == 0);
 
-      nir = spirv_to_nir(spirv, module->size / 4, nir_options);
+      entry_point = spirv_to_nir(spirv, module->size / 4, entrypoint_name,
+                                 nir_options);
+      nir = entry_point->shader;
       assert(nir->stage == stage);
       nir_validate_shader(nir);
 
@@ -126,24 +133,15 @@ anv_shader_compile_to_nir(struct anv_device *device,
    nir->info.separate_shader = true;
 
    /* Pick off the single entrypoint that we want */
-   nir_function_impl *entrypoint = NULL;
    foreach_list_typed_safe(nir_function, func, node, &nir->functions) {
-      if (strcmp(entrypoint_name, func->name) != 0) {
-         /* Not our function, get rid of it */
+      if (func != entry_point)
          exec_node_remove(&func->node);
-         continue;
-      }
-
-      assert(entrypoint == NULL);
-      assert(func->impl);
-      entrypoint = func->impl;
    }
    assert(exec_list_length(&nir->functions) == 1);
-   assert(entrypoint != NULL);
 
    nir = brw_preprocess_nir(nir, compiler->scalar_stage[stage]);
 
-   nir_shader_gather_info(nir, entrypoint);
+   nir_shader_gather_info(nir, entry_point->impl);
 
    return nir;
 }