nir: add nir_var_shader_storage
[mesa.git] / src / glsl / nir / nir.c
index 85ff0f46e2e574db9840e72063d401e1b05013af..78ff886218da8f05b75fdb3b9630b50c8b5339f7 100644 (file)
@@ -57,13 +57,9 @@ reg_create(void *mem_ctx, struct exec_list *list)
 {
    nir_register *reg = ralloc(mem_ctx, nir_register);
 
-   reg->parent_instr = NULL;
-   reg->uses = _mesa_set_create(reg, _mesa_hash_pointer,
-                                _mesa_key_pointer_equal);
-   reg->defs = _mesa_set_create(reg, _mesa_hash_pointer,
-                                _mesa_key_pointer_equal);
-   reg->if_uses = _mesa_set_create(reg, _mesa_hash_pointer,
-                                   _mesa_key_pointer_equal);
+   list_inithead(&reg->uses);
+   list_inithead(&reg->defs);
+   list_inithead(&reg->if_uses);
 
    reg->num_components = 0;
    reg->num_array_elems = 0;
@@ -151,18 +147,18 @@ void nir_src_copy(nir_src *dest, const nir_src *src, void *mem_ctx)
 
 void nir_dest_copy(nir_dest *dest, const nir_dest *src, void *mem_ctx)
 {
-   dest->is_ssa = src->is_ssa;
-   if (src->is_ssa) {
-      dest->ssa = src->ssa;
+   /* Copying an SSA definition makes no sense whatsoever. */
+   assert(!src->is_ssa);
+
+   dest->is_ssa = false;
+
+   dest->reg.base_offset = src->reg.base_offset;
+   dest->reg.reg = src->reg.reg;
+   if (src->reg.indirect) {
+      dest->reg.indirect = ralloc(mem_ctx, nir_src);
+      nir_src_copy(dest->reg.indirect, src->reg.indirect, mem_ctx);
    } else {
-      dest->reg.base_offset = src->reg.base_offset;
-      dest->reg.reg = src->reg.reg;
-      if (src->reg.indirect) {
-         dest->reg.indirect = ralloc(mem_ctx, nir_src);
-         nir_src_copy(dest->reg.indirect, src->reg.indirect, mem_ctx);
-      } else {
-         dest->reg.indirect = NULL;
-      }
+      dest->reg.indirect = NULL;
    }
 }
 
@@ -381,11 +377,11 @@ alu_src_init(nir_alu_src *src)
 }
 
 nir_alu_instr *
-nir_alu_instr_create(void *mem_ctx, nir_op op)
+nir_alu_instr_create(nir_shader *shader, nir_op op)
 {
    unsigned num_srcs = nir_op_infos[op].num_inputs;
    nir_alu_instr *instr =
-      ralloc_size(mem_ctx,
+      ralloc_size(shader,
                   sizeof(nir_alu_instr) + num_srcs * sizeof(nir_alu_src));
 
    instr_init(&instr->instr, nir_instr_type_alu);
@@ -398,18 +394,18 @@ nir_alu_instr_create(void *mem_ctx, nir_op op)
 }
 
 nir_jump_instr *
-nir_jump_instr_create(void *mem_ctx, nir_jump_type type)
+nir_jump_instr_create(nir_shader *shader, nir_jump_type type)
 {
-   nir_jump_instr *instr = ralloc(mem_ctx, nir_jump_instr);
+   nir_jump_instr *instr = ralloc(shader, nir_jump_instr);
    instr_init(&instr->instr, nir_instr_type_jump);
    instr->type = type;
    return instr;
 }
 
 nir_load_const_instr *
-nir_load_const_instr_create(void *mem_ctx, unsigned num_components)
+nir_load_const_instr_create(nir_shader *shader, unsigned num_components)
 {
-   nir_load_const_instr *instr = ralloc(mem_ctx, nir_load_const_instr);
+   nir_load_const_instr *instr = ralloc(shader, nir_load_const_instr);
    instr_init(&instr->instr, nir_instr_type_load_const);
 
    nir_ssa_def_init(&instr->instr, &instr->def, num_components, NULL);
@@ -418,11 +414,11 @@ nir_load_const_instr_create(void *mem_ctx, unsigned num_components)
 }
 
 nir_intrinsic_instr *
-nir_intrinsic_instr_create(void *mem_ctx, nir_intrinsic_op op)
+nir_intrinsic_instr_create(nir_shader *shader, nir_intrinsic_op op)
 {
    unsigned num_srcs = nir_intrinsic_infos[op].num_srcs;
    nir_intrinsic_instr *instr =
-      ralloc_size(mem_ctx,
+      ralloc_size(shader,
                   sizeof(nir_intrinsic_instr) + num_srcs * sizeof(nir_src));
 
    instr_init(&instr->instr, nir_instr_type_intrinsic);
@@ -438,9 +434,9 @@ nir_intrinsic_instr_create(void *mem_ctx, nir_intrinsic_op op)
 }
 
 nir_call_instr *
-nir_call_instr_create(void *mem_ctx, nir_function_overload *callee)
+nir_call_instr_create(nir_shader *shader, nir_function_overload *callee)
 {
-   nir_call_instr *instr = ralloc(mem_ctx, nir_call_instr);
+   nir_call_instr *instr = ralloc(shader, nir_call_instr);
    instr_init(&instr->instr, nir_instr_type_call);
 
    instr->callee = callee;
@@ -452,9 +448,9 @@ nir_call_instr_create(void *mem_ctx, nir_function_overload *callee)
 }
 
 nir_tex_instr *
-nir_tex_instr_create(void *mem_ctx, unsigned num_srcs)
+nir_tex_instr_create(nir_shader *shader, unsigned num_srcs)
 {
-   nir_tex_instr *instr = ralloc(mem_ctx, nir_tex_instr);
+   nir_tex_instr *instr = ralloc(shader, nir_tex_instr);
    instr_init(&instr->instr, nir_instr_type_tex);
 
    dest_init(&instr->dest);
@@ -472,9 +468,9 @@ nir_tex_instr_create(void *mem_ctx, unsigned num_srcs)
 }
 
 nir_phi_instr *
-nir_phi_instr_create(void *mem_ctx)
+nir_phi_instr_create(nir_shader *shader)
 {
-   nir_phi_instr *instr = ralloc(mem_ctx, nir_phi_instr);
+   nir_phi_instr *instr = ralloc(shader, nir_phi_instr);
    instr_init(&instr->instr, nir_instr_type_phi);
 
    dest_init(&instr->dest);
@@ -483,9 +479,9 @@ nir_phi_instr_create(void *mem_ctx)
 }
 
 nir_parallel_copy_instr *
-nir_parallel_copy_instr_create(void *mem_ctx)
+nir_parallel_copy_instr_create(nir_shader *shader)
 {
-   nir_parallel_copy_instr *instr = ralloc(mem_ctx, nir_parallel_copy_instr);
+   nir_parallel_copy_instr *instr = ralloc(shader, nir_parallel_copy_instr);
    instr_init(&instr->instr, nir_instr_type_parallel_copy);
 
    exec_list_make_empty(&instr->entries);
@@ -494,9 +490,9 @@ nir_parallel_copy_instr_create(void *mem_ctx)
 }
 
 nir_ssa_undef_instr *
-nir_ssa_undef_instr_create(void *mem_ctx, unsigned num_components)
+nir_ssa_undef_instr_create(nir_shader *shader, unsigned num_components)
 {
-   nir_ssa_undef_instr *instr = ralloc(mem_ctx, nir_ssa_undef_instr);
+   nir_ssa_undef_instr *instr = ralloc(shader, nir_ssa_undef_instr);
    instr_init(&instr->instr, nir_instr_type_ssa_undef);
 
    nir_ssa_def_init(&instr->instr, &instr->def, num_components, NULL);
@@ -543,7 +539,7 @@ copy_deref_var(void *mem_ctx, nir_deref_var *deref)
    nir_deref_var *ret = nir_deref_var_create(mem_ctx, deref->var);
    ret->deref.type = deref->deref.type;
    if (deref->deref.child)
-      ret->deref.child = nir_copy_deref(mem_ctx, deref->deref.child);
+      ret->deref.child = nir_copy_deref(ret, deref->deref.child);
    return ret;
 }
 
@@ -558,7 +554,7 @@ copy_deref_array(void *mem_ctx, nir_deref_array *deref)
    }
    ret->deref.type = deref->deref.type;
    if (deref->deref.child)
-      ret->deref.child = nir_copy_deref(mem_ctx, deref->deref.child);
+      ret->deref.child = nir_copy_deref(ret, deref->deref.child);
    return ret;
 }
 
@@ -568,7 +564,7 @@ copy_deref_struct(void *mem_ctx, nir_deref_struct *deref)
    nir_deref_struct *ret = nir_deref_struct_create(mem_ctx, deref->index);
    ret->deref.type = deref->deref.type;
    if (deref->deref.child)
-      ret->deref.child = nir_copy_deref(mem_ctx, deref->deref.child);
+      ret->deref.child = nir_copy_deref(ret, deref->deref.child);
    return ret;
 }
 
