compiler: Add SYSTEM_VALUE_IS_INDEXED_DRAW and instrinsics
[mesa.git] / src / compiler / nir / nir.c
index a9fac96d1e4e5c85a3e51e7a4a8d2914c5ad9fa8..dc1c560319e03c68ae8f51c96107062a0a71f120 100644 (file)
 
 #include "nir.h"
 #include "nir_control_flow_private.h"
+#include "util/half_float.h"
+#include <limits.h>
 #include <assert.h>
+#include <math.h>
 
 nir_shader *
 nir_shader_create(void *mem_ctx,
@@ -44,7 +47,12 @@ nir_shader_create(void *mem_ctx,
 
    shader->options = options;
 
-   shader->info = si ? si : rzalloc(shader, shader_info);
+   if (si) {
+      assert(si->stage == stage);
+      shader->info = *si;
+   } else {
+      shader->info.stage = stage;
+   }
 
    exec_list_make_empty(&shader->functions);
    exec_list_make_empty(&shader->registers);
@@ -57,8 +65,6 @@ nir_shader_create(void *mem_ctx,
    shader->num_uniforms = 0;
    shader->num_shared = 0;
 
-   shader->stage = stage;
-
    return shader;
 }
 
@@ -142,7 +148,7 @@ nir_shader_add_variable(nir_shader *shader, nir_variable *var)
       break;
 
    case nir_var_shared:
-      assert(shader->stage == MESA_SHADER_COMPUTE);
+      assert(shader->info.stage == MESA_SHADER_COMPUTE);
       exec_list_push_tail(&shader->shared, &var->node);
       break;
 
@@ -161,8 +167,10 @@ nir_variable_create(nir_shader *shader, nir_variable_mode mode,
    var->type = type;
    var->data.mode = mode;
 
-   if ((mode == nir_var_shader_in && shader->stage != MESA_SHADER_VERTEX) ||
-       (mode == nir_var_shader_out && shader->stage != MESA_SHADER_FRAGMENT))
+   if ((mode == nir_var_shader_in &&
+        shader->info.stage != MESA_SHADER_VERTEX) ||
+       (mode == nir_var_shader_out &&
+        shader->info.stage != MESA_SHADER_FRAGMENT))
       var->data.interpolation = INTERP_MODE_SMOOTH;
 
    if (mode == nir_var_shader_in || mode == nir_var_uniform)
@@ -204,6 +212,9 @@ nir_function_create(nir_shader *shader, const char *name)
    return func;
 }
 
+/* NOTE: if the instruction you are copying a src to is already added
+ * to the IR, use nir_instr_rewrite_src() instead.
+ */
 void nir_src_copy(nir_src *dest, const nir_src *src, void *mem_ctx)
 {
    dest->is_ssa = src->is_ssa;
@@ -345,7 +356,7 @@ nir_block_create(nir_shader *shader)
                                           _mesa_key_pointer_equal);
    block->imm_dom = NULL;
    /* XXX maybe it would be worth it to defer allocation?  This
-    * way it doesn't get allocated for shader ref's that never run
+    * way it doesn't get allocated for shader refs that never run
     * nir_calc_dominance?  For example, state-tracker creates an
     * initial IR, clones that, runs appropriate lowering pass, passes
     * to driver which does common lowering/opt, and then stores ref
@@ -475,7 +486,7 @@ nir_load_const_instr *
 nir_load_const_instr_create(nir_shader *shader, unsigned num_components,
                             unsigned bit_size)
 {
-   nir_load_const_instr *instr = ralloc(shader, nir_load_const_instr);
+   nir_load_const_instr *instr = rzalloc(shader, nir_load_const_instr);
    instr_init(&instr->instr, nir_instr_type_load_const);
 
    nir_ssa_def_init(&instr->instr, &instr->def, num_components, bit_size, NULL);
@@ -540,6 +551,28 @@ nir_tex_instr_create(nir_shader *shader, unsigned num_srcs)
    return instr;
 }
 
+void
+nir_tex_instr_add_src(nir_tex_instr *tex,
+                      nir_tex_src_type src_type,
+                      nir_src src)
+{
+   nir_tex_src *new_srcs = rzalloc_array(tex, nir_tex_src,
+                                         tex->num_srcs + 1);
+
+   for (unsigned i = 0; i < tex->num_srcs; i++) {
+      new_srcs[i].src_type = tex->src[i].src_type;
+      nir_instr_move_src(&tex->instr, &new_srcs[i].src,
+                         &tex->src[i].src);
+   }
+
+   ralloc_free(tex->src);
+   tex->src = new_srcs;
+
+   tex->src[tex->num_srcs].src_type = src_type;
+   nir_instr_rewrite_src(&tex->instr, &tex->src[tex->num_srcs].src, src);
+   tex->num_srcs++;
+}
+
 void
 nir_tex_instr_remove_src(nir_tex_instr *tex, unsigned src_idx)
 {
@@ -699,8 +732,13 @@ deref_foreach_leaf_build_recur(nir_deref_var *deref, nir_deref *tail,
    assert(tail->child == NULL);
    switch (glsl_get_base_type(tail->type)) {
    case GLSL_TYPE_UINT:
+   case GLSL_TYPE_UINT16:
+   case GLSL_TYPE_UINT64:
    case GLSL_TYPE_INT:
+   case GLSL_TYPE_INT16:
+   case GLSL_TYPE_INT64:
    case GLSL_TYPE_FLOAT:
+   case GLSL_TYPE_FLOAT16:
    case GLSL_TYPE_DOUBLE:
    case GLSL_TYPE_BOOL:
       if (glsl_type_is_vector_or_scalar(tail->type))
@@ -845,7 +883,10 @@ nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
    case GLSL_TYPE_FLOAT:
    case GLSL_TYPE_INT:
    case GLSL_TYPE_UINT:
+   case GLSL_TYPE_FLOAT16:
    case GLSL_TYPE_DOUBLE:
+   case GLSL_TYPE_INT16:
+   case GLSL_TYPE_UINT16:
    case GLSL_TYPE_UINT64:
    case GLSL_TYPE_INT64:
    case GLSL_TYPE_BOOL:
@@ -858,6 +899,72 @@ nir_deref_get_const_initializer_load(nir_shader *shader, nir_deref_var *deref)
    return load;
 }
 
+static nir_const_value
+const_value_float(double d, unsigned bit_size)
+{
+   nir_const_value v;
+   switch (bit_size) {
+   case 16: v.u16[0] = _mesa_float_to_half(d);  break;
+   case 32: v.f32[0] = d;                       break;
+   case 64: v.f64[0] = d;                       break;
+   default:
+      unreachable("Invalid bit size");
+   }
+   return v;
+}
+
+static nir_const_value
+const_value_int(int64_t i, unsigned bit_size)
+{
+   nir_const_value v;
+   switch (bit_size) {
+   case 8:  v.i8[0]  = i;  break;
+   case 16: v.i16[0] = i;  break;
+   case 32: v.i32[0] = i;  break;
+   case 64: v.i64[0] = i;  break;
+   default:
+      unreachable("Invalid bit size");
+   }
+   return v;
+}
+
+nir_const_value
+nir_alu_binop_identity(nir_op binop, unsigned bit_size)
+{
+   const int64_t max_int = (1ull << (bit_size - 1)) - 1;
+   const int64_t min_int = -max_int - 1;
+   switch (binop) {
+   case nir_op_iadd:
+      return const_value_int(0, bit_size);
+   case nir_op_fadd:
+      return const_value_float(0, bit_size);
+   case nir_op_imul:
+      return const_value_int(1, bit_size);
+   case nir_op_fmul:
+      return const_value_float(1, bit_size);
+   case nir_op_imin:
+      return const_value_int(max_int, bit_size);
+   case nir_op_umin:
+      return const_value_int(~0ull, bit_size);
+   case nir_op_fmin:
+      return const_value_float(INFINITY, bit_size);
+   case nir_op_imax:
+      return const_value_int(min_int, bit_size);
+   case nir_op_umax:
+      return const_value_int(0, bit_size);
+   case nir_op_fmax:
+      return const_value_float(-INFINITY, bit_size);
+   case nir_op_iand:
+      return const_value_int(~0ull, bit_size);
+   case nir_op_ior:
+      return const_value_int(0, bit_size);
+   case nir_op_ixor:
+      return const_value_int(0, bit_size);
+   default:
+      unreachable("Invalid reduction operation");
+   }
+}
+
 nir_function_impl *
 nir_cf_node_get_function(nir_cf_node *node)
 {
@@ -1055,7 +1162,7 @@ remove_defs_uses(nir_instr *instr)
    nir_foreach_src(instr, remove_use_cb, instr);
 }
 
-void nir_instr_remove(nir_instr *instr)
+void nir_instr_remove_v(nir_instr *instr)
 {
    remove_defs_uses(instr);
    exec_node_remove(&instr->node);
@@ -1506,6 +1613,19 @@ nir_instr_rewrite_dest(nir_instr *instr, nir_dest *dest, nir_dest new_dest)
       src_add_all_uses(dest->reg.indirect, instr, NULL);
 }
 
+void
+nir_instr_rewrite_deref(nir_instr *instr, nir_deref_var **deref,
+                        nir_deref_var *new_deref)
+{
+   if (*deref)
+      visit_deref_src(*deref, remove_use_cb, NULL);
+
+   *deref = new_deref;
+
+   if (*deref)
+      visit_deref_src(*deref, add_use_cb, instr);
+}
+
 /* note: does *not* take ownership of 'name' */
 void
 nir_ssa_def_init(nir_instr *instr, nir_ssa_def *def,
@@ -1602,7 +1722,7 @@ nir_ssa_def_rewrite_uses_after(nir_ssa_def *def, nir_src new_src,
 }
 
 uint8_t
-nir_ssa_def_components_read(nir_ssa_def *def)
+nir_ssa_def_components_read(const nir_ssa_def *def)
 {
    uint8_t read_mask = 0;
    nir_foreach_use(use, def) {
@@ -1871,10 +1991,16 @@ nir_intrinsic_from_system_value(gl_system_value val)
       return nir_intrinsic_load_base_instance;
    case SYSTEM_VALUE_VERTEX_ID_ZERO_BASE:
       return nir_intrinsic_load_vertex_id_zero_base;
+   case SYSTEM_VALUE_IS_INDEXED_DRAW:
+      return nir_intrinsic_load_is_indexed_draw;
+   case SYSTEM_VALUE_FIRST_VERTEX:
+      return nir_intrinsic_load_first_vertex;
    case SYSTEM_VALUE_BASE_VERTEX:
       return nir_intrinsic_load_base_vertex;
    case SYSTEM_VALUE_INVOCATION_ID:
       return nir_intrinsic_load_invocation_id;
+   case SYSTEM_VALUE_FRAG_COORD:
+      return nir_intrinsic_load_frag_coord;
    case SYSTEM_VALUE_FRONT_FACE:
       return nir_intrinsic_load_front_face;
    case SYSTEM_VALUE_SAMPLE_ID:
@@ -1903,6 +2029,28 @@ nir_intrinsic_from_system_value(gl_system_value val)
       return nir_intrinsic_load_patch_vertices_in;
    case SYSTEM_VALUE_HELPER_INVOCATION:
       return nir_intrinsic_load_helper_invocation;
+   case SYSTEM_VALUE_VIEW_INDEX:
+      return nir_intrinsic_load_view_index;
+   case SYSTEM_VALUE_SUBGROUP_SIZE:
+      return nir_intrinsic_load_subgroup_size;
+   case SYSTEM_VALUE_SUBGROUP_INVOCATION:
+      return nir_intrinsic_load_subgroup_invocation;
+   case SYSTEM_VALUE_SUBGROUP_EQ_MASK:
+      return nir_intrinsic_load_subgroup_eq_mask;
+   case SYSTEM_VALUE_SUBGROUP_GE_MASK:
+      return nir_intrinsic_load_subgroup_ge_mask;
+   case SYSTEM_VALUE_SUBGROUP_GT_MASK:
+      return nir_intrinsic_load_subgroup_gt_mask;
+   case SYSTEM_VALUE_SUBGROUP_LE_MASK:
+      return nir_intrinsic_load_subgroup_le_mask;
+   case SYSTEM_VALUE_SUBGROUP_LT_MASK:
+      return nir_intrinsic_load_subgroup_lt_mask;
+   case SYSTEM_VALUE_NUM_SUBGROUPS:
+      return nir_intrinsic_load_num_subgroups;
+   case SYSTEM_VALUE_SUBGROUP_ID:
+      return nir_intrinsic_load_subgroup_id;
+   case SYSTEM_VALUE_LOCAL_GROUP_SIZE:
+      return nir_intrinsic_load_local_group_size;
    default:
       unreachable("system value does not directly correspond to intrinsic");
    }
@@ -1922,10 +2070,16 @@ nir_system_value_from_intrinsic(nir_intrinsic_op intrin)
       return SYSTEM_VALUE_BASE_INSTANCE;
    case nir_intrinsic_load_vertex_id_zero_base:
       return SYSTEM_VALUE_VERTEX_ID_ZERO_BASE;
+   case nir_intrinsic_load_first_vertex:
+      return SYSTEM_VALUE_FIRST_VERTEX;
+   case nir_intrinsic_load_is_indexed_draw:
+      return SYSTEM_VALUE_IS_INDEXED_DRAW;
    case nir_intrinsic_load_base_vertex:
       return SYSTEM_VALUE_BASE_VERTEX;
    case nir_intrinsic_load_invocation_id:
       return SYSTEM_VALUE_INVOCATION_ID;
+   case nir_intrinsic_load_frag_coord:
+      return SYSTEM_VALUE_FRAG_COORD;
    case nir_intrinsic_load_front_face:
       return SYSTEM_VALUE_FRONT_FACE;
    case nir_intrinsic_load_sample_id:
@@ -1954,100 +2108,29 @@ nir_system_value_from_intrinsic(nir_intrinsic_op intrin)
       return SYSTEM_VALUE_VERTICES_IN;
    case nir_intrinsic_load_helper_invocation:
       return SYSTEM_VALUE_HELPER_INVOCATION;
+   case nir_intrinsic_load_view_index:
+      return SYSTEM_VALUE_VIEW_INDEX;
+   case nir_intrinsic_load_subgroup_size:
+      return SYSTEM_VALUE_SUBGROUP_SIZE;
+   case nir_intrinsic_load_subgroup_invocation:
+      return SYSTEM_VALUE_SUBGROUP_INVOCATION;
+   case nir_intrinsic_load_subgroup_eq_mask:
+      return SYSTEM_VALUE_SUBGROUP_EQ_MASK;
+   case nir_intrinsic_load_subgroup_ge_mask:
+      return SYSTEM_VALUE_SUBGROUP_GE_MASK;
+   case nir_intrinsic_load_subgroup_gt_mask:
+      return SYSTEM_VALUE_SUBGROUP_GT_MASK;
+   case nir_intrinsic_load_subgroup_le_mask:
+      return SYSTEM_VALUE_SUBGROUP_LE_MASK;
+   case nir_intrinsic_load_subgroup_lt_mask:
+      return SYSTEM_VALUE_SUBGROUP_LT_MASK;
+   case nir_intrinsic_load_num_subgroups:
+      return SYSTEM_VALUE_NUM_SUBGROUPS;
+   case nir_intrinsic_load_subgroup_id:
+      return SYSTEM_VALUE_SUBGROUP_ID;
+   case nir_intrinsic_load_local_group_size:
+      return SYSTEM_VALUE_LOCAL_GROUP_SIZE;
    default:
       unreachable("intrinsic doesn't produce a system value");
    }
 }
-
-nir_op
-nir_type_conversion_op(nir_alu_type src, nir_alu_type dst)
-{
-   nir_alu_type src_base_type = (nir_alu_type) nir_alu_type_get_base_type(src);
-   nir_alu_type dst_base_type = (nir_alu_type) nir_alu_type_get_base_type(dst);
-   unsigned src_bitsize = nir_alu_type_get_type_size(src);
-   unsigned dst_bitsize = nir_alu_type_get_type_size(dst);
-
-   if (src_base_type == dst_base_type) {
-      if (src_bitsize == dst_bitsize)
-         return (src_base_type == nir_type_float) ? nir_op_fmov : nir_op_imov;
-
-      assert(src_bitsize == 64 || dst_bitsize == 64);
-      if (src_base_type == nir_type_float)
-         /* TODO: implement support for float16 */
-         return (src_bitsize == 64) ? nir_op_d2f : nir_op_f2d;
-      else if (src_base_type == nir_type_uint)
-         return (src_bitsize == 64) ? nir_op_imov : nir_op_u2u64;
-      else if (src_base_type == nir_type_int)
-         return (src_bitsize == 64) ? nir_op_imov : nir_op_i2i64;
-      unreachable("Invalid conversion");
-   }
-
-   /* Different base type but same bit_size */
-   if (src_bitsize == dst_bitsize) {
-      /* TODO: This does not include specific conversions between
-       * signed or unsigned integer types of bit size different than 32 yet.
-       */
-      assert(src_bitsize == 32);
-      switch (src_base_type) {
-      case nir_type_uint:
-         return (dst_base_type == nir_type_float) ? nir_op_u2f : nir_op_imov;
-      case nir_type_int:
-         return (dst_base_type == nir_type_float) ? nir_op_i2f : nir_op_imov;
-      case nir_type_bool:
-         return (dst_base_type == nir_type_float) ? nir_op_b2f : nir_op_b2i;
-      case nir_type_float:
-         switch (dst_base_type) {
-         case nir_type_uint:
-            return nir_op_f2u;
-         case nir_type_bool:
-            return nir_op_f2b;
-         default:
-            return nir_op_f2i;
-         };
-      default:
-         unreachable("Invalid conversion");
-      };
-   }
-
-   /* Different bit_size and different base type */
-   /* TODO: Implement integer support for types with bit_size != 32 */
-   switch (src_base_type) {
-   case nir_type_uint:
-      if (dst == nir_type_float64)
-         return nir_op_u2d;
-      else if (dst == nir_type_int64)
-         return nir_op_u2i64;
-      break;
-   case nir_type_int:
-      if (dst == nir_type_float64)
-         return nir_op_i2d;
-      else if (dst == nir_type_uint64)
-         return nir_op_i2i64;
-      break;
-   case nir_type_bool:
-      assert(dst == nir_type_float64);
-      return nir_op_u2d;
-   case nir_type_float:
-      assert(src_bitsize == 32 || src_bitsize == 64);
-      if (src_bitsize != 64) {
-         assert(dst == nir_type_float64);
-         return nir_op_f2d;
-      }
-      assert(dst_bitsize == 32);
-      switch (dst_base_type) {
-      case nir_type_uint:
-         return nir_op_d2u;
-      case nir_type_int:
-         return nir_op_d2i;
-      case nir_type_bool:
-         return nir_op_d2b;
-      case nir_type_float:
-         return nir_op_d2f;
-      default:
-         unreachable("Invalid conversion");
-      };
-   default:
-      unreachable("Invalid conversion");
-   };
-   unreachable("Invalid conversion");
-}