nir/spirv: Add an actual CFG data structure
authorJason Ekstrand <jason.ekstrand@intel.com>
Tue, 29 Dec 2015 02:19:40 +0000 (18:19 -0800)
committerJason Ekstrand <jason.ekstrand@intel.com>
Tue, 29 Dec 2015 20:50:31 +0000 (12:50 -0800)
The current data structure doesn't handle much that we couldn't handle
before.  However, this will be absolutely crucial for doing swith
statements.  Also, this should fix structured continues.

src/glsl/Makefile.sources
src/glsl/nir/spirv/spirv_to_nir.c
src/glsl/nir/spirv/vtn_cfg.c [new file with mode: 0644]
src/glsl/nir/spirv/vtn_private.h

index aa87cb1480f26b7c6a4a24804618b6162d3cfad9..65c493cd6775a61fca09e5d4a033255777114d24 100644 (file)
@@ -94,6 +94,7 @@ NIR_FILES = \
 SPIRV_FILES = \
        nir/spirv/nir_spirv.h \
        nir/spirv/spirv_to_nir.c \
+       nir/spirv/vtn_cfg.c \
        nir/spirv/vtn_glsl450.c
 
 # libglsl
index 815b447857bae338c7da6325dc4a94e4bab82b70..16930c461979e28a821ace68f669ed122a982d85 100644 (file)
@@ -181,7 +181,7 @@ vtn_string_literal(struct vtn_builder *b, const uint32_t *words,
    return ralloc_strndup(b, (char *)words, word_count * sizeof(*words));
 }
 
-static const uint32_t *
+const uint32_t *
 vtn_foreach_instruction(struct vtn_builder *b, const uint32_t *start,
                         const uint32_t *end, vtn_instruction_handler handler)
 {
@@ -3359,127 +3359,6 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
    return true;
 }
 
-static bool
-vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
-                                      const uint32_t *w, unsigned count)
-{
-   switch (opcode) {
-   case SpvOpFunction: {
-      assert(b->func == NULL);
-      b->func = rzalloc(b, struct vtn_function);
-
-      const struct glsl_type *result_type =
-         vtn_value(b, w[1], vtn_value_type_type)->type->type;
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
-      val->func = b->func;
-
-      const struct glsl_type *func_type =
-         vtn_value(b, w[4], vtn_value_type_type)->type->type;
-
-      assert(glsl_get_function_return_type(func_type) == result_type);
-
-      nir_function *func =
-         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
-
-      func->num_params = glsl_get_length(func_type);
-      func->params = ralloc_array(b->shader, nir_parameter, func->num_params);
-      for (unsigned i = 0; i < func->num_params; i++) {
-         const struct glsl_function_param *param =
-            glsl_get_function_param(func_type, i);
-         func->params[i].type = param->type;
-         if (param->in) {
-            if (param->out) {
-               func->params[i].param_type = nir_parameter_inout;
-            } else {
-               func->params[i].param_type = nir_parameter_in;
-            }
-         } else {
-            if (param->out) {
-               func->params[i].param_type = nir_parameter_out;
-            } else {
-               assert(!"Parameter is neither in nor out");
-            }
-         }
-      }
-
-      func->return_type = glsl_get_function_return_type(func_type);
-
-      b->func->impl = nir_function_impl_create(func);
-      if (!glsl_type_is_void(func->return_type)) {
-         b->func->impl->return_var =
-            nir_local_variable_create(b->func->impl, func->return_type, "ret");
-      }
-
-      b->func_param_idx = 0;
-      break;
-   }
-
-   case SpvOpFunctionEnd:
-      b->func->end = w;
-      b->func = NULL;
-      break;
-
-   case SpvOpFunctionParameter: {
-      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_deref);
-
-      assert(b->func_param_idx < b->func->impl->num_params);
-      unsigned idx = b->func_param_idx++;
-
-      nir_variable *param =
-         nir_local_variable_create(b->func->impl,
-                                   b->func->impl->function->params[idx].type,
-                                   val->name);
-
-      b->func->impl->params[idx] = param;
-      val->deref = nir_deref_var_create(b, param);
-      val->deref_type = vtn_value(b, w[1], vtn_value_type_type)->type;
-      break;
-   }
-
-   case SpvOpLabel: {
-      assert(b->block == NULL);
-      b->block = rzalloc(b, struct vtn_block);
-      b->block->label = w;
-      vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
-
-      if (b->func->start_block == NULL) {
-         /* This is the first block encountered for this function.  In this
-          * case, we set the start block and add it to the list of
-          * implemented functions that we'll walk later.
-          */
-         b->func->start_block = b->block;
-         exec_list_push_tail(&b->functions, &b->func->node);
-      }
-      break;
-   }
-
-   case SpvOpBranch:
-   case SpvOpBranchConditional:
-   case SpvOpSwitch:
-   case SpvOpKill:
-   case SpvOpReturn:
-   case SpvOpReturnValue:
-   case SpvOpUnreachable:
-      assert(b->block);
-      b->block->branch = w;
-      b->block = NULL;
-      break;
-
-   case SpvOpSelectionMerge:
-   case SpvOpLoopMerge:
-      assert(b->block && b->block->merge_op == SpvOpNop);
-      b->block->merge_op = opcode;
-      b->block->merge_block_id = w[1];
-      break;
-
-   default:
-      /* Continue on as per normal */
-      return true;
-   }
-
-   return true;
-}
-
 static bool
 vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
                             const uint32_t *w, unsigned count)
