nir: Support 8 and 16 component vectors for reduceable intrinsics
authorJesse Natalie <jenatali@microsoft.com>
Mon, 22 Jun 2020 23:48:43 +0000 (16:48 -0700)
committerJesse Natalie <jenatali@microsoft.com>
Fri, 24 Jul 2020 01:23:20 +0000 (18:23 -0700)
Reviewed-by: Boris Brezillon <boris.brezillon@collabora.com>
Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6030>

src/compiler/nir/nir_builder.h
src/compiler/nir/nir_lower_alu_to_scalar.c
src/compiler/nir/nir_opcodes.py

index 48b22888516abe91fdeb1d80663d5e59aacbab0f..283eb2b1f198dcb73fd138f5db8fbe4630f7cbdb 100644 (file)
@@ -513,6 +513,8 @@ nir_fdot(nir_builder *build, nir_ssa_def *src0, nir_ssa_def *src1)
    case 2: return nir_fdot2(build, src0, src1);
    case 3: return nir_fdot3(build, src0, src1);
    case 4: return nir_fdot4(build, src0, src1);
+   case 8: return nir_fdot8(build, src0, src1);
+   case 16: return nir_fdot16(build, src0, src1);
    default:
       unreachable("bad component size");
    }
@@ -528,6 +530,8 @@ nir_ball_iequal(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
    case 2: return nir_ball_iequal2(b, src0, src1);
    case 3: return nir_ball_iequal3(b, src0, src1);
    case 4: return nir_ball_iequal4(b, src0, src1);
+   case 8: return nir_ball_iequal8(b, src0, src1);
+   case 16: return nir_ball_iequal16(b, src0, src1);
    default:
       unreachable("bad component size");
    }
@@ -541,6 +545,8 @@ nir_bany_inequal(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1)
    case 2: return nir_bany_inequal2(b, src0, src1);
    case 3: return nir_bany_inequal3(b, src0, src1);
    case 4: return nir_bany_inequal4(b, src0, src1);
+   case 8: return nir_bany_inequal8(b, src0, src1);
+   case 16: return nir_bany_inequal16(b, src0, src1);
    default:
       unreachable("bad component size");
    }
index e3258429b5800b579cb5a9b52debd16b8653140d..138318fbf60b0bd0b02e56a7ac9421b009558d01 100644 (file)
@@ -114,6 +114,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data)
    case name##2: \
    case name##3: \
    case name##4: \
+   case name##8: \
+   case name##16: \
       return lower_reduction(alu, chan, merge, b); \
 
    switch (alu->op) {
index f18668493b22509bd6267b4b26b7702b5a8086e0..7846bf2192c290b6b2d3c9eeaeedd3491cec3b43 100644 (file)
@@ -77,7 +77,7 @@ class Opcode(object):
       assert len(input_sizes) == len(input_types)
       assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16)
       for size in input_sizes:
-         assert 0 <= size <= 4
+         assert 0 <= size <= 4 or (size == 8) or (size == 16)
          if output_size != 0:
             assert size != 0
       self.name = name
@@ -544,19 +544,18 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
       return reduce_expr.format(src0=src0, src1=src1)
    def prereduce(src0, src1):
       return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
-   src0 = prereduce("src0.x", "src1.x")
-   src1 = prereduce("src0.y", "src1.y")
-   src2 = prereduce("src0.z", "src1.z")
-   src3 = prereduce("src0.w", "src1.w")
-   opcode(name + "2", output_size, output_type,
-          [2, 2], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(src0, src1)))
+   srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"]
+   def pairwise_reduce(start, size):
+      if (size == 1):
+         return srcs[start]
+      return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2))
+   for size in [2, 4, 8, 16]:
+      opcode(name + str(size), output_size, output_type,
+             [size, size], [src_type, src_type], False, _2src_commutative,
+             final(pairwise_reduce(0, size)))
    opcode(name + "3", output_size, output_type,
           [3, 3], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), src2)))
-   opcode(name + "4", output_size, output_type,
-          [4, 4], [src_type, src_type], False, _2src_commutative,
-          final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
+          final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2])))
 
 def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr,
                            reduce_expr, final_expr):