nir+vtn: vec8+vec16 support
authorRob Clark <robdclark@gmail.com>
Sat, 9 Mar 2019 16:17:55 +0000 (17:17 +0100)
committerKarol Herbst <karolherbst@gmail.com>
Sat, 21 Dec 2019 11:00:17 +0000 (11:00 +0000)
This introduces new vec8 and vec16 instructions (which are the only
instructions taking more than 4 sources), in order to construct 8 and 16
component vectors.

In order to avoid fixing up the non-autogenerated nir_build_alu() sites
and making them pass 16 src args for the benefit of the two instructions
that take more than 4 srcs (ie vec8 and vec16), nir_build_alu() is has
nir_build_alu_tail() split out and re-used by nir_build_alu2() (which is
used for the > 4 src args case).

v2 (Karol Herbst):
  use nir_build_alu2 for vec8 and vec16
  use python's array multiplication syntax
  add nir_op_vec helper
  simplify nir_vec
  nir_build_alu_tail -> nir_builder_alu_instr_finish_and_insert
  use nir_build_alu for opcodes with <= 4 sources
v3 (Karol Herbst):
  fix nir_serialize
v4 (Dave Airlie):
  fix serialization of glsl_type
  handle vec8/16 in lowering of bools
v5 (Karol Herbst):
  fix load store vectorizer

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Dave Airlie <airlied@redhat.com>
14 files changed:
src/compiler/glsl_types.cpp
src/compiler/nir/nir.h
src/compiler/nir/nir_builder.h
src/compiler/nir/nir_builder_opcodes_h.py
src/compiler/nir/nir_constant_expressions.py
src/compiler/nir/nir_lower_alu_to_scalar.c
src/compiler/nir/nir_lower_bool_to_float.c
src/compiler/nir/nir_lower_bool_to_int32.c
src/compiler/nir/nir_opcodes.py
src/compiler/nir/nir_opt_load_store_vectorize.c
src/compiler/nir/nir_print.c
src/compiler/nir/nir_search.c
src/compiler/nir/nir_validate.c
src/compiler/spirv/spirv_to_nir.c

index 4450b23a8c4ab781d4533bf1d0ad4bdb5f2e3499..c958cc4b90dc7bb16db7ef9df71eb079f4d5696e 100644 (file)
@@ -2630,9 +2630,13 @@ encode_type_to_blob(struct blob *blob, const glsl_type *type)
    case GLSL_TYPE_INT64:
    case GLSL_TYPE_BOOL:
       encoded.basic.interface_row_major = type->interface_row_major;
-      assert(type->vector_elements < 8);
       assert(type->matrix_columns < 8);
-      encoded.basic.vector_elements = type->vector_elements;
+      if (type->vector_elements <= 4)
+         encoded.basic.vector_elements = type->vector_elements;
+      else if (type->vector_elements == 8)
+         encoded.basic.vector_elements = 5;
+      else if (type->vector_elements == 16)
+         encoded.basic.vector_elements = 6;
       encoded.basic.matrix_columns = type->matrix_columns;
       encoded.basic.explicit_stride = MIN2(type->explicit_stride, 0xfffff);
       blob_write_uint32(blob, encoded.u32);
@@ -2741,6 +2745,11 @@ decode_type_from_blob(struct blob_reader *blob)
       unsigned explicit_stride = encoded.basic.explicit_stride;
       if (explicit_stride == 0xfffff)
          explicit_stride = blob_read_uint32(blob);
