nir/spirv: Handle OpBranchConditional
authorJason Ekstrand <jason.ekstrand@intel.com>
Mon, 4 May 2015 17:23:09 +0000 (10:23 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Mon, 31 Aug 2015 23:58:20 +0000 (16:58 -0700)
We do control-flow handling as a two-step process.  The first step is to
walk the instructions list and record various information about blocks and
functions.  This is where the acutal nir_function_overload objects get
created.  We also record the start/stop instruction for each block.  Then
a second pass walks over each of the functions and over the blocks in each
function in a way that's NIR-friendly and actually parses the instructions.

src/glsl/nir/spirv_to_nir.c

index 840b4c6fc65170f300fa7a4a350b51fdab62912c..0bbae8ee874f84fad7af0473749f799f25ed47b6 100644 (file)
@@ -44,6 +44,19 @@ enum vtn_value_type {
    vtn_value_type_ssa,
 };
 
+struct vtn_block {
+   const uint32_t *label;
+   const uint32_t *branch;
+   nir_block *block;
+};
+
+struct vtn_function {
+   struct exec_node node;
+
+   nir_function_overload *overload;
+   struct vtn_block *start_block;
+};
+
 struct vtn_value {
    enum vtn_value_type value_type;
    const char *name;
@@ -54,8 +67,8 @@ struct vtn_value {
       const struct glsl_type *type;
       nir_constant *constant;
       nir_deref_var *deref;
-      nir_function_impl *impl;
-      nir_block *block;
+      struct vtn_function *func;
+      struct vtn_block *block;
       nir_ssa_def *ssa;
    };
 };
@@ -71,12 +84,17 @@ struct vtn_builder {
    nir_shader *shader;
    nir_function_impl *impl;
    struct exec_list *cf_list;
+   struct vtn_block *block;
+   struct vtn_block *merge_block;
 
    unsigned value_id_bound;
    struct vtn_value *values;
 
    SpvExecutionModel execution_model;
    struct vtn_value *entry_point;
+
+   struct vtn_function *func;
+   struct exec_list functions;
 };
 
 static struct vtn_value *
@@ -672,60 +690,10 @@ vtn_handle_variables(struct vtn_builder *b, SpvOp opcode,
 }
 
 static void
-vtn_handle_functions(struct vtn_builder *b, SpvOp opcode,
-                     const uint32_t *w, unsigned count)
+vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
+                         const uint32_t *w, unsigned count)
 {
-   switch (opcode) {
-   case SpvOpFunction: {
-      assert(b->impl == NULL);
-
-      const struct glsl_type *result_type =
-         vtn_value(b, w[1], vtn_value_type_type)->type;
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
-      const struct glsl_type *func_type =
-         vtn_value(b, w[4], vtn_value_type_type)->type;
-
-      assert(glsl_get_function_return_type(func_type) == result_type);
-
-      nir_function *func =
-         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
-
-      nir_function_overload *overload = nir_function_overload_create(func);
-      overload->num_params = glsl_get_length(func_type);
-      overload->params = ralloc_array(overload, nir_parameter,
-                                      overload->num_params);
-      for (unsigned i = 0; i < overload->num_params; i++) {
-         const struct glsl_function_param *param =
-            glsl_get_function_param(func_type, i);
-         overload->params[i].type = param->type;
-         if (param->in) {
-            if (param->out) {
-               overload->params[i].param_type = nir_parameter_inout;
-            } else {
-               overload->params[i].param_type = nir_parameter_in;
-            }
-         } else {
-            if (param->out) {
-               overload->params[i].param_type = nir_parameter_out;
-            } else {
-               assert(!"Parameter is neither in nor out");
-            }
-         }
-      }
-
-      val->impl = b->impl = nir_function_impl_create(overload);
-      b->cf_list = &b->impl->body;
-
-      break;
-   }
-   case SpvOpFunctionEnd:
-      b->impl = NULL;
-      break;
-   case SpvOpFunctionParameter:
-   case SpvOpFunctionCall:
-   default:
-      unreachable("Unhandled opcode");
-   }
+   unreachable("Unhandled opcode");
 }
 
 static void
@@ -841,22 +809,118 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
    return true;
 }
 