@@ -3487,9 +3366,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    switch (opcode) {
    case SpvOpLabel: {
       struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block;
-      assert(block->block == NULL);
-
-      block->block = nir_cursor_current_block(b->nb.cursor);
+      assert(block->block == nir_cursor_current_block(b->nb.cursor));
       break;
    }
 
@@ -3697,196 +3574,119 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    return true;
 }
 
-static void
-vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
-                struct vtn_block *break_block, struct vtn_block *cont_block,
-                struct vtn_block *end_block)
-{
-   struct vtn_block *block = start;
-   while (block != end_block) {
-      if (block->merge_op == SpvOpLoopMerge) {
-         /* This is the jump into a loop. */
-         struct vtn_block *new_cont_block = block;
-         struct vtn_block *new_break_block =
-            vtn_value(b, block->merge_block_id, vtn_value_type_block)->block;
+static void vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list);
 
-         nir_loop *loop = nir_loop_create(b->shader);
-         nir_cf_node_insert(b->nb.cursor, &loop->cf_node);
-
-         /* Reset the merge_op to prerevent infinite recursion */
-         block->merge_op = SpvOpNop;
-
-         b->nb.cursor = nir_after_cf_list(&loop->body);
-         vtn_walk_blocks(b, block, new_break_block, new_cont_block, NULL);
-
-         b->nb.cursor = nir_after_cf_node(&loop->cf_node);
-         block = new_break_block;
-         continue;
-      }
+static void
+vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type)
+{
+   nir_jump_type jump_type;
+   switch (branch_type) {
+   case vtn_branch_type_break:      jump_type = nir_jump_break;      break;
+   case vtn_branch_type_continue:   jump_type = nir_jump_continue;   break;
+   case vtn_branch_type_return:     jump_type = nir_jump_return;     break;
+   default:
+      unreachable("Invalid branch type");
+   }
 
-      const uint32_t *w = block->branch;
-      SpvOp branch_op = w[0] & SpvOpCodeMask;
+   nir_jump_instr *jump = nir_jump_instr_create(b->shader, jump_type);
+   nir_builder_instr_insert(&b->nb, &jump->instr);
+}
 
