nir/spirv: Handle control-flow with loops
authorJason Ekstrand <jason.ekstrand@intel.com>
Wed, 6 May 2015 19:37:10 +0000 (12:37 -0700)
committerJason Ekstrand <jason.ekstrand@intel.com>
Sat, 16 May 2015 18:16:34 +0000 (11:16 -0700)
src/glsl/nir/spirv_to_nir.c
src/glsl/nir/spirv_to_nir_private.h

index 3bbf91453fdaaf1ed2edee9b1346162482bcd4c6..a4f13603dacb661ec04653ed5a4ecbdbf17d1cac 100644 (file)
@@ -1000,6 +1000,13 @@ vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
       b->block = NULL;
       break;
 
+   case SpvOpSelectionMerge:
+   case SpvOpLoopMerge:
+      assert(b->block && b->block->merge_op == SpvOpNop);
+      b->block->merge_op = opcode;
+      b->block->merge_block_id = w[1];
+      break;
+
    default:
       /* Continue on as per normal */
       return true;
@@ -1015,19 +1022,20 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    switch (opcode) {
    case SpvOpLabel: {
       struct vtn_block *block = vtn_value(b, w[1], vtn_value_type_block)->block;
+      assert(block->block == NULL);
+
       struct exec_node *list_tail = exec_list_get_tail(b->nb.cf_node_list);
       nir_cf_node *tail_node = exec_node_data(nir_cf_node, list_tail, node);
       assert(tail_node->type == nir_cf_node_block);
       block->block = nir_cf_node_as_block(tail_node);
+
       assert(exec_list_is_empty(&block->block->instr_list));
       break;
    }
 
    case SpvOpLoopMerge:
    case SpvOpSelectionMerge:
-      assert(b->merge_block == NULL);
-      /* TODO: Selection Control */
-      b->merge_block = vtn_value(b, w[1], vtn_value_type_block)->block;
+      /* This is handled by cfg pre-pass and walk_blocks */
       break;
 
    case SpvOpUndef:
@@ -1186,19 +1194,68 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
 
 static void
 vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
-                struct vtn_block *end)
+                struct vtn_block *break_block, struct vtn_block *cont_block,
+                struct vtn_block *end_block)
 {
    struct vtn_block *block = start;
-   while (block != end) {
+   while (block != end_block) {
+      const uint32_t *w = block->branch;
+      SpvOp branch_op = w[0] & SpvOpCodeMask;
+
+      if (block->block != NULL) {
+         /* We've already visited this block once before so this is a
+          * back-edge.  Back-edges are only allowed to point to a loop
+          * merge.
+          */
+         assert(block == cont_block);
+         return;
+      }
+
+      b->block = block;
       vtn_foreach_instruction(b, block->label, block->branch,
                               vtn_handle_body_instruction);
 
-      const uint32_t *w = block->branch;
-      SpvOp branch_op = w[0] & SpvOpCodeMask;
       switch (branch_op) {
       case SpvOpBranch: {
-         assert(vtn_value(b, w[1], vtn_value_type_block)->block == end);
-         return;
+         struct vtn_block *branch_block =
+            vtn_value(b, w[1], vtn_value_type_block)->block;
+
+         if (branch_block == break_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_break);
+            nir_builder_instr_insert(&b->nb, &jump->instr);
+
+            return;
+         } else if (branch_block == cont_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_continue);
+            nir_builder_instr_insert(&b->nb, &jump->instr);
+
+            return;
+         } else if (branch_block == end_block) {
+            return;
+         } else if (branch_block->merge_op == SpvOpLoopMerge) {
+            /* This is the jump into a loop. */
+            cont_block = branch_block;
+            break_block = vtn_value(b, branch_block->merge_block_id,
+                                    vtn_value_type_block)->block;
+
+            nir_loop *loop = nir_loop_create(b->shader);
+            nir_cf_node_insert_end(b->nb.cf_node_list, &loop->cf_node);
+
+            struct exec_list *old_list = b->nb.cf_node_list;
+
+            nir_builder_insert_after_cf_list(&b->nb, &loop->body);
+            vtn_walk_blocks(b, branch_block, break_block, cont_block, NULL);
+
+            nir_builder_insert_after_cf_list(&b->nb, old_list);
+            block = break_block;
+            continue;
+         } else {
+            /* TODO: Can this ever happen? */
+            block = branch_block;
+            continue;
+         }
       }
 
       case SpvOpBranchConditional: {
@@ -1207,28 +1264,99 @@ vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
             vtn_value(b, w[2], vtn_value_type_block)->block;
          struct vtn_block *else_block =
             vtn_value(b, w[3], vtn_value_type_block)->block;
-         struct vtn_block *merge_block = b->merge_block;
 
          nir_if *if_stmt = nir_if_create(b->shader);
          if_stmt->condition = nir_src_for_ssa(vtn_ssa_value(b, w[1]));
          nir_cf_node_insert_end(b->nb.cf_node_list, &if_stmt->cf_node);
 
-         struct exec_list *old_list = b->nb.cf_node_list;
+         if (then_block == break_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_break);
+            nir_instr_insert_after_cf_list(&if_stmt->then_list,
+                                           &jump->instr);
+            block = else_block;
+         } else if (else_block == break_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_break);
+            nir_instr_insert_after_cf_list(&if_stmt->else_list,
+                                           &jump->instr);
+            block = then_block;
+         } else if (then_block == cont_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_continue);
+            nir_instr_insert_after_cf_list(&if_stmt->then_list,
+                                           &jump->instr);
+            block = else_block;
+         } else if (else_block == cont_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_continue);
+            nir_instr_insert_after_cf_list(&if_stmt->else_list,
+                                           &jump->instr);
+            block = then_block;
+         } else {
+            /* Conventional if statement */
+            assert(block->merge_op == SpvOpSelectionMerge);
+            struct vtn_block *merge_block =
+               vtn_value(b, block->merge_block_id, vtn_value_type_block)->block;
+
+            struct exec_list *old_list = b->nb.cf_node_list;
 
-         nir_builder_insert_after_cf_list(&b->nb, &if_stmt->then_list);
-         vtn_walk_blocks(b, then_block, merge_block);
+            nir_builder_insert_after_cf_list(&b->nb, &if_stmt->then_list);
+            vtn_walk_blocks(b, then_block, break_block, cont_block, merge_block);
 
-         nir_builder_insert_after_cf_list(&b->nb, &if_stmt->else_list);
-         vtn_walk_blocks(b, else_block, merge_block);
+            nir_builder_insert_after_cf_list(&b->nb, &if_stmt->else_list);
+            vtn_walk_blocks(b, else_block, break_block, cont_block, merge_block);
+
+            nir_builder_insert_after_cf_list(&b->nb, old_list);
+            block = merge_block;
+            continue;
+         }
 
