nir/vtn: add support for phi nodes
authorConnor Abbott <cwabbott0@gmail.com>
Tue, 23 Jun 2015 17:34:55 +0000 (10:34 -0700)
committerConnor Abbott <cwabbott0@gmail.com>
Tue, 23 Jun 2015 17:34:55 +0000 (10:34 -0700)
src/glsl/nir/spirv_to_nir.c
src/glsl/nir/spirv_to_nir_private.h

index 4aabf3cc4e0c3ec87c6ce16062f908373b3e4175..a5e9c4aa8aec26f14d0ba0473890a75bd7caeef6 100644 (file)
@@ -1686,6 +1686,112 @@ vtn_handle_composite(struct vtn_builder *b, SpvOp opcode,
    val->ssa->type = val->type;
 }
 
+static void
+vtn_phi_node_init(struct vtn_builder *b, struct vtn_ssa_value *val)
+{
+   if (glsl_type_is_vector_or_scalar(val->type)) {
+      nir_phi_instr *phi = nir_phi_instr_create(b->shader);
+      nir_ssa_dest_init(&phi->instr, &phi->dest,
+                        glsl_get_vector_elements(val->type), NULL);
+      exec_list_make_empty(&phi->srcs);
+      nir_builder_instr_insert(&b->nb, &phi->instr);
+      val->def = &phi->dest.ssa;
+   } else {
+      unsigned elems = glsl_get_length(val->type);
+      for (unsigned i = 0; i < elems; i++)
+         vtn_phi_node_init(b, val->elems[i]);
+   }
+}
+
+static struct vtn_ssa_value *
+vtn_phi_node_create(struct vtn_builder *b, const struct glsl_type *type)
+{
+   struct vtn_ssa_value *val = vtn_create_ssa_value(b, type);
+   vtn_phi_node_init(b, val);
+   return val;
+}
+
+static void
+vtn_handle_phi_first_pass(struct vtn_builder *b, const uint32_t *w)
+{
+   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
+   val->type = vtn_value(b, w[1], vtn_value_type_type)->type;
+   val->ssa = vtn_phi_node_create(b, val->type);
+}
+
+static void
+vtn_phi_node_add_src(struct vtn_ssa_value *phi, const nir_block *pred,
+                     struct vtn_ssa_value *val)
+{
+   assert(phi->type == val->type);
+   if (glsl_type_is_vector_or_scalar(phi->type)) {
+      nir_phi_instr *phi_instr = nir_instr_as_phi(phi->def->parent_instr);
+      nir_phi_src *src = ralloc(phi_instr, nir_phi_src);
+      src->pred = (nir_block *) pred;
+      src->src = nir_src_for_ssa(val->def);
+      exec_list_push_tail(&phi_instr->srcs, &src->node);
+   } else {
+      unsigned elems = glsl_get_length(phi->type);
+      for (unsigned i = 0; i < elems; i++)
+         vtn_phi_node_add_src(phi->elems[i], pred, val->elems[i]);
+   }
+}
+
+static struct vtn_ssa_value *
+vtn_get_phi_node_src(struct vtn_builder *b, nir_block *block,
+                     const struct glsl_type *type, const uint32_t *w,
+                     unsigned count)
+{
+   struct hash_entry *entry = _mesa_hash_table_search(b->block_table, block);
+   if (entry) {
+      struct vtn_block *spv_block = entry->data;
+      for (unsigned off = 4; off < count; off += 2) {
+         if (spv_block == vtn_value(b, w[off], vtn_value_type_block)->block) {
+            return vtn_ssa_value(b, w[off - 1]);
+         }
+      }
+   }
+
+   nir_builder_insert_before_block(&b->nb, block);
+   struct vtn_ssa_value *phi = vtn_phi_node_create(b, type);
+
+   struct set_entry *entry2;
+   set_foreach(block->predecessors, entry2) {
+      nir_block *pred = (nir_block *) entry2->key;
+      struct vtn_ssa_value *val = vtn_get_phi_node_src(b, pred, type, w,
+                                                       count);
+      vtn_phi_node_add_src(phi, pred, val);
+   }
+
+   return phi;
+}
+
+static bool
+vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
+                           const uint32_t *w, unsigned count)
+{
+   if (opcode == SpvOpLabel) {
+      b->block = vtn_value(b, w[1], vtn_value_type_block)->block;
+      return true;
+   }
+
+   if (opcode != SpvOpPhi)
+      return true;
+
+   struct vtn_ssa_value *phi = vtn_value(b, w[2], vtn_value_type_ssa)->ssa;
+
+   struct set_entry *entry;
+   set_foreach(b->block->block->predecessors, entry) {
+      nir_block *pred = (nir_block *) entry->key;
+
+      struct vtn_ssa_value *val = vtn_get_phi_node_src(b, pred, phi->type, w,
+                                                       count);
+      vtn_phi_node_add_src(phi, pred, val);
+   }
+
+   return true;
+}
+
 static bool
 vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
                                 const uint32_t *w, unsigned count)
@@ -1835,6 +1941,7 @@ vtn_handle_first_cfg_pass_instruction(struct vtn_builder *b, SpvOp opcode,
    }
 
    case SpvOpFunctionEnd:
+      b->func->end = w;
       b->func = NULL;
       break;
 
@@ -2063,6 +2170,10 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_composite(b, opcode, w, count);
       break;
 
+   case SpvOpPhi:
+      vtn_handle_phi_first_pass(b, w);
+      break;
+
    default:
       unreachable("Unhandled opcode");
    }
@@ -2106,6 +2217,12 @@ vtn_walk_blocks(struct vtn_builder *b, struct vtn_block *start,
       vtn_foreach_instruction(b, block->label, block->branch,
                               vtn_handle_body_instruction);
 
+      nir_cf_node *cur_cf_node =
+         exec_node_data(nir_cf_node, exec_list_get_tail(b->nb.cf_node_list),
+                        node);
+      nir_block *cur_block = nir_cf_node_as_block(cur_cf_node);
+      _mesa_hash_table_insert(b->block_table, cur_block, block);
+
       switch (branch_op) {
       case SpvOpBranch: {
          struct vtn_block *branch_block =
@@ -2287,9 +2404,13 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
       b->impl = nir_function_impl_create(func->overload);
       b->const_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
                                                _mesa_key_pointer_equal);
+      b->block_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
+                                               _mesa_key_pointer_equal);
       nir_builder_init(&b->nb, b->impl);
       nir_builder_insert_after_cf_list(&b->nb, &b->impl->body);
       vtn_walk_blocks(b, func->start_block, NULL, NULL, NULL);
+      vtn_foreach_instruction(b, func->start_block->label, func->end,
+                              vtn_handle_phi_second_pass);
    }
 
    ralloc_free(b);
index 937c45b08c27f3e699bcd8ce2923fa500a06944a..b157e023a684c5e733e31d698d1a82257d18e803 100644 (file)
@@ -60,6 +60,8 @@ struct vtn_function {
 
    nir_function_overload *overload;
    struct vtn_block *start_block;
+
+   const uint32_t *end;
 };
 
 typedef bool (*vtn_instruction_handler)(struct vtn_builder *, uint32_t,
@@ -118,6 +120,12 @@ struct vtn_builder {
     */
    struct hash_table *const_table;
 
+   /*
+    * Map from nir_block to the vtn_block which ends with it -- used for
+    * handling phi nodes.
+    */
+   struct hash_table *block_table;
+
    unsigned value_id_bound;
    struct vtn_value *values;