From 456edf0b302f317d16e4041aceab4c3ddc8347af Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Mon, 22 Jun 2020 16:48:43 -0700 Subject: [PATCH] nir: Support 8 and 16 component vectors for reduceable intrinsics Reviewed-by: Boris Brezillon Reviewed-by: Alyssa Rosenzweig Reviewed-by: Jason Ekstrand Part-of: --- src/compiler/nir/nir_builder.h | 6 ++++++ src/compiler/nir/nir_lower_alu_to_scalar.c | 2 ++ src/compiler/nir/nir_opcodes.py | 23 +++++++++++----------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/compiler/nir/nir_builder.h b/src/compiler/nir/nir_builder.h index 48b22888516..283eb2b1f19 100644 --- a/src/compiler/nir/nir_builder.h +++ b/src/compiler/nir/nir_builder.h @@ -513,6 +513,8 @@ nir_fdot(nir_builder *build, nir_ssa_def *src0, nir_ssa_def *src1) case 2: return nir_fdot2(build, src0, src1); case 3: return nir_fdot3(build, src0, src1); case 4: return nir_fdot4(build, src0, src1); + case 8: return nir_fdot8(build, src0, src1); + case 16: return nir_fdot16(build, src0, src1); default: unreachable("bad component size"); } @@ -528,6 +530,8 @@ nir_ball_iequal(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1) case 2: return nir_ball_iequal2(b, src0, src1); case 3: return nir_ball_iequal3(b, src0, src1); case 4: return nir_ball_iequal4(b, src0, src1); + case 8: return nir_ball_iequal8(b, src0, src1); + case 16: return nir_ball_iequal16(b, src0, src1); default: unreachable("bad component size"); } @@ -541,6 +545,8 @@ nir_bany_inequal(nir_builder *b, nir_ssa_def *src0, nir_ssa_def *src1) case 2: return nir_bany_inequal2(b, src0, src1); case 3: return nir_bany_inequal3(b, src0, src1); case 4: return nir_bany_inequal4(b, src0, src1); + case 8: return nir_bany_inequal8(b, src0, src1); + case 16: return nir_bany_inequal16(b, src0, src1); default: unreachable("bad component size"); } diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c index e3258429b58..138318fbf60 100644 --- a/src/compiler/nir/nir_lower_alu_to_scalar.c +++ b/src/compiler/nir/nir_lower_alu_to_scalar.c @@ -114,6 +114,8 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data) case name##2: \ case name##3: \ case name##4: \ + case name##8: \ + case name##16: \ return lower_reduction(alu, chan, merge, b); \ switch (alu->op) { diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index f18668493b2..7846bf2192c 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -77,7 +77,7 @@ class Opcode(object): assert len(input_sizes) == len(input_types) assert 0 <= output_size <= 4 or (output_size == 8) or (output_size == 16) for size in input_sizes: - assert 0 <= size <= 4 + assert 0 <= size <= 4 or (size == 8) or (size == 16) if output_size != 0: assert size != 0 self.name = name @@ -544,19 +544,18 @@ def binop_reduce(name, output_size, output_type, src_type, prereduce_expr, return reduce_expr.format(src0=src0, src1=src1) def prereduce(src0, src1): return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")" - src0 = prereduce("src0.x", "src1.x") - src1 = prereduce("src0.y", "src1.y") - src2 = prereduce("src0.z", "src1.z") - src3 = prereduce("src0.w", "src1.w") - opcode(name + "2", output_size, output_type, - [2, 2], [src_type, src_type], False, _2src_commutative, - final(reduce_(src0, src1))) + srcs = [prereduce("src0." + letter, "src1." + letter) for letter in "xyzwefghijklmnop"] + def pairwise_reduce(start, size): + if (size == 1): + return srcs[start] + return reduce_(pairwise_reduce(start, size // 2), pairwise_reduce(start + size // 2, size // 2)) + for size in [2, 4, 8, 16]: + opcode(name + str(size), output_size, output_type, + [size, size], [src_type, src_type], False, _2src_commutative, + final(pairwise_reduce(0, size))) opcode(name + "3", output_size, output_type, [3, 3], [src_type, src_type], False, _2src_commutative, - final(reduce_(reduce_(src0, src1), src2))) - opcode(name + "4", output_size, output_type, - [4, 4], [src_type, src_type], False, _2src_commutative, - final(reduce_(reduce_(src0, src1), reduce_(src2, src3)))) + final(reduce_(reduce_(srcs[0], srcs[1]), srcs[2]))) def binop_reduce_all_sizes(name, output_size, src_type, prereduce_expr, reduce_expr, final_expr): -- 2.30.2