From 33277bd66e32d50a96b7cd5dfe73a6a962138ea2 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Mon, 11 Nov 2019 17:37:43 +0000 Subject: [PATCH] aco: refactor reduction lowering helpers MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit Should make 64-bit integer reductions easier to implement. v4: use num_opcodes instead of last_opcode Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann (v3) --- src/amd/compiler/aco_lower_to_hw_instr.cpp | 289 ++++++++------------- 1 file changed, 115 insertions(+), 174 deletions(-) diff --git a/src/amd/compiler/aco_lower_to_hw_instr.cpp b/src/amd/compiler/aco_lower_to_hw_instr.cpp index 8db54064202..0183b1d1ac5 100644 --- a/src/amd/compiler/aco_lower_to_hw_instr.cpp +++ b/src/amd/compiler/aco_lower_to_hw_instr.cpp @@ -41,65 +41,101 @@ struct lower_context { std::vector> instructions; }; -void emit_dpp_op(lower_context *ctx, PhysReg dst, PhysReg src0, PhysReg src1, PhysReg vtmp, - aco_opcode op, Format format, bool clobber_vcc, unsigned dpp_ctrl, - unsigned row_mask, unsigned bank_mask, bool bound_ctrl_zero, unsigned size, +aco_opcode get_reduce_opcode(chip_class chip, ReduceOp op) { + switch (op) { + case iadd32: return chip >= GFX9 ? aco_opcode::v_add_u32 : aco_opcode::v_add_co_u32; + case imul32: return aco_opcode::v_mul_lo_u32; + case fadd32: return aco_opcode::v_add_f32; + case fmul32: return aco_opcode::v_mul_f32; + case imax32: return aco_opcode::v_max_i32; + case imin32: return aco_opcode::v_min_i32; + case umin32: return aco_opcode::v_min_u32; + case umax32: return aco_opcode::v_max_u32; + case fmin32: return aco_opcode::v_min_f32; + case fmax32: return aco_opcode::v_max_f32; + case iand32: return aco_opcode::v_and_b32; + case ixor32: return aco_opcode::v_xor_b32; + case ior32: return aco_opcode::v_or_b32; + case iadd64: return aco_opcode::num_opcodes; + case imul64: return aco_opcode::num_opcodes; + case fadd64: return aco_opcode::v_add_f64; + case fmul64: return aco_opcode::v_mul_f64; + case imin64: return aco_opcode::num_opcodes; + case imax64: return aco_opcode::num_opcodes; + case umin64: return aco_opcode::num_opcodes; + case umax64: return aco_opcode::num_opcodes; + case fmin64: return aco_opcode::v_min_f64; + case fmax64: return aco_opcode::v_max_f64; + case iand64: return aco_opcode::num_opcodes; + case ior64: return aco_opcode::num_opcodes; + case ixor64: return aco_opcode::num_opcodes; + default: return aco_opcode::num_opcodes; + } +} + +void emit_dpp_op(lower_context *ctx, PhysReg dst_reg, PhysReg src0_reg, PhysReg src1_reg, + PhysReg vtmp, ReduceOp op, unsigned size, + unsigned dpp_ctrl, unsigned row_mask, unsigned bank_mask, bool bound_ctrl, Operand *identity=NULL) /* for VOP3 with sparse writes */ { + Builder bld(ctx->program, &ctx->instructions); RegClass rc = RegClass(RegType::vgpr, size); - if (format == Format::VOP3) { - Builder bld(ctx->program, &ctx->instructions); + Definition dst(dst_reg, rc); + Operand src0(src0_reg, rc); + Operand src1(src1_reg, rc); - if (identity) - bld.vop1(aco_opcode::v_mov_b32, Definition(vtmp, v1), identity[0]); - if (identity && size >= 2) - bld.vop1(aco_opcode::v_mov_b32, Definition(PhysReg{vtmp+1}, v1), identity[1]); + aco_opcode opcode = get_reduce_opcode(ctx->program->chip_class, op); + bool vop3 = op == imul32 || size == 2; - for (unsigned i = 0; i < size; i++) - bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{src0+i}, v1), - dpp_ctrl, row_mask, bank_mask, bound_ctrl_zero); - - if (clobber_vcc) - bld.vop3(op, Definition(dst, rc), Definition(vcc, s2), Operand(vtmp, rc), Operand(src1, rc)); + if (!vop3) { + if (opcode == aco_opcode::v_add_co_u32) + bld.vop2_dpp(opcode, dst, bld.def(s2, vcc), src0, src1, dpp_ctrl, row_mask, bank_mask, bound_ctrl); else - bld.vop3(op, Definition(dst, rc), Operand(vtmp, rc), Operand(src1, rc)); + bld.vop2_dpp(opcode, dst, src0, src1, dpp_ctrl, row_mask, bank_mask, bound_ctrl); + return; + } + + if (identity) + bld.vop1(aco_opcode::v_mov_b32, Definition(vtmp, v1), identity[0]); + if (identity && size >= 2) + bld.vop1(aco_opcode::v_mov_b32, Definition(PhysReg{vtmp+1}, v1), identity[1]); + + for (unsigned i = 0; i < size; i++) + bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{src0_reg+i}, v1), + dpp_ctrl, row_mask, bank_mask, bound_ctrl); + + bld.vop3(opcode, dst, Operand(vtmp, rc), src1); +} + +void emit_op(lower_context *ctx, PhysReg dst_reg, PhysReg src0_reg, PhysReg src1_reg, + ReduceOp op, unsigned size) +{ + Builder bld(ctx->program, &ctx->instructions); + RegClass rc = RegClass(RegType::vgpr, size); + Definition dst(dst_reg, rc); + Operand src0(src0_reg, RegClass(src0_reg.reg >= 256 ? RegType::vgpr : RegType::sgpr, size)); + Operand src1(src1_reg, rc); + + aco_opcode opcode = get_reduce_opcode(ctx->program->chip_class, op); + bool vop3 = op == imul32 || size == 2; + + if (vop3) { + bld.vop3(opcode, dst, src0, src1); + } else if (opcode == aco_opcode::v_add_co_u32) { + bld.vop2(opcode, dst, bld.def(s2, vcc), src0, src1); } else { - assert(format == Format::VOP2 || format == Format::VOP1); - assert(size == 1 || (op == aco_opcode::v_mov_b32)); - - for (unsigned i = 0; i < size; i++) { - aco_ptr dpp{create_instruction( - op, (Format) ((uint32_t) format | (uint32_t) Format::DPP), - format == Format::VOP2 ? 2 : 1, clobber_vcc ? 2 : 1)}; - dpp->operands[0] = Operand(PhysReg{src0+i}, rc); - if (format == Format::VOP2) - dpp->operands[1] = Operand(PhysReg{src1+i}, rc); - dpp->definitions[0] = Definition(PhysReg{dst+i}, rc); - if (clobber_vcc) - dpp->definitions[1] = Definition(vcc, s2); - dpp->dpp_ctrl = dpp_ctrl; - dpp->row_mask = row_mask; - dpp->bank_mask = bank_mask; - dpp->bound_ctrl = bound_ctrl_zero; - ctx->instructions.emplace_back(std::move(dpp)); - } + bld.vop2(opcode, dst, src0, src1); } } -void emit_op(lower_context *ctx, PhysReg dst, PhysReg src0, PhysReg src1, - aco_opcode op, Format format, bool clobber_vcc, unsigned size) +void emit_dpp_mov(lower_context *ctx, PhysReg dst, PhysReg src0, unsigned size, + unsigned dpp_ctrl, unsigned row_mask, unsigned bank_mask, bool bound_ctrl) { - aco_ptr instr; - if (format == Format::VOP3) - instr.reset(create_instruction(op, format, 2, clobber_vcc ? 2 : 1)); - else - instr.reset(create_instruction(op, format, 2, clobber_vcc ? 2 : 1)); - instr->operands[0] = Operand(src0, src0.reg >= 256 ? v1 : s1); - instr->operands[1] = Operand(src1, v1); - instr->definitions[0] = Definition(dst, v1); - if (clobber_vcc) - instr->definitions[1] = Definition(vcc, s2); - ctx->instructions.emplace_back(std::move(instr)); + Builder bld(ctx->program, &ctx->instructions); + for (unsigned i = 0; i < size; i++) { + bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(PhysReg{dst+i}, v1), Operand(PhysReg{src0+i}, v1), + dpp_ctrl, row_mask, bank_mask, bound_ctrl); + } } uint32_t get_reduction_identity(ReduceOp op, unsigned idx) @@ -151,95 +187,6 @@ uint32_t get_reduction_identity(ReduceOp op, unsigned idx) return 0; } -aco_opcode get_reduction_opcode(lower_context *ctx, ReduceOp op, bool *clobber_vcc, Format *format) -{ - *clobber_vcc = false; - *format = Format::VOP2; - switch (op) { - case iadd32: - *clobber_vcc = ctx->program->chip_class < GFX9; - return ctx->program->chip_class < GFX9 ? aco_opcode::v_add_co_u32 : aco_opcode::v_add_u32; - case imul32: - *format = Format::VOP3; - return aco_opcode::v_mul_lo_u32; - case fadd32: - return aco_opcode::v_add_f32; - case fmul32: - return aco_opcode::v_mul_f32; - case imax32: - return aco_opcode::v_max_i32; - case imin32: - return aco_opcode::v_min_i32; - case umin32: - return aco_opcode::v_min_u32; - case umax32: - return aco_opcode::v_max_u32; - case fmin32: - return aco_opcode::v_min_f32; - case fmax32: - return aco_opcode::v_max_f32; - case iand32: - return aco_opcode::v_and_b32; - case ixor32: - return aco_opcode::v_xor_b32; - case ior32: - return aco_opcode::v_or_b32; - case iadd64: - case imul64: - assert(false); - break; - case fadd64: - *format = Format::VOP3; - return aco_opcode::v_add_f64; - case fmul64: - *format = Format::VOP3; - return aco_opcode::v_mul_f64; - case imin64: - case imax64: - case umin64: - case umax64: - assert(false); - break; - case fmin64: - *format = Format::VOP3; - return aco_opcode::v_min_f64; - case fmax64: - *format = Format::VOP3; - return aco_opcode::v_max_f64; - case iand64: - case ior64: - case ixor64: - assert(false); - break; - default: - unreachable("Invalid reduction operation"); - break; - } - return aco_opcode::v_min_u32; -} - -void emit_vopn(lower_context *ctx, PhysReg dst, PhysReg src0, PhysReg src1, - RegClass rc, aco_opcode op, Format format, bool clobber_vcc) -{ - aco_ptr instr; - switch (format) { - case Format::VOP2: - instr.reset(create_instruction(op, format, 2, clobber_vcc ? 2 : 1)); - break; - case Format::VOP3: - instr.reset(create_instruction(op, format, 2, clobber_vcc ? 2 : 1)); - break; - default: - assert(false); - } - instr->operands[0] = Operand(src0, rc); - instr->operands[1] = Operand(src1, rc); - instr->definitions[0] = Definition(dst, rc); - if (clobber_vcc) - instr->definitions[1] = Definition(vcc, s2); - ctx->instructions.emplace_back(std::move(instr)); -} - void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsigned cluster_size, PhysReg tmp, PhysReg stmp, PhysReg vtmp, PhysReg sitmp, Operand src, Definition dst) { @@ -247,9 +194,6 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig Builder bld(ctx->program, &ctx->instructions); - Format format; - bool should_clobber_vcc; - aco_opcode reduce_opcode = get_reduction_opcode(ctx, reduce_op, &should_clobber_vcc, &format); Operand identity[2]; identity[0] = Operand(get_reduction_identity(reduce_op, 0)); identity[1] = Operand(get_reduction_identity(reduce_op, 1)); @@ -284,49 +228,47 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig switch (op) { case aco_opcode::p_reduce: if (cluster_size == 1) break; - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_quad_perm(1, 0, 3, 2), 0xf, 0xf, false, src.size()); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_quad_perm(1, 0, 3, 2), 0xf, 0xf, false); if (cluster_size == 2) break; - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_quad_perm(2, 3, 0, 1), 0xf, 0xf, false, src.size()); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_quad_perm(2, 3, 0, 1), 0xf, 0xf, false); if (cluster_size == 4) break; - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_half_mirror, 0xf, 0xf, false, src.size()); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_half_mirror, 0xf, 0xf, false); if (cluster_size == 8) break; - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_mirror, 0xf, 0xf, false, src.size()); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_mirror, 0xf, 0xf, false); if (cluster_size == 16) break; if (cluster_size == 32) { for (unsigned i = 0; i < src.size(); i++) bld.ds(aco_opcode::ds_swizzle_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, s1), ds_pattern_bitmode(0x1f, 0, 0x10)); bld.sop1(aco_opcode::s_mov_b64, Definition(exec, s2), Operand(stmp, s2)); exec_restored = true; - emit_vopn(ctx, dst.physReg(), vtmp, tmp, src.regClass(), reduce_opcode, format, should_clobber_vcc); + emit_op(ctx, dst.physReg(), vtmp, tmp, reduce_op, src.size()); dst_written = true; } else if (ctx->program->chip_class >= GFX10) { assert(cluster_size == 64); /* GFX10+ doesn't support row_bcast15 and row_bcast31 */ for (unsigned i = 0; i < src.size(); i++) bld.vop3(aco_opcode::v_permlanex16_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, v1), Operand(0u), Operand(0u)); - emit_op(ctx, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, src.size()); + emit_op(ctx, tmp, tmp, vtmp, reduce_op, src.size()); for (unsigned i = 0; i < src.size(); i++) bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u)); - emit_op(ctx, tmp, sitmp, tmp, reduce_opcode, format, should_clobber_vcc, src.size()); + emit_op(ctx, tmp, sitmp, tmp, reduce_op, src.size()); } else { assert(cluster_size == 64); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_bcast15, 0xa, 0xf, false, src.size()); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_bcast31, 0xc, 0xf, false, src.size()); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_bcast15, 0xa, 0xf, false); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_bcast31, 0xc, 0xf, false); } break; case aco_opcode::p_exclusive_scan: if (ctx->program->chip_class >= GFX10) { /* gfx10 doesn't support wf_sr1, so emulate it */ /* shift rows right */ - for (unsigned i = 0; i < src.size(); i++) { - bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(PhysReg{vtmp+i}, v1), Operand(PhysReg{tmp+i}, s1), dpp_row_sr(1), 0xf, 0xf, true); - } + emit_dpp_mov(ctx, vtmp, tmp, src.size(), dpp_row_sr(1), 0xf, 0xf, true); /* fill in the gaps in rows 1 and 3 */ bld.sop1(aco_opcode::s_mov_b32, Definition(exec_lo, s1), Operand(0x10000u)); @@ -347,8 +289,7 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig } std::swap(tmp, vtmp); } else { - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, aco_opcode::v_mov_b32, Format::VOP1, false, - dpp_wf_sr1, 0xf, 0xf, true, src.size()); + emit_dpp_mov(ctx, tmp, tmp, src.size(), dpp_wf_sr1, 0xf, 0xf, true); } for (unsigned i = 0; i < src.size(); i++) { if (!identity[i].isConstant() || identity[i].constantValue()) { /* bound_ctrl should take case of this overwise */ @@ -361,14 +302,14 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig /* fall through */ case aco_opcode::p_inclusive_scan: assert(cluster_size == 64); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_sr(1), 0xf, 0xf, false, src.size(), identity); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_sr(2), 0xf, 0xf, false, src.size(), identity); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_sr(4), 0xf, 0xf, false, src.size(), identity); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_sr(8), 0xf, 0xf, false, src.size(), identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_sr(1), 0xf, 0xf, false, identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_sr(2), 0xf, 0xf, false, identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_sr(4), 0xf, 0xf, false, identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_sr(8), 0xf, 0xf, false, identity); if (ctx->program->chip_class >= GFX10) { bld.sop1(aco_opcode::s_mov_b32, Definition(exec_lo, s1), Operand(0xffff0000u)); bld.sop1(aco_opcode::s_mov_b32, Definition(exec_hi, s1), Operand(0xffff0000u)); @@ -379,18 +320,18 @@ void emit_reduction(lower_context *ctx, aco_opcode op, ReduceOp reduce_op, unsig Operand(0xffffffffu), Operand(0xffffffffu)).instr; static_cast(perm)->opsel[0] = true; /* FI (Fetch Inactive) */ } - emit_op(ctx, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, src.size()); + emit_op(ctx, tmp, tmp, vtmp, reduce_op, src.size()); bld.sop1(aco_opcode::s_mov_b32, Definition(exec_lo, s1), Operand(0u)); bld.sop1(aco_opcode::s_mov_b32, Definition(exec_hi, s1), Operand(0xffffffffu)); for (unsigned i = 0; i < src.size(); i++) bld.vop3(aco_opcode::v_readlane_b32, Definition(PhysReg{sitmp+i}, s1), Operand(PhysReg{tmp+i}, v1), Operand(31u)); - emit_op(ctx, tmp, sitmp, tmp, reduce_opcode, format, should_clobber_vcc, src.size()); + emit_op(ctx, tmp, sitmp, tmp, reduce_op, src.size()); } else { - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_bcast15, 0xa, 0xf, false, src.size(), identity); - emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_opcode, format, should_clobber_vcc, - dpp_row_bcast31, 0xc, 0xf, false, src.size(), identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_bcast15, 0xa, 0xf, false, identity); + emit_dpp_op(ctx, tmp, tmp, tmp, vtmp, reduce_op, src.size(), + dpp_row_bcast31, 0xc, 0xf, false, identity); } break; default: -- 2.30.2