zink/spirv: implement loops
authorErik Faye-Lund <erik.faye-lund@collabora.com>
Mon, 18 Mar 2019 19:29:49 +0000 (20:29 +0100)
committerErik Faye-Lund <erik.faye-lund@collabora.com>
Mon, 28 Oct 2019 08:51:43 +0000 (08:51 +0000)
Acked-by: Jordan Justen <jordan.l.justen@intel.com>
src/gallium/drivers/zink/nir_to_spirv/nir_to_spirv.c
src/gallium/drivers/zink/nir_to_spirv/spirv_builder.c
src/gallium/drivers/zink/nir_to_spirv/spirv_builder.h

index c4c8da2efa1b62fe205ccd7ee1d489622265f026..e2b0cee44380c0550b73a6d9bb9de492c1836d36 100644 (file)
@@ -54,6 +54,7 @@ struct ntv_context {
    const SpvId *block_ids;
    size_t num_blocks;
    bool block_started;
+   SpvId loop_break, loop_cont;
 };
 
 static SpvId
@@ -1183,6 +1184,25 @@ branch_conditional(struct ntv_context *ctx, SpvId condition, SpvId then_id,
    ctx->block_started = false;
 }
 
+static void
+emit_jump(struct ntv_context *ctx, nir_jump_instr *jump)
+{
+   switch (jump->type) {
+   case nir_jump_break:
+      assert(ctx->loop_break);
+      branch(ctx, ctx->loop_break);
+      break;
+
+   case nir_jump_continue:
+      assert(ctx->loop_cont);
+      branch(ctx, ctx->loop_cont);
+      break;
+
+   default:
+      unreachable("Unsupported jump type\n");
+   }
+}
+
 static void
 emit_block(struct ntv_context *ctx, struct nir_block *block)
 {
@@ -1208,7 +1228,7 @@ emit_block(struct ntv_context *ctx, struct nir_block *block)
          unreachable("nir_instr_type_phi not supported");
          break;
       case nir_instr_type_jump:
-         unreachable("nir_instr_type_jump not supported");
+         emit_jump(ctx, nir_instr_as_jump(instr));
          break;
       case nir_instr_type_call:
          unreachable("nir_instr_type_call not supported");
@@ -1260,13 +1280,45 @@ emit_if(struct ntv_context *ctx, nir_if *if_stmt)
    emit_cf_list(ctx, &if_stmt->then_list);
 
    if (has_else) {
-      branch(ctx, endif_id);
+      if (ctx->block_started)
+         branch(ctx, endif_id);
+
       emit_cf_list(ctx, &if_stmt->else_list);
    }
 
    start_block(ctx, endif_id);
 }
 
+static void
+emit_loop(struct ntv_context *ctx, nir_loop *loop)
+{
+   SpvId header_id = spirv_builder_new_id(&ctx->builder);
+   SpvId begin_id = block_label(ctx, nir_loop_first_block(loop));
+   SpvId break_id = spirv_builder_new_id(&ctx->builder);
+   SpvId cont_id = spirv_builder_new_id(&ctx->builder);
+
+   /* create a header-block */
+   start_block(ctx, header_id);
+   spirv_builder_loop_merge(&ctx->builder, break_id, cont_id, SpvLoopControlMaskNone);
+   branch(ctx, begin_id);
+
+   SpvId save_break = ctx->loop_break;
+   SpvId save_cont = ctx->loop_cont;
+   ctx->loop_break = break_id;
+   ctx->loop_cont = cont_id;
+
+   emit_cf_list(ctx, &loop->body);
+
+   ctx->loop_break = save_break;
+   ctx->loop_cont = save_cont;
+
+   branch(ctx, cont_id);
+   start_block(ctx, cont_id);
+   branch(ctx, header_id);
+
+   start_block(ctx, break_id);
+}
+
 static void
 emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
 {
@@ -1281,7 +1333,7 @@ emit_cf_list(struct ntv_context *ctx, struct exec_list *list)
          break;
 
       case nir_cf_node_loop:
-         unreachable("nir_cf_node_loop not supported");
+         emit_loop(ctx, nir_cf_node_as_loop(node));
          break;
 
       case nir_cf_node_function:
index 8fce21bbc356bc7c2d8cf1abd59713e832acc336..053fba7b38b8aabe65bb959318469c5d291cab18 100644 (file)
@@ -450,6 +450,17 @@ spirv_builder_emit_selection_merge(struct spirv_builder *b, SpvId merge_block,
    spirv_buffer_emit_word(&b->instructions, selection_control);
 }
 
+void
+spirv_builder_loop_merge(struct spirv_builder *b, SpvId merge_block,
+                         SpvId cont_target, SpvLoopControlMask loop_control)
+{
+   spirv_buffer_prepare(&b->instructions, 4);
+   spirv_buffer_emit_word(&b->instructions, SpvOpLoopMerge | (4 << 16));
+   spirv_buffer_emit_word(&b->instructions, merge_block);
+   spirv_buffer_emit_word(&b->instructions, cont_target);
+   spirv_buffer_emit_word(&b->instructions, loop_control);
+}
+
 void
 spirv_builder_emit_branch_conditional(struct spirv_builder *b, SpvId condition,
                                       SpvId true_label, SpvId false_label)
index 3cf3d5d3565d824bbe2387e9141b9822b98b4ed5..03acd4110e763a7bfb132fd64cd8ebf2983941df 100644 (file)
@@ -182,6 +182,10 @@ void
 spirv_builder_emit_selection_merge(struct spirv_builder *b, SpvId merge_block,
                                    SpvSelectionControlMask selection_control);
 
+void
+spirv_builder_loop_merge(struct spirv_builder *b, SpvId merge_block,
+                         SpvId cont_target, SpvLoopControlMask loop_control);
+
 void
 spirv_builder_emit_branch_conditional(struct spirv_builder *b, SpvId condition,
                                       SpvId true_label, SpvId false_label);