-      b->block = block;
-      vtn_foreach_instruction(b, block->label, block->branch,
-                              vtn_handle_body_instruction);
+static void
+vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list)
+{
+   list_for_each_entry(struct vtn_cf_node, node, cf_list, link) {
+      switch (node->type) {
+      case vtn_cf_node_type_block: {
+         struct vtn_block *block = (struct vtn_block *)node;
 
-      nir_block *cur_block = nir_cursor_current_block(b->nb.cursor);
-      assert(cur_block == block->block);
-      _mesa_hash_table_insert(b->block_table, cur_block, block);
+         block->block = nir_cursor_current_block(b->nb.cursor);
+         _mesa_hash_table_insert(b->block_table, block->block, block);
 
-      switch (branch_op) {
-      case SpvOpBranch: {
-         struct vtn_block *branch_block =
-            vtn_value(b, w[1], vtn_value_type_block)->block;
+         vtn_foreach_instruction(b, block->label,
+                                 block->merge ? block->merge : block->branch,
+                                 vtn_handle_body_instruction);
 
-         if (branch_block == break_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_break);
-            nir_builder_instr_insert(&b->nb, &jump->instr);
+         if ((*block->branch & SpvOpCodeMask) == SpvOpReturnValue) {
+            struct vtn_ssa_value *src = vtn_ssa_value(b, block->branch[1]);
+            vtn_variable_store(b, src,
+                               nir_deref_var_create(b, b->impl->return_var),
+                               NULL);
+         }
 
-            return;
-         } else if (branch_block == cont_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_continue);
-            nir_builder_instr_insert(&b->nb, &jump->instr);
+         if (block->branch_type != vtn_branch_type_none)
+            vtn_emit_branch(b, block->branch_type);
 
-            return;
-         } else if (branch_block == end_block) {
-            /* We're branching to the merge block of an if, since for loops
-             * and functions end_block == NULL, so we're done here.
-             */
-            return;
-         } else {
-            /* We're branching to another block, and according to the rules,
-             * we can only branch to another block with one predecessor (so
-             * we're the only one jumping to it) so we can just process it
-             * next.
-             */
-            block = branch_block;
-            continue;
-         }
+         break;
       }
 
-      case SpvOpBranchConditional: {
-         /* Gather up the branch blocks */
-         struct vtn_block *then_block =
-            vtn_value(b, w[2], vtn_value_type_block)->block;
-         struct vtn_block *else_block =
-            vtn_value(b, w[3], vtn_value_type_block)->block;
+      case vtn_cf_node_type_if: {
+         struct vtn_if *vtn_if = (struct vtn_if *)node;
 
          nir_if *if_stmt = nir_if_create(b->shader);
-         if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1])->def);
+         if_stmt->condition =
+            nir_src_for_ssa(vtn_ssa_value(b, vtn_if->condition)->def);
          nir_cf_node_insert(b->nb.cursor, &if_stmt->cf_node);
 
-         if (then_block == break_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_break);
-            nir_instr_insert_after_cf_list(&if_stmt->then_list,
-                                           &jump->instr);
-            block = else_block;
-         } else if (else_block == break_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_break);
-            nir_instr_insert_after_cf_list(&if_stmt->else_list,
-                                           &jump->instr);
-            block = then_block;
-         } else if (then_block == cont_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_continue);
-            nir_instr_insert_after_cf_list(&if_stmt->then_list,
-                                           &jump->instr);
-            block = else_block;
-         } else if (else_block == cont_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_continue);
-            nir_instr_insert_after_cf_list(&if_stmt->else_list,
-                                           &jump->instr);
-            block = then_block;
-         } else {
-            /* According to the rules we're branching to two blocks that don't
-             * have any other predecessors, so we can handle this as a
-             * conventional if.
-             */
-            assert(block->merge_op == SpvOpSelectionMerge);
-            struct vtn_block *merge_block =
-               vtn_value(b, block->merge_block_id, vtn_value_type_block)->block;
+         b->nb.cursor = nir_after_cf_list(&if_stmt->then_list);
+         if (vtn_if->then_type == vtn_branch_type_none)
+            vtn_emit_cf_list(b, &vtn_if->then_body);
+         else
+            vtn_emit_branch(b, vtn_if->then_type);
 
-            b->nb.cursor = nir_after_cf_list(&if_stmt->then_list);
-            vtn_walk_blocks(b, then_block, break_block, cont_block, merge_block);
+         b->nb.cursor = nir_after_cf_list(&if_stmt->else_list);
+         if (vtn_if->else_type == vtn_branch_type_none)
+            vtn_emit_cf_list(b, &vtn_if->else_body);
+         else
+            vtn_emit_branch(b, vtn_if->else_type);
 
-            b->nb.cursor = nir_after_cf_list(&if_stmt->else_list);
-            vtn_walk_blocks(b, else_block, break_block, cont_block, merge_block);
+         b->nb.cursor = nir_after_cf_node(&if_stmt->cf_node);
+         break;
+      }
 