+static bool
+vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
+                                      const uint32_t *w, unsigned count)
+{
+   switch (opcode) {
+   case SpvOpFunction: {
+      assert(b->func == NULL);
+      b->func = rzalloc(b, struct vtn_function);
+
+      const struct glsl_type *result_type =
+         vtn_value(b, w[1], vtn_value_type_type)->type;
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
+      const struct glsl_type *func_type =
+         vtn_value(b, w[4], vtn_value_type_type)->type;
+
+      assert(glsl_get_function_return_type(func_type) == result_type);
+
+      nir_function *func =
+         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
+
+      nir_function_overload *overload = nir_function_overload_create(func);
+      overload->num_params = glsl_get_length(func_type);
+      overload->params = ralloc_array(overload, nir_parameter,
+                                      overload->num_params);
+      for (unsigned i = 0; i < overload->num_params; i++) {
+         const struct glsl_function_param *param =
+            glsl_get_function_param(func_type, i);
+         overload->params[i].type = param->type;
+         if (param->in) {
+            if (param->out) {
+               overload->params[i].param_type = nir_parameter_inout;
+            } else {
+               overload->params[i].param_type = nir_parameter_in;
+            }
+         } else {
+            if (param->out) {
+               overload->params[i].param_type = nir_parameter_out;
+            } else {
+               assert(!"Parameter is neither in nor out");
+            }
+         }
+      }
+      b->func->overload = overload;
+      break;
+   }
+
+   case SpvOpFunctionEnd:
+      b->func = NULL;
+      break;
+
+   case SpvOpFunctionParameter:
+      break; /* Does nothing */
+
+   case SpvOpLabel: {
+      assert(b->block == NULL);
+      b->block = rzalloc(b, struct vtn_block);
+      b->block->label = w;
+      vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
+
+      if (b->func->start_block == NULL) {
+         /* This is the first block encountered for this function.  In this
+          * case, we set the start block and add it to the list of
+          * implemented functions that we'll walk later.
+          */
+         b->func->start_block = b->block;
+         exec_list_push_tail(&b->functions, &b->func->node);
+      }
+      break;
+   }
+
+   case SpvOpBranch:
+   case SpvOpBranchConditional:
+   case SpvOpSwitch:
+   case SpvOpKill:
+   case SpvOpReturn:
+   case SpvOpReturnValue:
+   case SpvOpUnreachable:
+      assert(b->block);
+      b->block->branch = w;
+      b->block = NULL;
+      break;
+
+   default:
+      /* Continue on as per normal */
+      return true;
+   }
+
+   return true;
+}
+
 static bool
 vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
                             const uint32_t *w, unsigned count)
 {
    switch (opcode) {
    case SpvOpLabel: {
+      struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block;
       struct exec_node *list_tail = exec_list_get_tail(b->cf_list);
       nir_cf_node *tail_node = exec_node_data(nir_cf_node, list_tail, node);
       assert(tail_node->type == nir_cf_node_block);
-      nir_block *block = nir_cf_node_as_block(tail_node);
-
-      assert(exec_list_is_empty(&block->instr_list));
-      vtn_push_value(b, w[1], vtn_value_type_block)->block = block;
+      block->block = nir_cf_node_as_block(tail_node);
+      assert(exec_list_is_empty(&block->block->instr_list));
       break;
    }
 
+   case SpvOpLoopMerge:
+   case SpvOpSelectionMerge:
+      assert(b->merge_block == NULL);
+      /* TODO: Selection Control */
+      b->merge_block = vtn_value(b, w[1], vtn_value_type_block)->block;
+      break;
+
    case SpvOpUndef:
       vtn_push_value(b, w[2], vtn_value_type_undef);
       break;
@@ -878,11 +942,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_variables(b, opcode, w, count);
       break;
 
-   case SpvOpFunction:
-   case SpvOpFunctionEnd:
-   case SpvOpFunctionParameter:
    case SpvOpFunctionCall:
-      vtn_handle_functions(b, opcode, w, count);
+      vtn_handle_function_call(b, opcode, w, count);
       break;
 
    case SpvOpTextureSample:
@@ -1011,11 +1072,66 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    return true;
 }
 
+static void
+vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
+                struct vtn_block *end)
+{
+   struct vtn_block *block = start;
+   while (block != end) {
+      vtn_foreach_instruction(b, block->label, block->branch,
+                              vtn_handle_body_instruction);
+
+      const uint32_t *w = block->branch;
+      SpvOp branch_op = w[0] & SpvOpCodeMask;
+      switch (branch_op) {
+      case SpvOpBranch: {
+         assert(vtn_value(b, w[1], vtn_value_type_block)->block == end);
+         return;
+      }
+
+      case SpvOpBranchConditional: {
+         /* Gather up the branch blocks */
+         struct vtn_block *then_block =
+            vtn_value(b, w[2], vtn_value_type_block)->block;
+         struct vtn_block *else_block =
+            vtn_value(b, w[3], vtn_value_type_block)->block;
+         struct vtn_block *merge_block = b->merge_block;
+
+         nir_if *if_stmt = nir_if_create(b->shader);
+         if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1]));
+         nir_cf_node_insert_end(b->cf_list, &if_stmt->cf_node);
+
+         struct exec_list *old_list = b->cf_list;
+
+         b->cf_list = &if_stmt->then_list;
+         vtn_walk_blocks(b, then_block, merge_block);
+
+         b->cf_list = &if_stmt->else_list;
+         vtn_walk_blocks(b, else_block, merge_block);
+
+         b->cf_list = old_list;
+         block = merge_block;
+         continue;
+      }
+
+      case SpvOpSwitch:
+      case SpvOpKill:
+      case SpvOpReturn:
+      case SpvOpReturnValue:
+      case SpvOpUnreachable:
+      default:
+         unreachable("Unhandled opcode");
+      }
+   }
+}
+
 nir_shader *
 spirv_to_nir(const uint32_t *words, size_t word_count,
              gl_shader_stage stage,
              const nir_shader_compiler_options *options)
 {
+   const uint32_t *word_end = words + word_count;
+
    /* Handle the SPIR-V header (first 4 dwords)  */
    assert(word_count > 5);
 
@@ -1034,16 +1150,21 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    b->shader = shader;
    b->value_id_bound = value_id_bound;
    b->values = ralloc_array(b, struct vtn_value, value_id_bound);
-
-   const uint32_t *word_end = words + word_count;
+   exec_list_make_empty(&b->functions);
 
    /* Handle all the preamble instructions */
    words = vtn_foreach_instruction(b, words, word_end,
                                    vtn_handle_preamble_instruction);
 
-   words = vtn_foreach_instruction(b, words, word_end,
-                                   vtn_handle_body_instruction);
-   assert(words == word_end);
+   /* Do a very quick CFG analysis pass */
+   vtn_foreach_instruction(b, words, word_end,
+                           vtn_handle_first_cfg_pass_instruction);
+
+   foreach_list_typed(struct vtn_function, func, node, &b->functions) {
+      b->impl = nir_function_impl_create(func->overload);
+      b->cf_list = &b->impl->body;
+      vtn_walk_blocks(b, func->start_block, NULL);
+   }
 
    ralloc_free(b);