X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Famd%2Fcompiler%2Faco_instruction_selection.cpp;h=31a5e410e9f8042222107f567ea2f4d8a729ceda;hb=b497b774a5008c5c424b05cdbc3f4e96a6765912;hp=480aecf42f19347692be6cfa54594b2b2533d697;hpb=55537ed9d3e8869eaa9890a254ab35f7ce530ae1;p=mesa.git diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 480aecf42f1..31a5e410e9f 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -561,16 +561,8 @@ void emit_vop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o Temp t = src0; src0 = src1; src1 = t; - } else if (src0.type() == RegType::vgpr && - op != aco_opcode::v_madmk_f32 && - op != aco_opcode::v_madak_f32 && - op != aco_opcode::v_madmk_f16 && - op != aco_opcode::v_madak_f16) { - /* If the instruction is not commutative, we emit a VOP3A instruction */ - bld.vop2_e64(op, Definition(dst), src0, src1); - return; } else { - src1 = bld.copy(bld.def(RegType::vgpr, src1.size()), src1); //TODO: as_vgpr + src1 = as_vgpr(ctx, src1); } } @@ -626,6 +618,24 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o if (src0.type() == RegType::vgpr) { /* to swap the operands, we might also have to change the opcode */ switch (op) { + case aco_opcode::v_cmp_lt_f16: + op = aco_opcode::v_cmp_gt_f16; + break; + case aco_opcode::v_cmp_ge_f16: + op = aco_opcode::v_cmp_le_f16; + break; + case aco_opcode::v_cmp_lt_i16: + op = aco_opcode::v_cmp_gt_i16; + break; + case aco_opcode::v_cmp_ge_i16: + op = aco_opcode::v_cmp_le_i16; + break; + case aco_opcode::v_cmp_lt_u16: + op = aco_opcode::v_cmp_gt_u16; + break; + case aco_opcode::v_cmp_ge_u16: + op = aco_opcode::v_cmp_le_u16; + break; case aco_opcode::v_cmp_lt_f32: op = aco_opcode::v_cmp_gt_f32; break; @@ -695,10 +705,10 @@ void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o } void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst, - aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::num_opcodes, aco_opcode s64_op = aco_opcode::num_opcodes) + aco_opcode v16_op, aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::num_opcodes, aco_opcode s64_op = aco_opcode::num_opcodes) { - aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op; - aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op; + aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : instr->src[0].src.ssa->bit_size == 32 ? s32_op : aco_opcode::num_opcodes; + aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : instr->src[0].src.ssa->bit_size == 32 ? v32_op : v16_op; bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index]; bool use_valu = s_op == aco_opcode::num_opcodes || divergent_vals || @@ -932,6 +942,58 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val) return add->definitions[0].getTemp(); } +Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) { + if (!dst.id()) { + if (dst_bits % 32 == 0 || src.type() == RegType::sgpr) + dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u)); + else + dst = bld.tmp(RegClass(RegType::vgpr, dst_bits / 8u).as_subdword()); + } + + if (dst.bytes() == src.bytes() && dst_bits < src_bits) + return bld.copy(Definition(dst), src); + else if (dst.bytes() < src.bytes()) + return bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); + + Temp tmp = dst; + if (dst_bits == 64) + tmp = src_bits == 32 ? src : bld.tmp(src.type(), 1); + + if (tmp == src) { + } else if (src.regClass() == s1) { + if (is_signed) + bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src); + else + bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src); + } else { + assert(src_bits != 8 || src.regClass() == v1b); + assert(src_bits != 16 || src.regClass() == v2b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(tmp); + if (is_signed) + sdwa->sel[0] = src_bits == 8 ? sdwa_sbyte : sdwa_sword; + else + sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword; + bld.insert(std::move(sdwa)); + } + + if (dst_bits == 64) { + if (is_signed && dst.regClass() == s2) { + Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), tmp, Operand(31u)); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else if (is_signed && dst.regClass() == v2) { + Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), tmp); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else { + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, Operand(0u)); + } + } + + return dst; +} + void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) { if (!instr->dest.dest.is_ssa) { @@ -1569,7 +1631,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fsub: { Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); + Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v2b) { Temp tmp = bld.tmp(v1); if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr) @@ -1584,7 +1646,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f32, dst, true); } else if (dst.regClass() == v2) { Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst), - src0, src1); + as_vgpr(ctx, src0), as_vgpr(ctx, src1)); VOP3A_instruction* sub = static_cast(add); sub->neg[1] = true; } else { @@ -1643,7 +1705,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmax3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1653,7 +1719,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmin3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1663,7 +1733,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmed3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1823,8 +1897,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fsat: { Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - Temp one = bld.copy(bld.def(s1), Operand(0x3c00u)); - Temp tmp = bld.vop3(aco_opcode::v_med3_f16, bld.def(v1), Operand(0u), one, src); + Temp tmp = bld.vop3(aco_opcode::v_med3_f16, bld.def(v1), Operand(0u), Operand(0x3f800000u), src); bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); } else if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src); @@ -2048,14 +2121,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ldexp: { - if (dst.size() == 1) { - bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), - as_vgpr(ctx, get_alu_src(ctx, instr->src[0])), - get_alu_src(ctx, instr->src[1])); - } else if (dst.size() == 2) { - bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), - as_vgpr(ctx, get_alu_src(ctx, instr->src[0])), - get_alu_src(ctx, instr->src[1])); + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), as_vgpr(ctx, src0), src1); + } else if (dst.regClass() == v2) { + bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), as_vgpr(ctx, src0), src1); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -2083,7 +2158,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 16) { Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src); - bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), tmp, Operand(0u)); + tmp = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v1b), tmp, Operand(0u)); + convert_int(bld, tmp, 8, 32, true, dst); } else if (instr->src[0].src.ssa->bit_size == 32) { bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { @@ -2163,14 +2239,29 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src); break; } + case nir_op_i2f16: { + assert(dst.regClass() == v2b); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, true); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_i16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + break; + } case nir_op_i2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_i32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src); break; } case nir_op_i2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_i32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f64_i32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -2188,14 +2279,34 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_u2f16: { + assert(dst.regClass() == v2b); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, false); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + break; + } case nir_op_u2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_u32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) { + //TODO: we should use v_cvt_f32_ubyte1/v_cvt_f32_ubyte2/etc depending on the register assignment + bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src); + } else { + if (instr->src[0].src.ssa->bit_size == 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src); + } break; } case nir_op_u2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_u32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, false); + bld.vop1(aco_opcode::v_cvt_f64_u32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -2212,6 +2323,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_f2i8: case nir_op_f2i16: { Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 16) @@ -2222,11 +2334,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) src = bld.vop1(aco_opcode::v_cvt_i32_f64, bld.def(v1), src); if (dst.type() == RegType::vgpr) - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); else bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); break; } + case nir_op_f2u8: case nir_op_f2u16: { Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 16) @@ -2237,7 +2350,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) src = bld.vop1(aco_opcode::v_cvt_u32_f64, bld.def(v1), src); if (dst.type() == RegType::vgpr) - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); else bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); break; @@ -2306,7 +2419,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_f2i64: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::vgpr) { + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + + if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::vgpr) { Temp exponent = bld.vop1(aco_opcode::v_frexp_exp_i32_f32, bld.def(v1), src); exponent = bld.vop3(aco_opcode::v_med3_i32, bld.def(v1), Operand(0x0u), exponent, Operand(64u)); Temp mantissa = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7fffffu), src); @@ -2332,13 +2448,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp new_upper = bld.vsub32(bld.def(v1), upper, sign, false, borrow); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), new_lower, new_upper); - } else if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::sgpr) { + } else if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::sgpr) { if (src.type() == RegType::vgpr) src = bld.as_uniform(src); Temp exponent = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), src, Operand(0x80017u)); - exponent = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); - exponent = bld.sop2(aco_opcode::s_max_u32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); - exponent = bld.sop2(aco_opcode::s_min_u32, bld.def(s1), bld.def(s1, scc), Operand(64u), exponent); + exponent = bld.sop2(aco_opcode::s_sub_i32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); + exponent = bld.sop2(aco_opcode::s_max_i32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); + exponent = bld.sop2(aco_opcode::s_min_i32, bld.def(s1), bld.def(s1, scc), Operand(64u), exponent); Temp mantissa = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0x7fffffu), src); Temp sign = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), src, Operand(31u)); mantissa = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), Operand(0x800000u), mantissa); @@ -2382,7 +2498,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_f2u64: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::vgpr) { + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + + if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::vgpr) { Temp exponent = bld.vop1(aco_opcode::v_frexp_exp_i32_f32, bld.def(v1), src); Temp exponent_in_range = bld.vopc(aco_opcode::v_cmp_ge_i32, bld.hint_vcc(bld.def(bld.lm)), Operand(64u), exponent); exponent = bld.vop2(aco_opcode::v_max_i32, bld.def(v1), Operand(0x0u), exponent); @@ -2405,12 +2524,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) upper = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0xffffffffu), upper, exponent_in_range); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lower, upper); - } else if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::sgpr) { + } else if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::sgpr) { if (src.type() == RegType::vgpr) src = bld.as_uniform(src); Temp exponent = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), src, Operand(0x80017u)); - exponent = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); - exponent = bld.sop2(aco_opcode::s_max_u32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); + exponent = bld.sop2(aco_opcode::s_sub_i32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); + exponent = bld.sop2(aco_opcode::s_max_i32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); Temp mantissa = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0x7fffffu), src); mantissa = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), Operand(0x800000u), mantissa); Temp exponent_small = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), Operand(24u), exponent); @@ -2449,6 +2568,22 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_b2f16: { + Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == bld.lm); + + if (dst.regClass() == s1) { + src = bool_to_scalar_condition(ctx, src); + bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3c00u), src); + } else if (dst.regClass() == v2b) { + Temp one = bld.copy(bld.def(v1), Operand(0x3c00u)); + Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else { + unreachable("Wrong destination register class for nir_op_b2f16."); + } + break; + } case nir_op_b2f32: { Temp src = get_alu_src(ctx, instr->src[0]); assert(src.regClass() == bld.lm); @@ -2480,159 +2615,19 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_i2i8: - case nir_op_u2u8: { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - break; - } - case nir_op_i2i16: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sbyte; - sdwa->dst_sel = sdwa_sword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - } - break; - } - case nir_op_u2u16: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); - else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_ubyte; - sdwa->dst_sel = sdwa_uword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - } - break; - } - case nir_op_i2i32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sbyte; - sdwa->dst_sel = sdwa_sdword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 16) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i16, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v2b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sword; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } - case nir_op_u2u32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); - else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_ubyte; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 16) { - if (dst.regClass() == s1) { - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src); - } else { - assert(src.regClass() == v2b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_uword; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } + case nir_op_i2i16: + case nir_op_i2i32: case nir_op_i2i64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (src.regClass() == s1) { - Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), src, Operand(31u)); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else if (src.regClass() == v1) { - Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), src); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst); break; } + case nir_op_u2u8: + case nir_op_u2u16: + case nir_op_u2u32: case nir_op_u2u64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32) { - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, Operand(0u)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst); break; } case nir_op_b2b32: @@ -2697,13 +2692,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.type() == RegType::vgpr) { bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); } else { - bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16))); + bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), bld.def(s1, scc), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16))); } break; case nir_op_pack_32_2x16_split: { Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); if (dst.regClass() == v1) { + src0 = emit_extract_vector(ctx, src0, 0, v2b); + src1 = emit_extract_vector(ctx, src1, 0, v2b); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1); } else { src0 = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), src0, Operand(0xFFFFu)); @@ -2920,34 +2917,34 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flt: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f16, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); break; } case nir_op_fge: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f16, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); break; } case nir_op_feq: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f16, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); break; } case nir_op_fne: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f16, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); break; } case nir_op_ilt: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i16, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); break; } case nir_op_ige: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i16, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); break; } case nir_op_ieq: { if (instr->src[0].src.ssa->bit_size == 1) emit_boolean_logic(ctx, instr, Builder::s_xnor, dst); else - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i16, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, ctx->program->chip_class >= GFX8 ? aco_opcode::s_cmp_eq_u64 : aco_opcode::num_opcodes); break; } @@ -2955,16 +2952,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (instr->src[0].src.ssa->bit_size == 1) emit_boolean_logic(ctx, instr, Builder::s_xor, dst); else - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i16, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, ctx->program->chip_class >= GFX8 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::num_opcodes); break; } case nir_op_ult: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u16, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); break; } case nir_op_uge: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u16, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); break; } case nir_op_fddx: @@ -3025,6 +3022,12 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr) int val = instr->value[0].b ? -1 : 0; Operand op = bld.lm.size() == 1 ? Operand((uint32_t) val) : Operand((uint64_t) val); bld.sop1(Builder::s_mov, Definition(dst), op); + } else if (instr->def.bit_size == 8) { + /* ensure that the value is correctly represented in the low byte of the register */ + bld.sopk(aco_opcode::s_movk_i32, Definition(dst), instr->value[0].u8); + } else if (instr->def.bit_size == 16) { + /* ensure that the value is correctly represented in the low half of the register */ + bld.sopk(aco_opcode::s_movk_i32, Definition(dst), instr->value[0].u16); } else if (dst.size() == 1) { bld.copy(Definition(dst), Operand(instr->value[0].u32)); } else { @@ -3765,11 +3768,8 @@ bool load_input_from_temps(isel_context *ctx, nir_intrinsic_instr *instr, Temp d unsigned idx = nir_intrinsic_base(instr) + nir_intrinsic_component(instr) + 4 * nir_src_as_uint(*off_src); Temp *src = &ctx->inputs.temps[idx]; - Temp vec = create_vec_from_array(ctx, src, dst.size(), dst.regClass().type(), 4u); - assert(vec.size() == dst.size()); + create_vec_from_array(ctx, src, dst.size(), dst.regClass().type(), 4u, 0, dst); - Builder bld(ctx->program, ctx->block); - bld.copy(Definition(dst), vec); return true; } @@ -8378,7 +8378,7 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr) if (instr->sampler_dim == GLSL_SAMPLER_DIM_1D && ctx->options->chip_class == GFX9) { assert(has_ddx && has_ddy && ddx.size() == 1 && ddy.size() == 1); Temp zero = bld.copy(bld.def(v1), Operand(0u)); - derivs = {ddy, zero, ddy, zero}; + derivs = {ddx, zero, ddy, zero}; } else { for (unsigned i = 0; has_ddx && i < ddx.size(); i++) derivs.emplace_back(emit_extract_vector(ctx, ddx, i, v1)); @@ -9758,7 +9758,8 @@ static void create_vs_exports(isel_context *ctx) for (unsigned i = 0; i <= VARYING_SLOT_VAR31; ++i) { if (i < VARYING_SLOT_VAR0 && i != VARYING_SLOT_LAYER && - i != VARYING_SLOT_PRIMITIVE_ID) + i != VARYING_SLOT_PRIMITIVE_ID && + i != VARYING_SLOT_VIEWPORT) continue; export_vs_varying(ctx, i, false, NULL); @@ -10117,7 +10118,7 @@ static void emit_stream_output(isel_context *ctx, Temp out[4]; bool all_undef = true; - assert(ctx->stage == vertex_vs || ctx->stage == gs_copy_vs); + assert(ctx->stage & hw_vs); for (unsigned i = 0; i < num_comps; i++) { out[i] = ctx->outputs.temps[loc * 4 + start + i]; all_undef = all_undef && !out[i].id();