-         nir_builder_insert_after_cf_list(&b->nb, old_list);
-         block = merge_block;
+         /* If we got here then we inserted a predicated break or continue
+          * above and we need to handle the other case.  We already set
+          * `block` above to indicate what block to visit after the
+          * predicated break.
+          */
+
+         /* It's possible that the other branch is also a break/continue.
+          * If it is, we handle that here.
+          */
+         if (block == break_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_break);
+            nir_builder_instr_insert(&b->nb, &jump->instr);
+
+            return;
+         } else if (block == cont_block) {
+            nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                         nir_jump_continue);
+            nir_builder_instr_insert(&b->nb, &jump->instr);
+
+            return;
+         }
+
+         /* If we got here then there was a predicated break/continue but
+          * the other half of the if has stuff in it.  `block` was already
+          * set above so there is nothing left for us to do.
+          */
          continue;
       }
 
+      case SpvOpReturn: {
+         nir_jump_instr *jump = nir_jump_instr_create(b->shader,
+                                                      nir_jump_return);
+         nir_builder_instr_insert(&b->nb, &jump->instr);
+         return;
+      }
+
+      case SpvOpKill: {
+         nir_intrinsic_instr *discard =
+            nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard);
+         nir_builder_instr_insert(&b->nb, &discard->instr);
+         return;
+      }
+
       case SpvOpSwitch:
-      case SpvOpKill:
-      case SpvOpReturn:
       case SpvOpReturnValue:
       case SpvOpUnreachable:
       default:
@@ -1275,7 +1403,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
       b->impl = nir_function_impl_create(func->overload);
       nir_builder_init(&b->nb, b->impl);
       nir_builder_insert_after_cf_list(&b->nb, &b->impl->body);
-      vtn_walk_blocks(b, func->start_block, NULL);
+      vtn_walk_blocks(b, func->start_block, NULL, NULL, NULL);
    }
 
    ralloc_free(b);
index fd80dd4e161ac524726714b4ab5c5c2d75ce08ae..d2b364bdfeb3081f269cf6ddbc58f62045e2099f 100644 (file)
@@ -47,6 +47,9 @@ enum vtn_value_type {
 };
 
 struct vtn_block {
+   /* Merge opcode if this block contains a merge; SpvOpNop otherwise. */
+   SpvOp merge_op;
+   uint32_t merge_block_id;
    const uint32_t *label;
    const uint32_t *branch;
    nir_block *block;
@@ -92,7 +95,6 @@ struct vtn_builder {
    nir_shader *shader;
    nir_function_impl *impl;
    struct vtn_block *block;
-   struct vtn_block *merge_block;
 
    unsigned value_id_bound;
    struct vtn_value *values;