@@ -589,6 +585,66 @@ nir_copy_deref(void *mem_ctx, nir_deref *deref)
    return NULL;
 }
 
+/* Returns a load_const instruction that represents the constant
+ * initializer for the given deref chain.  The caller is responsible for
+ * ensuring that there actually is a constant initializer.
+ */
+nir_load_const_instr *
+nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
+{
+   nir_constant *constant = deref->var->constant_initializer;
+   assert(constant);
+
+   const nir_deref *tail = &deref->deref;
+   unsigned matrix_offset = 0;
+   while (tail->child) {
+      switch (tail->child->deref_type) {
+      case nir_deref_type_array: {
+         nir_deref_array *arr = nir_deref_as_array(tail->child);
+         assert(arr->deref_array_type == nir_deref_array_type_direct);
+         if (glsl_type_is_matrix(tail->type)) {
+            assert(arr->deref.child == NULL);
+            matrix_offset = arr->base_offset;
+         } else {
+            constant = constant->elements[arr->base_offset];
+         }
+         break;
+      }
+
+      case nir_deref_type_struct: {
+         constant = constant->elements[nir_deref_as_struct(tail->child)->index];
+         break;
+      }
+
+      default:
+         unreachable("Invalid deref child type");
+      }
+
+      tail = tail->child;
+   }
+
+   nir_load_const_instr *load =
+      nir_load_const_instr_create(shader, glsl_get_vector_elements(tail->type));
+
+   matrix_offset *= load->def.num_components;
+   for (unsigned i = 0; i < load->def.num_components; i++) {
+      switch (glsl_get_base_type(tail->type)) {
+      case GLSL_TYPE_FLOAT:
+      case GLSL_TYPE_INT:
+      case GLSL_TYPE_UINT:
+         load->value.u[i] = constant->value.u[matrix_offset + i];
+         break;
+      case GLSL_TYPE_BOOL:
+         load->value.u[i] = constant->value.b[matrix_offset + i] ?
+                             NIR_TRUE : NIR_FALSE;
+         break;
+      default:
+         unreachable("Invalid immediate type");
+      }
+   }
+
+   return load;
+}
 
 /**
  * \name Control flow modification
@@ -1010,11 +1066,14 @@ update_if_uses(nir_cf_node *node)
 
    nir_if *if_stmt = nir_cf_node_as_if(node);
 
-   struct set *if_uses_set = if_stmt->condition.is_ssa ?
-                             if_stmt->condition.ssa->if_uses :
-                             if_stmt->condition.reg.reg->uses;
-
-   _mesa_set_add(if_uses_set, if_stmt);
+   if_stmt->condition.parent_if = if_stmt;
+   if (if_stmt->condition.is_ssa) {
+      list_addtail(&if_stmt->condition.use_link,
+                   &if_stmt->condition.ssa->if_uses);
+   } else {
+      list_addtail(&if_stmt->condition.use_link,
+                   &if_stmt->condition.reg.reg->if_uses);
+   }
 }
 
 void
@@ -1167,16 +1226,7 @@ cleanup_cf_node(nir_cf_node *node)
       foreach_list_typed(nir_cf_node, child, node, &if_stmt->else_list)
          cleanup_cf_node(child);
 
-      struct set *if_uses;
-      if (if_stmt->condition.is_ssa) {
-         if_uses = if_stmt->condition.ssa->if_uses;
-      } else {
-         if_uses = if_stmt->condition.reg.reg->if_uses;
-      }
-
-      struct set_entry *entry = _mesa_set_search(if_uses, if_stmt);
-      assert(entry);
-      _mesa_set_remove(if_uses, entry);
+      list_del(&if_stmt->condition.use_link);
       break;
    }
 
@@ -1233,9 +1283,9 @@ add_use_cb(nir_src *src, void *state)
 {
    nir_instr *instr = state;
 
-   struct set *uses_set = src->is_ssa ? src->ssa->uses : src->reg.reg->uses;
-
-   _mesa_set_add(uses_set, instr);
+   src->parent_instr = instr;
+   list_addtail(&src->use_link,
+                src->is_ssa ? &src->ssa->uses : &src->reg.reg->uses);
 
    return true;
 }
@@ -1260,8 +1310,10 @@ add_reg_def_cb(nir_dest *dest, void *state)
 {
    nir_instr *instr = state;
 
-   if (!dest->is_ssa)
-      _mesa_set_add(dest->reg.reg->defs, instr);
+   if (!dest->is_ssa) {
+      dest->reg.parent_instr = instr;
+      list_addtail(&dest->reg.def_link, &dest->reg.reg->defs);
+   }
 
    return true;
 }
@@ -1376,13 +1428,7 @@ nir_instr_insert_after_cf_list(struct exec_list *list, nir_instr *after)
 static bool
 remove_use_cb(nir_src *src, void *state)
 {
-   nir_instr *instr = state;
-
-   struct set *uses_set = src->is_ssa ? src->ssa->uses : src->reg.reg->uses;
-
-   struct set_entry *entry = _mesa_set_search(uses_set, instr);
-   if (entry)
-      _mesa_set_remove(uses_set, entry);
+   list_del(&src->use_link);
 
    return true;
 }
@@ -1390,16 +1436,8 @@ remove_use_cb(nir_src *src, void *state)
 static bool
 remove_def_cb(nir_dest *dest, void *state)
 {
-   nir_instr *instr = state;
-
-   if (dest->is_ssa)
-      return true;
-
-   nir_register *reg = dest->reg.reg;
-
-   struct set_entry *entry = _mesa_set_search(reg->defs, instr);
-   if (entry)
-      _mesa_set_remove(reg->defs, entry);
+   if (!dest->is_ssa)
+      list_del(&dest->reg.def_link);
 
    return true;
 }
@@ -1774,60 +1812,77 @@ nir_srcs_equal(nir_src src1, nir_src src2)
 }
 
 static bool
-src_does_not_use_def(nir_src *src, void *void_def)
+src_is_valid(const nir_src *src)
 {
-   nir_ssa_def *def = void_def;
-
-   if (src->is_ssa) {
-      return src->ssa != def;
-   } else {
-      return true;
-   }
+   return src->is_ssa ? (src->ssa != NULL) : (src->reg.reg != NULL);
 }
 
-static bool
-src_does_not_use_reg(nir_src *src, void *void_reg)
+static void
+src_remove_all_uses(nir_src *src)
 {
-   nir_register *reg = void_reg;
+   for (; src; src = src->is_ssa ? NULL : src->reg.indirect) {
+      if (!src_is_valid(src))
+         continue;
 
-   if (src->is_ssa) {
-      return true;
-   } else {
-      return src->reg.reg != reg;
+      list_del(&src->use_link);
+   }
+}
+
+static void
+src_add_all_uses(nir_src *src, nir_instr *parent_instr, nir_if *parent_if)
+{
+   for (; src; src = src->is_ssa ? NULL : src->reg.indirect) {
+      if (!src_is_valid(src))
+         continue;
+
+      if (parent_instr) {
+         src->parent_instr = parent_instr;
+         if (src->is_ssa)
+            list_addtail(&src->use_link, &src->ssa->uses);
+         else
+            list_addtail(&src->use_link, &src->reg.reg->uses);
+      } else {
+         assert(parent_if);
+         src->parent_if = parent_if;
+         if (src->is_ssa)
+            list_addtail(&src->use_link, &src->ssa->if_uses);
+         else
+            list_addtail(&src->use_link, &src->reg.reg->if_uses);
+      }
    }
 }
 
 void
 nir_instr_rewrite_src(nir_instr *instr, nir_src *src, nir_src new_src)
 {
-   if (src->is_ssa) {
-      nir_ssa_def *old_ssa = src->ssa;
-      *src = new_src;
-      if (old_ssa && nir_foreach_src(instr, src_does_not_use_def, old_ssa)) {
-         struct set_entry *entry = _mesa_set_search(old_ssa->uses, instr);
-         assert(entry);
-         _mesa_set_remove(old_ssa->uses, entry);
-      }
-   } else {
-      if (src->reg.indirect)
-         nir_instr_rewrite_src(instr, src->reg.indirect, new_src);
-
-      nir_register *old_reg = src->reg.reg;
-      *src = new_src;
-      if (old_reg && nir_foreach_src(instr, src_does_not_use_reg, old_reg)) {
-         struct set_entry *entry = _mesa_set_search(old_reg->uses, instr);
-         assert(entry);
-         _mesa_set_remove(old_reg->uses, entry);
-      }
-   }
+   assert(!src_is_valid(src) || src->parent_instr == instr);
 
-   if (new_src.is_ssa) {
-      if (new_src.ssa)
-         _mesa_set_add(new_src.ssa->uses, instr);
-   } else {
-      if (new_src.reg.reg)
-         _mesa_set_add(new_src.reg.reg->uses, instr);
-   }
+   src_remove_all_uses(src);
+   *src = new_src;
+   src_add_all_uses(src, instr, NULL);
+}
+
+void
+nir_instr_move_src(nir_instr *dest_instr, nir_src *dest, nir_src *src)
+{
+   assert(!src_is_valid(dest) || dest->parent_instr == dest_instr);
+
+   src_remove_all_uses(dest);
+   src_remove_all_uses(src);
+   *dest = *src;
+   *src = NIR_SRC_INIT;
+   src_add_all_uses(dest, dest_instr, NULL);
+}
+
+void
+nir_if_rewrite_condition(nir_if *if_stmt, nir_src new_src)
+{
+   nir_src *src = &if_stmt->condition;
+   assert(!src_is_valid(src) || src->parent_if == if_stmt);
+
+   src_remove_all_uses(src);
+   *src = new_src;
+   src_add_all_uses(src, NULL, if_stmt);
 }
 
 void
@@ -1836,10 +1891,8 @@ nir_ssa_def_init(nir_instr *instr, nir_ssa_def *def,
 {
    def->name = name;
    def->parent_instr = instr;
-   def->uses = _mesa_set_create(instr, _mesa_hash_pointer,
-                                _mesa_key_pointer_equal);
-   def->if_uses = _mesa_set_create(instr, _mesa_hash_pointer,
-                                   _mesa_key_pointer_equal);
+   list_inithead(&def->uses);
+   list_inithead(&def->if_uses);
    def->num_components = num_components;
 
    if (instr->block) {
@@ -1860,57 +1913,23 @@ nir_ssa_dest_init(nir_instr *instr, nir_dest *dest,
    nir_ssa_def_init(instr, &dest->ssa, num_components, name);
 }
 
-struct ssa_def_rewrite_state {
-   void *mem_ctx;
-   nir_ssa_def *old;
-   nir_src new_src;
-};
-
-static bool
-ssa_def_rewrite_uses_src(nir_src *src, void *void_state)
-{
-   struct ssa_def_rewrite_state *state = void_state;
-
-   if (src->is_ssa && src->ssa == state->old)
-      nir_src_copy(src, &state->new_src, state->mem_ctx);
-
-   return true;
-}
-
 void
 nir_ssa_def_rewrite_uses(nir_ssa_def *def, nir_src new_src, void *mem_ctx)
 {
-   struct ssa_def_rewrite_state state;
-   state.mem_ctx = mem_ctx;
-   state.old = def;
-   state.new_src = new_src;
-
    assert(!new_src.is_ssa || def != new_src.ssa);
 
-   struct set *new_uses, *new_if_uses;
-   if (new_src.is_ssa) {
-      new_uses = new_src.ssa->uses;
-      new_if_uses = new_src.ssa->if_uses;
-   } else {
-      new_uses = new_src.reg.reg->uses;
-      new_if_uses = new_src.reg.reg->if_uses;
-   }
-
-   struct set_entry *entry;
-   set_foreach(def->uses, entry) {
-      nir_instr *instr = (nir_instr *)entry->key;
-
-      _mesa_set_remove(def->uses, entry);
-      nir_foreach_src(instr, ssa_def_rewrite_uses_src, &state);
-      _mesa_set_add(new_uses, instr);
+   nir_foreach_use_safe(def, use_src) {
+      nir_instr *src_parent_instr = use_src->parent_instr;
+      list_del(&use_src->use_link);
+      nir_src_copy(use_src, &new_src, mem_ctx);
+      src_add_all_uses(use_src, src_parent_instr, NULL);
    }
 
-   set_foreach(def->if_uses, entry) {
-      nir_if *if_use = (nir_if *)entry->key;
-
-      _mesa_set_remove(def->if_uses, entry);
-      nir_src_copy(&if_use->condition, &new_src, mem_ctx);
-      _mesa_set_add(new_if_uses, if_use);
+   nir_foreach_if_use_safe(def, use_src) {
+      nir_if *src_parent_if = use_src->parent_if;
+      list_del(&use_src->use_link);
+      nir_src_copy(use_src, &new_src, mem_ctx);
+      src_add_all_uses(use_src, NULL, src_parent_if);
    }
 }