-            b->nb.cursor = nir_after_cf_node(&if_stmt->cf_node);
-            block = merge_block;
-            continue;
-         }
+      case vtn_cf_node_type_loop: {
+         struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
 
-         /* If we got here then we inserted a predicated break or continue
-          * above and we need to handle the other case.  We already set
-          * `block` above to indicate what block to visit after the
-          * predicated break.
-          */
+         nir_loop *loop = nir_loop_create(b->shader);
+         nir_cf_node_insert(b->nb.cursor, &loop->cf_node);
 
-         /* It's possible that the other branch is also a break/continue.
-          * If it is, we handle that here.
-          */
-         if (block == break_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_break);
-            nir_builder_instr_insert(&b->nb, &jump->instr);
-
-            return;
-         } else if (block == cont_block) {
-            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                         nir_jump_continue);
-            nir_builder_instr_insert(&b->nb, &jump->instr);
-
-            return;
-         }
+         if (!list_empty(&vtn_loop->cont_body)) {
+            /* If we have a non-trivial continue body then we need to put
+             * it at the beginning of the loop with a flag to ensure that
+             * it doesn't get executed in the first iteration.
+             */
+            nir_variable *do_cont =
+               nir_local_variable_create(b->nb.impl, glsl_bool_type(), "cont");
 
-         /* If we got here then there was a predicated break/continue but
-          * the other half of the if has stuff in it.  `block` was already
-          * set above so there is nothing left for us to do.
-          */
-         continue;
-      }
+            b->nb.cursor = nir_before_cf_node(&loop->cf_node);
+            nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_FALSE), 1);
 
-      case SpvOpReturn: {
-         nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                      nir_jump_return);
-         nir_builder_instr_insert(&b->nb, &jump->instr);
-         return;
-      }
+            b->nb.cursor = nir_after_cf_list(&loop->body);
+            nir_if *cont_if = nir_if_create(b->shader);
+            cont_if->condition = nir_src_for_ssa(nir_load_var(&b->nb, do_cont));
+            nir_cf_node_insert(b->nb.cursor, &cont_if->cf_node);
 
-      case SpvOpReturnValue: {
-         struct vtn_ssa_value *src = vtn_ssa_value(b, w[1]);
-         vtn_variable_store(b, src,
-                            nir_deref_var_create(b, b->impl->return_var),
-                            NULL);
+            b->nb.cursor = nir_after_cf_list(&cont_if->then_list);
+            vtn_emit_cf_list(b, &vtn_loop->cont_body);
 
-         nir_jump_instr *jump = nir_jump_instr_create(b->shader,
-                                                      nir_jump_return);
-         nir_builder_instr_insert(&b->nb, &jump->instr);
-         return;
-      }
+            b->nb.cursor = nir_after_cf_node(&cont_if->cf_node);
+            nir_store_var(&b->nb, do_cont, nir_imm_int(&b->nb, NIR_TRUE), 1);
+         }
 
-      case SpvOpKill: {
-         nir_intrinsic_instr *discard =
-            nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard);
-         nir_builder_instr_insert(&b->nb, &discard->instr);
-         return;
+         b->nb.cursor = nir_after_cf_node(&loop->cf_node);
+         vtn_emit_cf_list(b, &vtn_loop->body);
+         break;
       }
 
-      case SpvOpSwitch:
-      case SpvOpUnreachable:
+      case vtn_cf_node_type_switch:
+      case vtn_cf_node_type_case:
       default:
-         unreachable("Unhandled opcode");
+         unreachable("Invalid CF node type");
       }
    }
 }
 
+
 nir_shader *
 spirv_to_nir(const uint32_t *words, size_t word_count,
              gl_shader_stage stage,
@@ -3924,9 +3724,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
    words = vtn_foreach_instruction(b, words, word_end,
                                    vtn_handle_preamble_instruction);
 
-   /* Do a very quick CFG analysis pass */
-   vtn_foreach_instruction(b, words, word_end,
-                           vtn_handle_first_cfg_pass_instruction);
+   vtn_build_cfg(b, words, word_end);
 
    foreach_list_typed(struct vtn_function, func, node, &b->functions) {
       b->impl = func->impl;
@@ -3936,7 +3734,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
                                                _mesa_key_pointer_equal);
       nir_builder_init(&b->nb, b->impl);
       b->nb.cursor = nir_after_cf_list(&b->impl->body);
-      vtn_walk_blocks(b, func->start_block, NULL, NULL, NULL);
+      vtn_emit_cf_list(b, &func->body);
       vtn_foreach_instruction(b, func->start_block->label, func->end,
                               vtn_handle_phi_second_pass);
    }