+      uint32_t vector_elements = encoded.basic.vector_elements;
+      if (vector_elements == 5)
+         vector_elements = 8;
+      else if (vector_elements == 6)
+         vector_elements = 16;
       return glsl_type::get_instance(base_type, encoded.basic.vector_elements,
                                      encoded.basic.matrix_columns,
                                      explicit_stride,
index fbef503f0fa168c90b445f87c867e1011a62fec0..53305ae714814848b450eee397d7fa076037c73d 100644 (file)
@@ -58,10 +58,19 @@ extern "C" {
 
 #define NIR_FALSE 0u
 #define NIR_TRUE (~0u)
-#define NIR_MAX_VEC_COMPONENTS 4
+#define NIR_MAX_VEC_COMPONENTS 16
 #define NIR_MAX_MATRIX_COLUMNS 4
 #define NIR_STREAM_PACKED (1 << 8)
-typedef uint8_t nir_component_mask_t;
+typedef uint16_t nir_component_mask_t;
+
+static inline bool
+nir_num_components_valid(unsigned num_components)
+{
+   return (num_components >= 1  &&
+           num_components <= 4) ||
+           num_components == 8  ||
+           num_components == 16;
+}
 
 /** Defines a cast function
  *
@@ -1030,6 +1039,8 @@ nir_op_vec(unsigned components)
    case  2: return nir_op_vec2;
    case  3: return nir_op_vec3;
    case  4: return nir_op_vec4;
+   case  8: return nir_op_vec8;
+   case 16: return nir_op_vec16;
    default: unreachable("bad component count");
    }
 }
index aed475938263e732cfb9219631ecd84d56d16eb4..8b5923211db294246b1d1193363abb1176fecae9 100644 (file)
@@ -874,7 +874,7 @@ nir_ssa_for_src(nir_builder *build, nir_src src, int num_components)
 static inline nir_ssa_def *
 nir_ssa_for_alu_src(nir_builder *build, nir_alu_instr *instr, unsigned srcn)
 {
-   static uint8_t trivial_swizzle[] = { 0, 1, 2, 3 };
+   static uint8_t trivial_swizzle[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 };
    STATIC_ASSERT(ARRAY_SIZE(trivial_swizzle) == NIR_MAX_VEC_COMPONENTS);
 
    nir_alu_src *src = &instr->src[srcn];
index 53fb23ca2b351cba113854c3cf14d4973d811629..f0d8cf1db68c6f5a556e6b71958f2665dba5bbdc 100644 (file)
@@ -31,14 +31,22 @@ def src_decl_list(num_srcs):
    return ', '.join('nir_ssa_def *src' + str(i) for i in range(num_srcs))
 
 def src_list(num_srcs):
-   return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
+   if num_srcs <= 4:
+      return ', '.join('src' + str(i) if i < num_srcs else 'NULL' for i in range(4))
+   else:
+      return ', '.join('src' + str(i) for i in range(num_srcs))
 %>
 
 % for name, opcode in sorted(opcodes.items()):
 static inline nir_ssa_def *
 nir_${name}(nir_builder *build, ${src_decl_list(opcode.num_inputs)})
 {
+% if opcode.num_inputs <= 4:
    return nir_build_alu(build, nir_op_${name}, ${src_list(opcode.num_inputs)});
+% else:
+   nir_ssa_def *srcs[${opcode.num_inputs}] = {${src_list(opcode.num_inputs)}};
+   return nir_build_alu_src_arr(build, nir_op_${name}, srcs);
+% endif
 }
 % endfor
 
index 267a2615964c39181ce48d5859be784496197c03..8b8cd5ffa698e782a0869694d879e07ec3947b04 100644 (file)
@@ -292,6 +292,18 @@ struct ${type}${width}_vec {
    ${type}${width}_t y;
    ${type}${width}_t z;
    ${type}${width}_t w;
+   ${type}${width}_t e;
+   ${type}${width}_t f;
+   ${type}${width}_t g;
+   ${type}${width}_t h;
+   ${type}${width}_t i;
+   ${type}${width}_t j;
+   ${type}${width}_t k;
+   ${type}${width}_t l;
+   ${type}${width}_t m;
+   ${type}${width}_t n;
+   ${type}${width}_t o;
+   ${type}${width}_t p;
 };
 % endfor
 % endfor
@@ -324,7 +336,7 @@ struct ${type}${width}_vec {
             _src[${j}][${k}].${get_const_field(input_types[j])},
          % endif
       % endfor
-      % for k in range(op.input_sizes[j], 4):
+      % for k in range(op.input_sizes[j], 16):
          0,
       % endfor
       };
@@ -418,18 +430,18 @@ struct ${type}${width}_vec {
       % for k in range(op.output_size):
          % if output_type == "int1" or output_type == "uint1":
             /* 1-bit integers get truncated */
-            _dst_val[${k}].b = dst.${"xyzw"[k]} & 1;
+            _dst_val[${k}].b = dst.${"xyzwefghijklmnop"[k]} & 1;
          % elif output_type.startswith("bool"):
             ## Sanitize the C value to a proper NIR 0/-1 bool
-            _dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzw"[k]};
+            _dst_val[${k}].${get_const_field(output_type)} = -(int)dst.${"xyzwefghijklmnop"[k]};
          % elif output_type == "float16":
             if (nir_is_rounding_mode_rtz(execution_mode, 16)) {
-               _dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzw"[k]});
+               _dst_val[${k}].u16 = _mesa_float_to_float16_rtz(dst.${"xyzwefghijklmnop"[k]});
             } else {
-               _dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzw"[k]});
+               _dst_val[${k}].u16 = _mesa_float_to_float16_rtne(dst.${"xyzwefghijklmnop"[k]});
             }
          % else:
-            _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzw"[k]};
+            _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzwefghijklmnop"[k]};
          % endif
 
          % if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
index f672699b5a4555d3f21341bc843586b8fb81bf95..e919370ec1aab4ac1e2cc51ee6417bd2bc49877c 100644 (file)
@@ -117,6 +117,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
       return lower_reduction(alu, chan, merge, b); \
 
    switch (alu->op) {
+   case nir_op_vec16:
+   case nir_op_vec8:
    case nir_op_vec4:
    case nir_op_vec3:
    case nir_op_vec2:
index c07121f6d882be4f770e961391ce972031d31531..d552b0575903108231799f151c0f358aa1ed276e 100644 (file)
@@ -56,6 +56,8 @@ lower_alu_instr(nir_builder *b, nir_alu_instr *alu)
    case nir_op_vec2:
    case nir_op_vec3:
    case nir_op_vec4:
+   case nir_op_vec8:
+   case nir_op_vec16:
       /* These we expect to have booleans but the opcode doesn't change */
       break;
 
index e331de488a3676ab1b6f4e4a3932ef253219cd92..0978207e72d2a44ef6aab3576f197d0cab4f57b0 100644 (file)
@@ -53,6 +53,8 @@ lower_alu_instr(nir_alu_instr *alu)
    case nir_op_vec2:
    case nir_op_vec3:
    case nir_op_vec4:
+   case nir_op_vec8:
+   case nir_op_vec16:
    case nir_op_inot:
    case nir_op_iand:
    case nir_op_ior:
index 2ab04ed9b1dbba731e412f5a94d6761a78b29afa..86485e395083c2ba4354c96a1bc07c09ea185dc9 100644 (file)
@@ -75,7 +75,7 @@ class Opcode(object):
       assert isinstance(algebraic_properties, str)
       assert isinstance(const_expr, str)
       assert len(input_sizes) == len(input_types)
-      assert 0 <= output_size <= 4
+      assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
          assert 0 <= size <= 4
          if output_size != 0:
@@ -1057,6 +1057,40 @@ dst.z = src2.x;
 dst.w = src3.x;
 """)
 
+opcode("vec8", 8, tuint,
+       [1] * 8, [tuint] * 8,
+       False, "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+""")
+
+opcode("vec16", 16, tuint,
+       [1] * 16, [tuint] * 16,
+       False, "", """
+dst.x = src0.x;
+dst.y = src1.x;
+dst.z = src2.x;
+dst.w = src3.x;
+dst.e = src4.x;
+dst.f = src5.x;
+dst.g = src6.x;
+dst.h = src7.x;
+dst.i = src8.x;
+dst.j = src9.x;
+dst.k = src10.x;
+dst.l = src11.x;
+dst.m = src12.x;
+dst.n = src13.x;
+dst.o = src14.x;
+dst.p = src15.x;
+""")
+
 # An integer multiply instruction for address calculation.  This is
 # similar to imul, except that the results are undefined in case of
 # overflow.  Overflow is defined according to the size of the variable
index c40d4ada4ee2faba968b45b4b3362016521af0d1..6587251c7e794263388e3c67c262e49262d69010 100644 (file)
@@ -643,7 +643,7 @@ new_bitsize_acceptable(struct vectorize_ctx *ctx, unsigned new_bit_size,
       return false;
 
    unsigned new_num_components = size / new_bit_size;
-   if (new_num_components > NIR_MAX_VEC_COMPONENTS)
+   if (!nir_num_components_valid(new_num_components))
       return false;
 
    unsigned high_offset = high->offset_signed - low->offset_signed;
index ab82f4fbb50456436962188103e2be6e15af156b..9ec6dcfc9476d99ed4922ba350b4c2115fbf9c0f 100644 (file)
@@ -171,6 +171,12 @@ print_dest(nir_dest *dest, print_state *state)
       print_reg_dest(&dest->reg, state);
 }
 
+static const char *
+comp_mask_string(unsigned num_components)
+{
+   return (num_components > 4) ? "abcdefghijklmnop" : "xyzw";
+}
+
 static void
 print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
 {
@@ -206,7 +212,7 @@ print_alu_src(nir_alu_instr *instr, unsigned src, print_state *state)
          if (!nir_alu_instr_channel_used(instr, src, i))
             continue;
 
-         fprintf(fp, "%c", "xyzw"[instr->src[src].swizzle[i]]);
+         fprintf(fp, "%c", comp_mask_string(live_channels)[instr->src[src].swizzle[i]]);
       }
    }
 
@@ -224,10 +230,11 @@ print_alu_dest(nir_alu_dest *dest, print_state *state)
 
    if (!dest->dest.is_ssa &&
        dest->write_mask != (1 << dest->dest.reg.reg->num_components) - 1) {
+      unsigned live_channels = dest->dest.reg.reg->num_components;
       fprintf(fp, ".");
       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
          if ((dest->write_mask >> i) & 1)
-            fprintf(fp, "%c", "xyzw"[i]);
+            fprintf(fp, "%c", comp_mask_string(live_channels)[i]);
    }
 }
 
