X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Famd%2Fcompiler%2Faco_builder_h.py;h=8d541cbd72f2a15a4d2856498728d8e8189f01e0;hb=fc9f502a5bd853128a9c2932c793180035883efc;hp=18e4bf752ec18c37ab14e1a5783a7eae9cea16a3;hpb=70f63c198863e60e844978e1ca2e9773159ca8d3;p=mesa.git diff --git a/src/amd/compiler/aco_builder_h.py b/src/amd/compiler/aco_builder_h.py index 18e4bf752ec..8d541cbd72f 100644 --- a/src/amd/compiler/aco_builder_h.py +++ b/src/amd/compiler/aco_builder_h.py @@ -69,6 +69,13 @@ dpp_row_sr(unsigned amount) return (dpp_ctrl)(((unsigned) _dpp_row_sr) | amount); } +inline dpp_ctrl +dpp_row_rr(unsigned amount) +{ + assert(amount > 0 && amount < 16); + return (dpp_ctrl)(((unsigned) _dpp_row_rr) | amount); +} + inline unsigned ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask) { @@ -78,6 +85,8 @@ ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask) aco_ptr create_s_mov(Definition dst, Operand src); +extern uint8_t int8_mul_table[512]; + enum sendmsg { sendmsg_none = 0, _sendmsg_gs = 2, @@ -166,11 +175,25 @@ public: std::vector> *instructions; std::vector>::iterator it; + bool is_precise = false; + bool is_nuw = false; - Builder(Program *pgm) : program(pgm), use_iterator(false), start(false), lm(pgm->lane_mask), instructions(NULL) {} + Builder(Program *pgm) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(NULL) {} Builder(Program *pgm, Block *block) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(&block->instructions) {} Builder(Program *pgm, std::vector> *instrs) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(instrs) {} + Builder precise() const { + Builder res = *this; + res.is_precise = true; + return res; + }; + + Builder nuw() const { + Builder res = *this; + res.is_nuw = true; + return res; + } + void moveEnd(Block *block) { instructions = &block->instructions; } @@ -294,7 +317,8 @@ public: % for fixed in ['m0', 'vcc', 'exec', 'scc']: Operand ${fixed}(Temp tmp) { % if fixed == 'vcc' or fixed == 'exec': - assert(tmp.regClass() == lm); + //vcc_hi and exec_hi can still be used in wave32 + assert(tmp.type() == RegType::sgpr && tmp.bytes() <= 8); % endif Operand op(tmp); op.setFixed(aco::${fixed}); @@ -303,7 +327,8 @@ public: Definition ${fixed}(Definition def) { % if fixed == 'vcc' or fixed == 'exec': - assert(def.regClass() == lm); + //vcc_hi and exec_hi can still be used in wave32 + assert(def.regClass().type() == RegType::sgpr && def.bytes() <= 8); % endif def.setFixed(aco::${fixed}); return def; @@ -311,7 +336,8 @@ public: Definition hint_${fixed}(Definition def) { % if fixed == 'vcc' or fixed == 'exec': - assert(def.regClass() == lm); + //vcc_hi and exec_hi can still be used in wave32 + assert(def.regClass().type() == RegType::sgpr && def.bytes() <= 8); % endif def.setHint(aco::${fixed}); return def; @@ -352,6 +378,7 @@ public: Result copy(Definition dst, Op op_) { Operand op = op_.op; + assert(op.bytes() == dst.bytes()); if (dst.regClass() == s1 && op.size() == 1 && op.isLiteral()) { uint32_t imm = op.constantValue(); if (imm == 0x3e22f983) { @@ -372,15 +399,58 @@ public: } } - if (dst.regClass() == s2) { + if (dst.regClass() == s1) { + return sop1(aco_opcode::s_mov_b32, dst, op); + } else if (dst.regClass() == s2) { return sop1(aco_opcode::s_mov_b64, dst, op); - } else if (op.size() > 1) { - return pseudo(aco_opcode::p_create_vector, dst, op); } else if (dst.regClass() == v1 || dst.regClass() == v1.as_linear()) { return vop1(aco_opcode::v_mov_b32, dst, op); + } else if (op.bytes() > 2) { + return pseudo(aco_opcode::p_create_vector, dst, op); + } else if (op.bytes() == 1 && op.isConstant()) { + uint8_t val = op.constantValue(); + Operand op32((uint32_t)val | (val & 0x80u ? 0xffffff00u : 0u)); + aco_ptr sdwa; + if (op32.isLiteral()) { + sdwa.reset(create_instruction(aco_opcode::v_mul_u32_u24, asSDWA(Format::VOP2), 2, 1)); + uint32_t a = (uint32_t)int8_mul_table[val * 2]; + uint32_t b = (uint32_t)int8_mul_table[val * 2 + 1]; + sdwa->operands[0] = Operand(a | (a & 0x80u ? 0xffffff00u : 0x0u)); + sdwa->operands[1] = Operand(b | (b & 0x80u ? 0xffffff00u : 0x0u)); + } else { + sdwa.reset(create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)); + sdwa->operands[0] = op32; + } + sdwa->definitions[0] = dst; + sdwa->sel[0] = sdwa_udword; + sdwa->sel[1] = sdwa_udword; + sdwa->dst_sel = sdwa_ubyte; + sdwa->dst_preserve = true; + return insert(std::move(sdwa)); + } else if (op.bytes() == 2 && op.isConstant() && !op.isLiteral()) { + aco_ptr sdwa{create_instruction(aco_opcode::v_add_f16, asSDWA(Format::VOP2), 2, 1)}; + sdwa->operands[0] = op; + sdwa->operands[1] = Operand(0u); + sdwa->definitions[0] = dst; + sdwa->sel[0] = sdwa_uword; + sdwa->sel[1] = sdwa_udword; + sdwa->dst_sel = dst.bytes() == 1 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_preserve = true; + return insert(std::move(sdwa)); + } else if (dst.regClass().is_subdword()) { + if (program->chip_class >= GFX8) { + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = op; + sdwa->definitions[0] = dst; + sdwa->sel[0] = op.bytes() == 1 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_sel = dst.bytes() == 1 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_preserve = true; + return insert(std::move(sdwa)); + } else { + return vop1(aco_opcode::v_mov_b32, dst, op); + } } else { - assert(dst.regClass() == s1); - return sop1(aco_opcode::s_mov_b32, dst, op); + unreachable("Unhandled case in bld.copy()"); } } @@ -392,7 +462,7 @@ public: if (!carry_in.op.isUndefined()) return vop2(aco_opcode::v_addc_co_u32, Definition(dst), hint_vcc(def(lm)), a, b, carry_in); else if (program->chip_class >= GFX10 && carry_out) - return vop3(aco_opcode::v_add_co_u32_e64, Definition(dst), def(s2), a, b); + return vop3(aco_opcode::v_add_co_u32_e64, Definition(dst), def(lm), a, b); else if (program->chip_class < GFX9 || carry_out) return vop2(aco_opcode::v_add_co_u32, Definition(dst), hint_vcc(def(lm)), a, b); else @@ -472,13 +542,15 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod ("smem", [Format.SMEM], 'SMEM_instruction', [(0, 4), (0, 3), (1, 0), (1, 3), (1, 2), (0, 0)]), ("ds", [Format.DS], 'DS_instruction', [(1, 1), (1, 2), (0, 3), (0, 4)]), ("mubuf", [Format.MUBUF], 'MUBUF_instruction', [(0, 4), (1, 3)]), - ("mimg", [Format.MIMG], 'MIMG_instruction', [(0, 4), (1, 3), (0, 3), (1, 2)]), #TODO(pendingchaos): less shapes? + ("mtbuf", [Format.MTBUF], 'MTBUF_instruction', [(0, 4), (1, 3)]), + ("mimg", [Format.MIMG], 'MIMG_instruction', [(0, 3), (1, 3)]), ("exp", [Format.EXP], 'Export_instruction', [(0, 4)]), ("branch", [Format.PSEUDO_BRANCH], 'Pseudo_branch_instruction', itertools.product([0], [0, 1])), ("barrier", [Format.PSEUDO_BARRIER], 'Pseudo_barrier_instruction', [(0, 0)]), - ("reduction", [Format.PSEUDO_REDUCTION], 'Pseudo_reduction_instruction', [(3, 2), (3, 4)]), + ("reduction", [Format.PSEUDO_REDUCTION], 'Pseudo_reduction_instruction', [(3, 2)]), ("vop1", [Format.VOP1], 'VOP1_instruction', [(1, 1), (2, 2)]), ("vop2", [Format.VOP2], 'VOP2_instruction', itertools.product([1, 2], [2, 3])), + ("vop2_sdwa", [Format.VOP2, Format.SDWA], 'SDWA_instruction', itertools.product([1, 2], [2, 3])), ("vopc", [Format.VOPC], 'VOPC_instruction', itertools.product([1, 2], [2])), ("vop3", [Format.VOP3A], 'VOP3A_instruction', [(1, 3), (1, 2), (1, 1), (2, 2)]), ("vintrp", [Format.VINTRP], 'Interp_instruction', [(1, 2), (1, 3)]), @@ -490,8 +562,9 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod ("vopc_e64", [Format.VOPC, Format.VOP3A], 'VOP3A_instruction', itertools.product([1, 2], [2])), ("flat", [Format.FLAT], 'FLAT_instruction', [(0, 3), (1, 2)]), ("global", [Format.GLOBAL], 'FLAT_instruction', [(0, 3), (1, 2)])] +formats = [(f if len(f) == 5 else f + ('',)) for f in formats] %>\\ -% for name, formats, struct, shapes in formats: +% for name, formats, struct, shapes, extra_field_setup in formats: % for num_definitions, num_operands in shapes: <% args = ['aco_opcode opcode'] @@ -508,6 +581,8 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod ${struct} *instr = create_instruction<${struct}>(opcode, (Format)(${'|'.join('(int)Format::%s' % f.name for f in formats)}), ${num_operands}, ${num_definitions}); % for i in range(num_definitions): instr->definitions[${i}] = def${i}; + instr->definitions[${i}].setPrecise(is_precise); + instr->definitions[${i}].setNUW(is_nuw); % endfor % for i in range(num_operands): instr->operands[${i}] = op${i}.op; @@ -516,7 +591,9 @@ formats = [("pseudo", [Format.PSEUDO], 'Pseudo_instruction', list(itertools.prod % for dest, field_name in zip(f.get_builder_field_dests(), f.get_builder_field_names()): instr->${dest} = ${field_name}; % endfor + ${f.get_builder_initialization(num_operands)} % endfor + ${extra_field_setup} return insert(instr); }