diff --git a/src/glsl/nir/spirv/vtn_cfg.c b/src/glsl/nir/spirv/vtn_cfg.c
new file mode 100644 (file)
index 0000000..e625817
--- /dev/null
@@ -0,0 +1,312 @@
+/*
+ * Copyright © 2015 Intel Corporation
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a
+ * copy of this software and associated documentation files (the "Software"),
+ * to deal in the Software without restriction, including without limitation
+ * the rights to use, copy, modify, merge, publish, distribute, sublicense,
+ * and/or sell copies of the Software, and to permit persons to whom the
+ * Software is furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice (including the next
+ * paragraph) shall be included in all copies or substantial portions of the
+ * Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+ * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "vtn_private.h"
+
+static bool
+vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
+                                   const uint32_t *w, unsigned count)
+{
+   switch (opcode) {
+   case SpvOpFunction: {
+      assert(b->func == NULL);
+      b->func = rzalloc(b, struct vtn_function);
+
+      list_inithead(&b->func->body);
+      b->func->control = w[3];
+
+      const struct glsl_type *result_type =
+         vtn_value(b, w[1], vtn_value_type_type)->type->type;
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
+      val->func = b->func;
+
+      const struct glsl_type *func_type =
+         vtn_value(b, w[4], vtn_value_type_type)->type->type;
+
+      assert(glsl_get_function_return_type(func_type) == result_type);
+
+      nir_function *func =
+         nir_function_create(b->shader, ralloc_strdup(b->shader, val->name));
+
+      func->num_params = glsl_get_length(func_type);
+      func->params = ralloc_array(b->shader, nir_parameter, func->num_params);
+      for (unsigned i = 0; i < func->num_params; i++) {
+         const struct glsl_function_param *param =
+            glsl_get_function_param(func_type, i);
+         func->params[i].type = param->type;
+         if (param->in) {
+            if (param->out) {
+               func->params[i].param_type = nir_parameter_inout;
+            } else {
+               func->params[i].param_type = nir_parameter_in;
+            }
+         } else {
+            if (param->out) {
+               func->params[i].param_type = nir_parameter_out;
+            } else {
+               assert(!"Parameter is neither in nor out");
+            }
+         }
+      }
+
+      func->return_type = glsl_get_function_return_type(func_type);
+
+      b->func->impl = nir_function_impl_create(func);
+      if (!glsl_type_is_void(func->return_type)) {
+         b->func->impl->return_var =
+            nir_local_variable_create(b->func->impl, func->return_type, "ret");
+      }
+
+      b->func_param_idx = 0;
+      break;
+   }
+
+   case SpvOpFunctionEnd:
+      b->func->end = w;
+      b->func = NULL;
+      break;
+
+   case SpvOpFunctionParameter: {
+      struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_deref);
+
+      assert(b->func_param_idx < b->func->impl->num_params);
+      unsigned idx = b->func_param_idx++;
+
+      nir_variable *param =
+         nir_local_variable_create(b->func->impl,
+                                   b->func->impl->function->params[idx].type,
+                                   val->name);
+
+      b->func->impl->params[idx] = param;
+      val->deref = nir_deref_var_create(b, param);
+      val->deref_type = vtn_value(b, w[1], vtn_value_type_type)->type;
+      break;
+   }
+
+   case SpvOpLabel: {
+      assert(b->block == NULL);
+      b->block = rzalloc(b, struct vtn_block);
+      b->block->node.type = vtn_cf_node_type_block;
+      b->block->label = w;
+      vtn_push_value(b, w[1], vtn_value_type_block)->block = b->block;
+
+      if (b->func->start_block == NULL) {
+         /* This is the first block encountered for this function.  In this
+          * case, we set the start block and add it to the list of
+          * implemented functions that we'll walk later.
+          */
+         b->func->start_block = b->block;
+         exec_list_push_tail(&b->functions, &b->func->node);
+      }
+      break;
+   }
+
+   case SpvOpSelectionMerge:
+   case SpvOpLoopMerge:
+      assert(b->block && b->block->merge == NULL);
+      b->block->merge = w;
+      break;
+
+   case SpvOpBranch:
+   case SpvOpBranchConditional:
+   case SpvOpSwitch:
+   case SpvOpKill:
+   case SpvOpReturn:
+   case SpvOpReturnValue:
+   case SpvOpUnreachable:
+      assert(b->block && b->block->branch == NULL);
+      b->block->branch = w;
+      b->block = NULL;
+      break;
+
+   default:
+      /* Continue on as per normal */
+      return true;
+   }
+
+   return true;
+}
+
+static void
+vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
+                    struct vtn_block *start, struct vtn_block *break_block,
+                    struct vtn_block *cont_block, struct vtn_block *end_block)
+{
+   struct vtn_block *block = start;
+   while (block != end_block) {
+      if (block->merge && (*block->merge & SpvOpCodeMask) == SpvOpLoopMerge &&
+          !block->loop) {
+         struct vtn_loop *loop = ralloc(b, struct vtn_loop);
+
+         loop->node.type = vtn_cf_node_type_loop;
+         list_inithead(&loop->body);
+         list_inithead(&loop->cont_body);
+         loop->control = block->merge[3];
+
+         list_addtail(&loop->node.link, cf_list);
+         block->loop = loop;
+
+         struct vtn_block *loop_break =
+            vtn_value(b, block->merge[1], vtn_value_type_block)->block;
+         struct vtn_block *loop_cont =
+            vtn_value(b, block->merge[2], vtn_value_type_block)->block;
+
+         /* Note: This recursive call will start with the current block as
+          * its start block.  If we weren't careful, we would get here
+          * again and end up in infinite recursion.  This is why we set
+          * block->loop above and check for it before creating one.  This
+          * way, we only create the loop once and the second call that
+          * tries to handle this loop goes to the cases below and gets
+          * handled as a regular block.
+          */
+         vtn_cfg_walk_blocks(b, &loop->body, block,
+                             loop_break, loop_cont, NULL );
+         vtn_cfg_walk_blocks(b, &loop->body, loop_cont, NULL, NULL, block);
+
+         block = loop_break;
+         continue;
+      }
+
+      list_addtail(&block->node.link, cf_list);
+
+      switch (*block->branch & SpvOpCodeMask) {
+      case SpvOpBranch: {
+         struct vtn_block *branch_block =
+            vtn_value(b, block->branch[1], vtn_value_type_block)->block;
+
+         if (branch_block == break_block) {
+            block->branch_type = vtn_branch_type_break;
+            return;
+         } else if (branch_block == cont_block) {
+            block->branch_type = vtn_branch_type_continue;
+            return;
+         } else if (branch_block == end_block) {
+            block->branch_type = vtn_branch_type_none;
+            return;
+         } else {
+            /* If it's not one of the above, then we must be jumping to the
+             * next block in the current CF list.  Just keep going.
+             */
+            block->branch_type = vtn_branch_type_none;
+            block = branch_block;
+            continue;
+         }
+      }
+
+      case SpvOpReturn:
+      case SpvOpReturnValue:
+         block->branch_type = vtn_branch_type_return;
+         return;
+
+      case SpvOpKill:
+         block->branch_type = vtn_branch_type_discard;
+         return;
+
+      case SpvOpBranchConditional: {
+         struct vtn_block *then_block =
+            vtn_value(b, block->branch[2], vtn_value_type_block)->block;
+         struct vtn_block *else_block =
+            vtn_value(b, block->branch[3], vtn_value_type_block)->block;
+
+         struct vtn_if *if_stmt = ralloc(b, struct vtn_if);
+
+         if_stmt->node.type = vtn_cf_node_type_if;
+         if_stmt->condition = block->branch[1];
+         list_inithead(&if_stmt->then_body);
+         list_inithead(&if_stmt->else_body);
+
+         list_addtail(&if_stmt->node.link, cf_list);
+
+         /* OpBranchConditional must be at the end of a block with either
+          * an OpSelectionMerge or an OpLoopMerge.
+          */
+         assert(block->merge);
+         if ((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge)
+            if_stmt->control = block->merge[2];
+
+         struct vtn_block *next_block = NULL;
+         if (then_block == break_block) {
+            if_stmt->then_type = vtn_branch_type_break;
+         } else if (then_block == cont_block) {
+            if_stmt->then_type = vtn_branch_type_continue;
+         } else {
+            if_stmt->then_type = vtn_branch_type_none;
+            next_block = then_block;
+         }
+
+         if (else_block == break_block) {
+            if_stmt->else_type = vtn_branch_type_break;
+         } else if (else_block == cont_block) {
+            if_stmt->else_type = vtn_branch_type_continue;
+         } else {
+            if_stmt->else_type = vtn_branch_type_none;
+            next_block = else_block;
+         }
+
+         if (if_stmt->then_type == vtn_branch_type_none &&
+             if_stmt->else_type == vtn_branch_type_none) {
+            /* Neither side of the if is something we can short-circuit. */
+            assert((*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge);
+            struct vtn_block *merge_block =
+               vtn_value(b, block->merge[1], vtn_value_type_block)->block;
+
+            vtn_cfg_walk_blocks(b, &if_stmt->then_body, then_block,
+                                break_block, cont_block, merge_block);
+            vtn_cfg_walk_blocks(b, &if_stmt->else_body, else_block,
+                                break_block, cont_block, merge_block);
+
+            block = merge_block;
+            continue;
+         } else if (if_stmt->then_type != vtn_branch_type_none &&
+                    if_stmt->else_type != vtn_branch_type_none) {
+            /* Both sides were short-circuited.  We're done here. */
+            return;
+         } else {
+            /* Exeactly one side of the branch could be short-circuited.
+             * We set the branch up as a predicated break/continue and we
+             * continue on with the other side as if it were what comes
+             * after the if.
+             */
+            block = next_block;
+            continue;
+         }
+         unreachable("Should have returned or continued");
+      }
+
+      case SpvOpSwitch:
+      case SpvOpUnreachable:
+      default:
+         unreachable("Unhandled opcode");
+      }
+   }
+}
+
+void
+vtn_build_cfg(struct vtn_builder *b, const uint32_t *words, const uint32_t *end)
+{
+   vtn_foreach_instruction(b, words, end,
+                           vtn_cfg_handle_prepass_instruction);
+
+   foreach_list_typed(struct vtn_function, func, node, &b->functions)
+      vtn_cfg_walk_blocks(b, &func->body, func->start_block, NULL, NULL, NULL);
+}
index 2fea244bb747ded451acd83143d96f96be586638..e6d8e190cb92730082eeab21374722e5b4b21ca5 100644 (file)
@@ -49,12 +49,72 @@ enum vtn_value_type {
    vtn_value_type_sampled_image,
 };
 
+enum vtn_branch_type {
+   vtn_branch_type_none,
+   vtn_branch_type_break,
+   vtn_branch_type_continue,
+   vtn_branch_type_discard,
+   vtn_branch_type_return,
+};
+
+enum vtn_cf_node_type {
+   vtn_cf_node_type_block,
+   vtn_cf_node_type_if,
+   vtn_cf_node_type_loop,
+   vtn_cf_node_type_switch,
+   vtn_cf_node_type_case,
+};
+
+struct vtn_cf_node {
+   struct list_head link;
+   enum vtn_cf_node_type type;
+};
+
+struct vtn_loop {
+   struct vtn_cf_node node;
+
+   /* The main body of the loop */
+   struct list_head body;
+
+   /* The "continue" part of the loop.  This gets executed after the body
+    * and is where you go when you hit a continue.
+    */
+   struct list_head cont_body;
+
+   SpvLoopControlMask control;
+};
+
+struct vtn_if {
+   struct vtn_cf_node node;
+
+   uint32_t condition;
+
+   enum vtn_branch_type then_type;
+   struct list_head then_body;
+
+   enum vtn_branch_type else_type;
+   struct list_head else_body;
+
+   SpvSelectionControlMask control;
+};
+
 struct vtn_block {
-   /* Merge opcode if this block contains a merge; SpvOpNop otherwise. */
-   SpvOp merge_op;
-   uint32_t merge_block_id;
+   struct vtn_cf_node node;
+
+   /** A pointer to the label instruction */
    const uint32_t *label;
+
+   /** A pointer to the merge instruction (or NULL if non exists) */
+   const uint32_t *merge;
+
+   /** A pointer to the branch instruction that ends this block */
    const uint32_t *branch;
+
+   enum vtn_branch_type branch_type;
+
+   /** Points to the loop that this block starts (if it starts a loop) */
+   struct vtn_loop *loop;
+
    nir_block *block;
 };
 
@@ -64,12 +124,23 @@ struct vtn_function {
    nir_function_impl *impl;
    struct vtn_block *start_block;
 
+   struct list_head body;
+
    const uint32_t *end;
+
+   SpvFunctionControlMask control;
 };
 
+void vtn_build_cfg(struct vtn_builder *b, const uint32_t *words,
+                   const uint32_t *end);
+
 typedef bool (*vtn_instruction_handler)(struct vtn_builder *, uint32_t,
                                         const uint32_t *, unsigned);
 
+const uint32_t *
+vtn_foreach_instruction(struct vtn_builder *b, const uint32_t *start,
+                        const uint32_t *end, vtn_instruction_handler handler);
+
 struct vtn_ssa_value {
    union {
       nir_ssa_def *def;