@@ -569,8 +576,8 @@ print_var_decl(nir_variable *var, print_state *state)
       switch (var->data.mode) {
       case nir_var_shader_in:
       case nir_var_shader_out:
-         if (num_components < 4 && num_components != 0) {
-            const char *xyzw = "xyzw";
+         if (num_components < 16 && num_components != 0) {
+            const char *xyzw = comp_mask_string(num_components);
             for (int i = 0; i < num_components; i++)
                components_local[i + 1] = xyzw[i + var->data.location_frac];
 
@@ -816,9 +823,9 @@ print_intrinsic_instr(nir_intrinsic_instr *instr, print_state *state)
          /* special case wrmask to show it as a writemask.. */
          unsigned wrmask = nir_intrinsic_write_mask(instr);
          fprintf(fp, " wrmask=");
-         for (unsigned i = 0; i < 4; i++)
+         for (unsigned i = 0; i < instr->num_components; i++)
             if ((wrmask >> i) & 1)
-               fprintf(fp, "%c", "xyzw"[i]);
+               fprintf(fp, "%c", comp_mask_string(instr->num_components)[i]);
          break;
       }
 
index c1b179525abaf9b94337e76698ccb17bccd76d70..458a4eeb1ce1ec7ebe8c13b4cb9721a2bd417004 100644 (file)
@@ -56,7 +56,13 @@ static bool
 nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
                         const struct per_op_table *pass_op_table);
 
-static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
+static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
+{
+    0,  1,  2,  3,
+    4,  5,  6,  7,
+    8,  9, 10, 11,
+   12, 13, 14, 15,
+};
 
 /**
  * Check if a source produces a value of the given type.
index 2a0839768b8e8f8d5cf75953359340abb62673bc..38e32e4ce3d912673265ca5b7a18a2d1c6429293 100644 (file)
@@ -128,8 +128,7 @@ static void validate_src(nir_src *src, validate_state *state,
 static void
 validate_num_components(validate_state *state, unsigned num_components)
 {
-   validate_assert(state, num_components >= 1 &&
-                          num_components <= 4);
+   validate_assert(state, nir_num_components_valid(num_components));
 }
 
 static void
index 29fef9fad6632cb51719b460aa46891d1ac12d07..63a165d5b847e9763169387e4ea16243942f40bd 100644 (file)
@@ -3819,10 +3819,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
       case SpvCapabilityInputAttachment:
       case SpvCapabilityImageGatherExtended:
       case SpvCapabilityStorageImageExtendedFormats:
+      case SpvCapabilityVector16:
          break;
 
       case SpvCapabilityLinkage:
-      case SpvCapabilityVector16:
       case SpvCapabilityFloat16Buffer:
       case SpvCapabilitySparseResidency:
          vtn_warn("Unsupported SPIR-V capability: %s",