X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Famd%2Fcompiler%2Faco_instruction_selection.cpp;h=31a5e410e9f8042222107f567ea2f4d8a729ceda;hb=b497b774a5008c5c424b05cdbc3f4e96a6765912;hp=716853d23ce998dcf8a5c8c7f1325babb8a644b2;hpb=798dd98d6e530afc5dab2f973785fbbd4e598dee;p=mesa.git diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 716853d23ce..31a5e410e9f 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -85,6 +85,7 @@ struct if_context { unsigned BB_if_idx; unsigned invert_idx; + bool uniform_has_then_branch; bool then_branch_divergent; Block BB_invert; Block BB_endif; @@ -270,21 +271,24 @@ Temp emit_extract_vector(isel_context* ctx, Temp src, uint32_t idx, RegClass dst assert(idx == 0); return src; } - assert(src.size() > idx); + + assert(src.bytes() > (idx * dst_rc.bytes())); Builder bld(ctx->program, ctx->block); auto it = ctx->allocated_vec.find(src.id()); - /* the size check needs to be early because elements other than 0 may be garbage */ - if (it != ctx->allocated_vec.end() && it->second[0].size() == dst_rc.size()) { + if (it != ctx->allocated_vec.end() && dst_rc.bytes() == it->second[idx].regClass().bytes()) { if (it->second[idx].regClass() == dst_rc) { return it->second[idx]; } else { - assert(dst_rc.size() == it->second[idx].regClass().size()); + assert(!dst_rc.is_subdword()); assert(dst_rc.type() == RegType::vgpr && it->second[idx].type() == RegType::sgpr); return bld.copy(bld.def(dst_rc), it->second[idx]); } } - if (src.size() == dst_rc.size()) { + if (dst_rc.is_subdword()) + src = as_vgpr(ctx, src); + + if (src.bytes() == dst_rc.bytes()) { assert(idx == 0); return bld.copy(bld.def(dst_rc), src); } else { @@ -303,8 +307,19 @@ void emit_split_vector(isel_context* ctx, Temp vec_src, unsigned num_components) aco_ptr split{create_instruction(aco_opcode::p_split_vector, Format::PSEUDO, 1, num_components)}; split->operands[0] = Operand(vec_src); std::array elems; + RegClass rc; + if (num_components > vec_src.size()) { + if (vec_src.type() == RegType::sgpr) + return; + + /* sub-dword split */ + assert(vec_src.type() == RegType::vgpr); + rc = RegClass(RegType::vgpr, vec_src.bytes() / num_components).as_subdword(); + } else { + rc = RegClass(vec_src.type(), vec_src.size() / num_components); + } for (unsigned i = 0; i < num_components; i++) { - elems[i] = {ctx->program->allocateId(), RegClass(vec_src.type(), vec_src.size() / num_components)}; + elems[i] = {ctx->program->allocateId(), rc}; split->definitions[i] = Definition(elems[i]); } ctx->block->instructions.emplace_back(std::move(split)); @@ -350,6 +365,82 @@ void expand_vector(isel_context* ctx, Temp vec_src, Temp dst, unsigned num_compo ctx->allocated_vec.emplace(dst.id(), elems); } +/* adjust misaligned small bit size loads */ +void byte_align_scalar(isel_context *ctx, Temp vec, Operand offset, Temp dst) +{ + Builder bld(ctx->program, ctx->block); + Operand shift; + Temp select = Temp(); + if (offset.isConstant()) { + assert(offset.constantValue() && offset.constantValue() < 4); + shift = Operand(offset.constantValue() * 8); + } else { + /* bit_offset = 8 * (offset & 0x3) */ + Temp tmp = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), offset, Operand(3u)); + select = bld.tmp(s1); + shift = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.scc(Definition(select)), tmp, Operand(3u)); + } + + if (vec.size() == 1) { + bld.sop2(aco_opcode::s_lshr_b32, Definition(dst), bld.def(s1, scc), vec, shift); + } else if (vec.size() == 2) { + Temp tmp = dst.size() == 2 ? dst : bld.tmp(s2); + bld.sop2(aco_opcode::s_lshr_b64, Definition(tmp), bld.def(s1, scc), vec, shift); + if (tmp == dst) + emit_split_vector(ctx, dst, 2); + else + emit_extract_vector(ctx, tmp, 0, dst); + } else if (vec.size() == 4) { + Temp lo = bld.tmp(s2), hi = bld.tmp(s2); + bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), vec); + hi = bld.pseudo(aco_opcode::p_extract_vector, bld.def(s1), hi, Operand(0u)); + if (select != Temp()) + hi = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), hi, Operand(0u), select); + lo = bld.sop2(aco_opcode::s_lshr_b64, bld.def(s2), bld.def(s1, scc), lo, shift); + Temp mid = bld.tmp(s1); + lo = bld.pseudo(aco_opcode::p_split_vector, bld.def(s1), Definition(mid), lo); + hi = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), hi, shift); + mid = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), hi, mid); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lo, mid); + emit_split_vector(ctx, dst, 2); + } +} + +/* this function trims subdword vectors: + * if dst is vgpr - split the src and create a shrunk version according to the mask. + * if dst is sgpr - split the src, but move the original to sgpr. */ +void trim_subdword_vector(isel_context *ctx, Temp vec_src, Temp dst, unsigned num_components, unsigned mask) +{ + assert(vec_src.type() == RegType::vgpr); + emit_split_vector(ctx, vec_src, num_components); + + Builder bld(ctx->program, ctx->block); + std::array elems; + unsigned component_size = vec_src.bytes() / num_components; + RegClass rc = RegClass(RegType::vgpr, component_size).as_subdword(); + + unsigned k = 0; + for (unsigned i = 0; i < num_components; i++) { + if (mask & (1 << i)) + elems[k++] = emit_extract_vector(ctx, vec_src, i, rc); + } + + if (dst.type() == RegType::vgpr) { + assert(dst.bytes() == k * component_size); + aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, k, 1)}; + for (unsigned i = 0; i < k; i++) + vec->operands[i] = Operand(elems[i]); + vec->definitions[0] = Definition(dst); + bld.insert(std::move(vec)); + } else { + // TODO: alignbyte if mask doesn't start with 1? + assert(mask & 1); + assert(dst.size() == vec_src.size()); + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), vec_src); + } + ctx->allocated_vec.emplace(dst.id(), elems); +} + Temp bool_to_vector_condition(isel_context *ctx, Temp val, Temp dst = Temp(0, s2)) { Builder bld(ctx->program, ctx->block); @@ -393,11 +484,32 @@ Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1) } Temp vec = get_ssa_temp(ctx, src.src.ssa); - unsigned elem_size = vec.size() / src.src.ssa->num_components; - assert(elem_size > 0); /* TODO: 8 and 16-bit vectors not supported */ - assert(vec.size() % elem_size == 0); + unsigned elem_size = vec.bytes() / src.src.ssa->num_components; + assert(elem_size > 0); + assert(vec.bytes() % elem_size == 0); + + if (elem_size < 4 && vec.type() == RegType::sgpr) { + assert(src.src.ssa->bit_size == 8 || src.src.ssa->bit_size == 16); + assert(size == 1); + unsigned swizzle = src.swizzle[0]; + if (vec.size() > 1) { + assert(src.src.ssa->bit_size == 16); + vec = emit_extract_vector(ctx, vec, swizzle / 2, s1); + swizzle = swizzle & 1; + } + if (swizzle == 0) + return vec; + + Temp dst{ctx->program->allocateId(), s1}; + aco_ptr bfe{create_instruction(aco_opcode::s_bfe_u32, Format::SOP2, 2, 1)}; + bfe->operands[0] = Operand(vec); + bfe->operands[1] = Operand(uint32_t((src.src.ssa->bit_size << 16) | (src.src.ssa->bit_size * swizzle))); + bfe->definitions[0] = Definition(dst); + ctx->block->instructions.emplace_back(std::move(bfe)); + return dst; + } - RegClass elem_rc = RegClass(vec.type(), elem_size); + RegClass elem_rc = elem_size < 4 ? RegClass(vec.type(), elem_size).as_subdword() : RegClass(vec.type(), elem_size / 4); if (size == 1) { return emit_extract_vector(ctx, vec, src.swizzle[0], elem_rc); } else { @@ -408,7 +520,7 @@ Temp get_alu_src(struct isel_context *ctx, nir_alu_src src, unsigned size=1) elems[i] = emit_extract_vector(ctx, vec, src.swizzle[i], elem_rc); vec_instr->operands[i] = Operand{elems[i]}; } - Temp dst{ctx->program->allocateId(), RegClass(vec.type(), elem_size * size)}; + Temp dst{ctx->program->allocateId(), RegClass(vec.type(), elem_size * size / 4)}; vec_instr->definitions[0] = Definition(dst); ctx->block->instructions.emplace_back(std::move(vec_instr)); ctx->allocated_vec.emplace(dst.id(), elems); @@ -449,16 +561,8 @@ void emit_vop2_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o Temp t = src0; src0 = src1; src1 = t; - } else if (src0.type() == RegType::vgpr && - op != aco_opcode::v_madmk_f32 && - op != aco_opcode::v_madak_f32 && - op != aco_opcode::v_madmk_f16 && - op != aco_opcode::v_madak_f16) { - /* If the instruction is not commutative, we emit a VOP3A instruction */ - bld.vop2_e64(op, Definition(dst), src0, src1); - return; } else { - src1 = bld.copy(bld.def(RegType::vgpr, src1.size()), src1); //TODO: as_vgpr + src1 = as_vgpr(ctx, src1); } } @@ -514,6 +618,24 @@ void emit_vopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o if (src0.type() == RegType::vgpr) { /* to swap the operands, we might also have to change the opcode */ switch (op) { + case aco_opcode::v_cmp_lt_f16: + op = aco_opcode::v_cmp_gt_f16; + break; + case aco_opcode::v_cmp_ge_f16: + op = aco_opcode::v_cmp_le_f16; + break; + case aco_opcode::v_cmp_lt_i16: + op = aco_opcode::v_cmp_gt_i16; + break; + case aco_opcode::v_cmp_ge_i16: + op = aco_opcode::v_cmp_le_i16; + break; + case aco_opcode::v_cmp_lt_u16: + op = aco_opcode::v_cmp_gt_u16; + break; + case aco_opcode::v_cmp_ge_u16: + op = aco_opcode::v_cmp_le_u16; + break; case aco_opcode::v_cmp_lt_f32: op = aco_opcode::v_cmp_gt_f32; break; @@ -583,10 +705,10 @@ void emit_sopc_instruction(isel_context *ctx, nir_alu_instr *instr, aco_opcode o } void emit_comparison(isel_context *ctx, nir_alu_instr *instr, Temp dst, - aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::num_opcodes, aco_opcode s64_op = aco_opcode::num_opcodes) + aco_opcode v16_op, aco_opcode v32_op, aco_opcode v64_op, aco_opcode s32_op = aco_opcode::num_opcodes, aco_opcode s64_op = aco_opcode::num_opcodes) { - aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : s32_op; - aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : v32_op; + aco_opcode s_op = instr->src[0].src.ssa->bit_size == 64 ? s64_op : instr->src[0].src.ssa->bit_size == 32 ? s32_op : aco_opcode::num_opcodes; + aco_opcode v_op = instr->src[0].src.ssa->bit_size == 64 ? v64_op : instr->src[0].src.ssa->bit_size == 32 ? v32_op : v16_op; bool divergent_vals = ctx->divergent_vals[instr->dest.dest.ssa.index]; bool use_valu = s_op == aco_opcode::num_opcodes || divergent_vals || @@ -626,12 +748,18 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst) if (dst.type() == RegType::vgpr) { aco_ptr bcsel; - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + then = as_vgpr(ctx, then); + els = as_vgpr(ctx, els); + + Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), els, then, cond); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { then = as_vgpr(ctx, then); els = as_vgpr(ctx, els); bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), els, then, cond); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { Temp then_lo = bld.tmp(v1), then_hi = bld.tmp(v1); bld.pseudo(aco_opcode::p_split_vector, Definition(then_lo), Definition(then_hi), then); Temp else_lo = bld.tmp(v1), else_hi = bld.tmp(v1); @@ -814,6 +942,58 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val) return add->definitions[0].getTemp(); } +Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) { + if (!dst.id()) { + if (dst_bits % 32 == 0 || src.type() == RegType::sgpr) + dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u)); + else + dst = bld.tmp(RegClass(RegType::vgpr, dst_bits / 8u).as_subdword()); + } + + if (dst.bytes() == src.bytes() && dst_bits < src_bits) + return bld.copy(Definition(dst), src); + else if (dst.bytes() < src.bytes()) + return bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); + + Temp tmp = dst; + if (dst_bits == 64) + tmp = src_bits == 32 ? src : bld.tmp(src.type(), 1); + + if (tmp == src) { + } else if (src.regClass() == s1) { + if (is_signed) + bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src); + else + bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src); + } else { + assert(src_bits != 8 || src.regClass() == v1b); + assert(src_bits != 16 || src.regClass() == v2b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(tmp); + if (is_signed) + sdwa->sel[0] = src_bits == 8 ? sdwa_sbyte : sdwa_sword; + else + sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword; + bld.insert(std::move(sdwa)); + } + + if (dst_bits == 64) { + if (is_signed && dst.regClass() == s2) { + Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), tmp, Operand(31u)); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else if (is_signed && dst.regClass() == v2) { + Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), tmp); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else { + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, Operand(0u)); + } + } + + return dst; +} + void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) { if (!instr->dest.dest.is_ssa) { @@ -829,14 +1009,38 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_vec3: case nir_op_vec4: { std::array elems; - aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, instr->dest.dest.ssa.num_components, 1)}; - for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i) { + unsigned num = instr->dest.dest.ssa.num_components; + for (unsigned i = 0; i < num; ++i) elems[i] = get_alu_src(ctx, instr->src[i]); - vec->operands[i] = Operand{elems[i]}; + + if (instr->dest.dest.ssa.bit_size >= 32 || dst.type() == RegType::vgpr) { + aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, instr->dest.dest.ssa.num_components, 1)}; + for (unsigned i = 0; i < num; ++i) + vec->operands[i] = Operand{elems[i]}; + vec->definitions[0] = Definition(dst); + ctx->block->instructions.emplace_back(std::move(vec)); + ctx->allocated_vec.emplace(dst.id(), elems); + } else { + // TODO: that is a bit suboptimal.. + Temp mask = bld.copy(bld.def(s1), Operand((1u << instr->dest.dest.ssa.bit_size) - 1)); + for (unsigned i = 0; i < num - 1; ++i) + if (((i+1) * instr->dest.dest.ssa.bit_size) % 32) + elems[i] = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), elems[i], mask); + for (unsigned i = 0; i < num; ++i) { + unsigned bit = i * instr->dest.dest.ssa.bit_size; + if (bit % 32 == 0) { + elems[bit / 32] = elems[i]; + } else { + elems[i] = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), + elems[i], Operand((i * instr->dest.dest.ssa.bit_size) % 32)); + elems[bit / 32] = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), elems[bit / 32], elems[i]); + } + } + if (dst.size() == 1) + bld.copy(Definition(dst), elems[0]); + else + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), elems[0], elems[1]); } - vec->definitions[0] = Definition(dst); - ctx->block->instructions.emplace_back(std::move(vec)); - ctx->allocated_vec.emplace(dst.id(), elems); break; } case nir_op_mov: { @@ -1390,11 +1594,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmul: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, tmp, true); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); - } else if (dst.size() == 2) { - bld.vop3(aco_opcode::v_mul_f64, Definition(dst), get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + } else if (dst.regClass() == v2) { + bld.vop3(aco_opcode::v_mul_f64, Definition(dst), src0, src1); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1403,11 +1612,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fadd: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, tmp, true); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); - } else if (dst.size() == 2) { - bld.vop3(aco_opcode::v_add_f64, Definition(dst), get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + } else if (dst.regClass() == v2) { + bld.vop3(aco_opcode::v_add_f64, Definition(dst), src0, src1); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1418,15 +1632,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) case nir_op_fsub: { Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr) + emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f16, tmp, false); + else + emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f16, tmp, true); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { if (src1.type() == RegType::vgpr || src0.type() != RegType::vgpr) emit_vop2_instruction(ctx, instr, aco_opcode::v_sub_f32, dst, false); else emit_vop2_instruction(ctx, instr, aco_opcode::v_subrev_f32, dst, true); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst), - get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + as_vgpr(ctx, src0), as_vgpr(ctx, src1)); VOP3A_instruction* sub = static_cast(add); sub->neg[1] = true; } else { @@ -1437,18 +1657,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmax: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); + if (dst.regClass() == v2b) { + // TODO: check fp_mode.must_flush_denorms16_64 + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, tmp, true); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { - Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2), - get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2), src0, src1); bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); } else { - bld.vop3(aco_opcode::v_max_f64, Definition(dst), - get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + bld.vop3(aco_opcode::v_max_f64, Definition(dst), src0, src1); } } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1458,18 +1681,21 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmin: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); + if (dst.regClass() == v2b) { + // TODO: check fp_mode.must_flush_denorms16_64 + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, tmp, true); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { - Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2), - get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2), src0, src1); bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); } else { - bld.vop3(aco_opcode::v_min_f64, Definition(dst), - get_alu_src(ctx, instr->src[0]), - as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); + bld.vop3(aco_opcode::v_min_f64, Definition(dst), src0, src1); } } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1479,7 +1705,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmax3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1489,7 +1719,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmin3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1499,7 +1733,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmed3: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1595,9 +1833,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frsq: { - if (dst.size() == 1) { - emit_rsq(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); - } else if (dst.size() == 2) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_rsq_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + emit_rsq(ctx, bld, Definition(dst), src); + } else if (dst.regClass() == v2) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1608,11 +1850,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fneg: { Temp src = get_alu_src(ctx, instr->src[0]); - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.vop2(aco_opcode::v_xor_b32, bld.def(v1), Operand(0x8000u), as_vgpr(ctx, src)); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { if (ctx->block->fp_mode.must_flush_denorms32) src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src)); bld.vop2(aco_opcode::v_xor_b32, Definition(dst), Operand(0x80000000u), as_vgpr(ctx, src)); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->block->fp_mode.must_flush_denorms16_64) src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src)); Temp upper = bld.tmp(v1), lower = bld.tmp(v1); @@ -1628,11 +1873,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fabs: { Temp src = get_alu_src(ctx, instr->src[0]); - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7FFFu), as_vgpr(ctx, src)); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { if (ctx->block->fp_mode.must_flush_denorms32) src = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), Operand(0x3f800000u), as_vgpr(ctx, src)); bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0x7FFFFFFFu), as_vgpr(ctx, src)); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->block->fp_mode.must_flush_denorms16_64) src = bld.vop3(aco_opcode::v_mul_f64, bld.def(v2), Operand(0x3FF0000000000000lu), as_vgpr(ctx, src)); Temp upper = bld.tmp(v1), lower = bld.tmp(v1); @@ -1648,11 +1896,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fsat: { Temp src = get_alu_src(ctx, instr->src[0]); - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp tmp = bld.vop3(aco_opcode::v_med3_f16, bld.def(v1), Operand(0u), Operand(0x3f800000u), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_med3_f32, Definition(dst), Operand(0u), Operand(0x3f800000u), src); /* apparently, it is not necessary to flush denorms if this instruction is used with these operands */ // TODO: confirm that this holds under any circumstances - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { Instruction* add = bld.vop3(aco_opcode::v_add_f64, Definition(dst), src, Operand(0u)); VOP3A_instruction* vop3 = static_cast(add); vop3->clamp = true; @@ -1664,8 +1915,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flog2: { - if (dst.size() == 1) { - emit_log2(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_log_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + emit_log2(ctx, bld, Definition(dst), src); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1674,9 +1929,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frcp: { - if (dst.size() == 1) { - emit_rcp(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); - } else if (dst.size() == 2) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_rcp_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + emit_rcp(ctx, bld, Definition(dst), src); + } else if (dst.regClass() == v2) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1686,7 +1945,11 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fexp2: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp src = get_alu_src(ctx, instr->src[0]); + Temp tmp = bld.vop1(aco_opcode::v_exp_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_exp_f32, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1696,9 +1959,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsqrt: { - if (dst.size() == 1) { - emit_sqrt(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); - } else if (dst.size() == 2) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_sqrt_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + emit_sqrt(ctx, bld, Definition(dst), src); + } else if (dst.regClass() == v2) { emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1708,9 +1975,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ffract: { - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp src = get_alu_src(ctx, instr->src[0]); + Temp tmp = bld.vop1(aco_opcode::v_fract_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f32, dst); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { emit_vop1_instruction(ctx, instr, aco_opcode::v_fract_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1720,10 +1991,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ffloor: { - if (dst.size() == 1) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_floor_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f32, dst); - } else if (dst.size() == 2) { - emit_floor_f64(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (dst.regClass() == v2) { + emit_floor_f64(ctx, bld, Definition(dst), src); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1732,15 +2007,17 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fceil: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_ceil_f16, bld.def(v1), src0); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f32, dst); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->options->chip_class >= GFX7) { emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f64, dst); } else { /* GFX6 doesn't support V_CEIL_F64, lower it. */ - Temp src0 = get_alu_src(ctx, instr->src[0]); - /* trunc = trunc(src0) * if (src0 > 0.0 && src0 != trunc) * trunc += 1.0 @@ -1761,10 +2038,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ftrunc: { - if (dst.size() == 1) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_trunc_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f32, dst); - } else if (dst.size() == 2) { - emit_trunc_f64(ctx, bld, Definition(dst), get_alu_src(ctx, instr->src[0])); + } else if (dst.regClass() == v2) { + emit_trunc_f64(ctx, bld, Definition(dst), src); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1773,15 +2054,17 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fround_even: { - if (dst.size() == 1) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_rndne_f16, bld.def(v1), src0); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f32, dst); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { if (ctx->options->chip_class >= GFX7) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f64, dst); } else { /* GFX6 doesn't support V_RNDNE_F64, lower it. */ - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src0_lo = bld.tmp(v1), src0_hi = bld.tmp(v1); bld.pseudo(aco_opcode::p_split_vector, Definition(src0_lo), Definition(src0_hi), src0); @@ -1813,11 +2096,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fsin: case nir_op_fcos: { - Temp src = get_alu_src(ctx, instr->src[0]); + Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0])); aco_ptr norm; - if (dst.size() == 1) { - Temp half_pi = bld.copy(bld.def(s1), Operand(0x3e22f983u)); - Temp tmp = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), half_pi, as_vgpr(ctx, src)); + Temp half_pi = bld.copy(bld.def(s1), Operand(0x3e22f983u)); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop2(aco_opcode::v_mul_f16, bld.def(v1), half_pi, src); + aco_opcode opcode = instr->op == nir_op_fsin ? aco_opcode::v_sin_f16 : aco_opcode::v_cos_f16; + tmp = bld.vop1(opcode, bld.def(v1), tmp); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + Temp tmp = bld.vop2(aco_opcode::v_mul_f32, bld.def(v1), half_pi, src); /* before GFX9, v_sin_f32 and v_cos_f32 had a valid input domain of [-256, +256] */ if (ctx->options->chip_class < GFX9) @@ -1833,14 +2121,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ldexp: { - if (dst.size() == 1) { - bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), - as_vgpr(ctx, get_alu_src(ctx, instr->src[0])), - get_alu_src(ctx, instr->src[1])); - } else if (dst.size() == 2) { - bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), - as_vgpr(ctx, get_alu_src(ctx, instr->src[0])), - get_alu_src(ctx, instr->src[1])); + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + if (dst.regClass() == v2b) { + Temp tmp = bld.tmp(v1); + emit_vop2_instruction(ctx, instr, aco_opcode::v_ldexp_f16, tmp, false); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + bld.vop3(aco_opcode::v_ldexp_f32, Definition(dst), as_vgpr(ctx, src0), src1); + } else if (dst.regClass() == v2) { + bld.vop3(aco_opcode::v_ldexp_f64, Definition(dst), as_vgpr(ctx, src0), src1); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1849,12 +2139,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frexp_sig: { - if (dst.size() == 1) { - bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst), - get_alu_src(ctx, instr->src[0])); - } else if (dst.size() == 2) { - bld.vop1(aco_opcode::v_frexp_mant_f64, Definition(dst), - get_alu_src(ctx, instr->src[0])); + Temp src = get_alu_src(ctx, instr->src[0]); + if (dst.regClass() == v2b) { + Temp tmp = bld.vop1(aco_opcode::v_frexp_mant_f16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { + bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst), src); + } else if (dst.regClass() == v2) { + bld.vop1(aco_opcode::v_frexp_mant_f64, Definition(dst), src); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1863,12 +2155,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frexp_exp: { - if (instr->src[0].src.ssa->bit_size == 32) { - bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), - get_alu_src(ctx, instr->src[0])); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) { + Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src); + tmp = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v1b), tmp, Operand(0u)); + convert_int(bld, tmp, 8, 32, true, dst); + } else if (instr->src[0].src.ssa->bit_size == 32) { + bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { - bld.vop1(aco_opcode::v_frexp_exp_i32_f64, Definition(dst), - get_alu_src(ctx, instr->src[0])); + bld.vop1(aco_opcode::v_frexp_exp_i32_f64, Definition(dst), src); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); nir_print_instr(&instr->instr, stderr); @@ -1878,12 +2173,20 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_fsign: { Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0])); - if (dst.size() == 1) { + if (dst.regClass() == v2b) { + Temp one = bld.copy(bld.def(v1), Operand(0x3c00u)); + Temp minus_one = bld.copy(bld.def(v1), Operand(0xbc00u)); + Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); + src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), one, src, cond); + cond = bld.vopc(aco_opcode::v_cmp_le_f16, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); + Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), minus_one, src, cond); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else if (dst.regClass() == v1) { Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); src = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0x3f800000u), src, cond); cond = bld.vopc(aco_opcode::v_cmp_le_f32, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), Operand(0xbf800000u), src, cond); - } else if (dst.size() == 2) { + } else if (dst.regClass() == v2) { Temp cond = bld.vopc(aco_opcode::v_cmp_nlt_f64, bld.hint_vcc(bld.def(bld.lm)), Operand(0u), src); Temp tmp = bld.vop1(aco_opcode::v_mov_b32, bld.def(v1), Operand(0x3FF00000u)); Temp upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), tmp, emit_extract_vector(ctx, src, 1, v1), cond); @@ -1900,8 +2203,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_f2f16: + case nir_op_f2f16_rtne: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 64) + src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); + src = bld.vop1(aco_opcode::v_cvt_f16_f32, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + break; + } + case nir_op_f2f16_rtz: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 64) + src = bld.vop1(aco_opcode::v_cvt_f32_f64, bld.def(v1), src); + src = bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, bld.def(v1), src, Operand(0u)); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), src); + break; + } case nir_op_f2f32: { - if (instr->src[0].src.ssa->bit_size == 64) { + if (instr->src[0].src.ssa->bit_size == 16) { + emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f16, dst); + } else if (instr->src[0].src.ssa->bit_size == 64) { emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_f64, dst); } else { fprintf(stderr, "Unimplemented NIR instr bit size: "); @@ -1911,23 +2233,35 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_f2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_f32, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + bld.vop1(aco_opcode::v_cvt_f64_f32, Definition(dst), src); + break; + } + case nir_op_i2f16: { + assert(dst.regClass() == v2b); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, true); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_i16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); break; } case nir_op_i2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_i32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src); break; } case nir_op_i2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_i32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f64_i32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -1945,14 +2279,34 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_u2f16: { + assert(dst.regClass() == v2b); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, false); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v1), src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + break; + } case nir_op_u2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_u32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) { + //TODO: we should use v_cvt_f32_ubyte1/v_cvt_f32_ubyte2/etc depending on the register assignment + bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src); + } else { + if (instr->src[0].src.ssa->bit_size == 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src); + } break; } case nir_op_u2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_u32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, false); + bld.vop1(aco_opcode::v_cvt_f64_u32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -1969,9 +2323,49 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_f2i8: + case nir_op_f2i16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_i16_f16, bld.def(v1), src); + else if (instr->src[0].src.ssa->bit_size == 32) + src = bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), src); + else + src = bld.vop1(aco_opcode::v_cvt_i32_f64, bld.def(v1), src); + + if (dst.type() == RegType::vgpr) + bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); + else + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); + break; + } + case nir_op_f2u8: + case nir_op_f2u16: { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_u16_f16, bld.def(v1), src); + else if (instr->src[0].src.ssa->bit_size == 32) + src = bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), src); + else + src = bld.vop1(aco_opcode::v_cvt_u32_f64, bld.def(v1), src); + + if (dst.type() == RegType::vgpr) + bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); + else + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), src); + break; + } case nir_op_f2i32: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32) { + if (instr->src[0].src.ssa->bit_size == 16) { + Temp tmp = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + if (dst.type() == RegType::vgpr) { + bld.vop1(aco_opcode::v_cvt_i32_f32, Definition(dst), tmp); + } else { + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), + bld.vop1(aco_opcode::v_cvt_i32_f32, bld.def(v1), tmp)); + } + } else if (instr->src[0].src.ssa->bit_size == 32) { if (dst.type() == RegType::vgpr) bld.vop1(aco_opcode::v_cvt_i32_f32, Definition(dst), src); else @@ -1994,7 +2388,15 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_f2u32: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32) { + if (instr->src[0].src.ssa->bit_size == 16) { + Temp tmp = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + if (dst.type() == RegType::vgpr) { + bld.vop1(aco_opcode::v_cvt_u32_f32, Definition(dst), tmp); + } else { + bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), + bld.vop1(aco_opcode::v_cvt_u32_f32, bld.def(v1), tmp)); + } + } else if (instr->src[0].src.ssa->bit_size == 32) { if (dst.type() == RegType::vgpr) bld.vop1(aco_opcode::v_cvt_u32_f32, Definition(dst), src); else @@ -2017,7 +2419,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_f2i64: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::vgpr) { + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + + if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::vgpr) { Temp exponent = bld.vop1(aco_opcode::v_frexp_exp_i32_f32, bld.def(v1), src); exponent = bld.vop3(aco_opcode::v_med3_i32, bld.def(v1), Operand(0x0u), exponent, Operand(64u)); Temp mantissa = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0x7fffffu), src); @@ -2043,13 +2448,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp new_upper = bld.vsub32(bld.def(v1), upper, sign, false, borrow); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), new_lower, new_upper); - } else if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::sgpr) { + } else if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::sgpr) { if (src.type() == RegType::vgpr) src = bld.as_uniform(src); Temp exponent = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), src, Operand(0x80017u)); - exponent = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); - exponent = bld.sop2(aco_opcode::s_max_u32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); - exponent = bld.sop2(aco_opcode::s_min_u32, bld.def(s1), bld.def(s1, scc), Operand(64u), exponent); + exponent = bld.sop2(aco_opcode::s_sub_i32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); + exponent = bld.sop2(aco_opcode::s_max_i32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); + exponent = bld.sop2(aco_opcode::s_min_i32, bld.def(s1), bld.def(s1, scc), Operand(64u), exponent); Temp mantissa = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0x7fffffu), src); Temp sign = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), src, Operand(31u)); mantissa = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), Operand(0x800000u), mantissa); @@ -2093,7 +2498,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_f2u64: { Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::vgpr) { + if (instr->src[0].src.ssa->bit_size == 16) + src = bld.vop1(aco_opcode::v_cvt_f32_f16, bld.def(v1), src); + + if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::vgpr) { Temp exponent = bld.vop1(aco_opcode::v_frexp_exp_i32_f32, bld.def(v1), src); Temp exponent_in_range = bld.vopc(aco_opcode::v_cmp_ge_i32, bld.hint_vcc(bld.def(bld.lm)), Operand(64u), exponent); exponent = bld.vop2(aco_opcode::v_max_i32, bld.def(v1), Operand(0x0u), exponent); @@ -2116,12 +2524,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) upper = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0xffffffffu), upper, exponent_in_range); bld.pseudo(aco_opcode::p_create_vector, Definition(dst), lower, upper); - } else if (instr->src[0].src.ssa->bit_size == 32 && dst.type() == RegType::sgpr) { + } else if (instr->src[0].src.ssa->bit_size <= 32 && dst.type() == RegType::sgpr) { if (src.type() == RegType::vgpr) src = bld.as_uniform(src); Temp exponent = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), src, Operand(0x80017u)); - exponent = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); - exponent = bld.sop2(aco_opcode::s_max_u32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); + exponent = bld.sop2(aco_opcode::s_sub_i32, bld.def(s1), bld.def(s1, scc), exponent, Operand(126u)); + exponent = bld.sop2(aco_opcode::s_max_i32, bld.def(s1), bld.def(s1, scc), Operand(0u), exponent); Temp mantissa = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), Operand(0x7fffffu), src); mantissa = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), Operand(0x800000u), mantissa); Temp exponent_small = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), Operand(24u), exponent); @@ -2160,6 +2568,22 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_b2f16: { + Temp src = get_alu_src(ctx, instr->src[0]); + assert(src.regClass() == bld.lm); + + if (dst.regClass() == s1) { + src = bool_to_scalar_condition(ctx, src); + bld.sop2(aco_opcode::s_mul_i32, Definition(dst), Operand(0x3c00u), src); + } else if (dst.regClass() == v2b) { + Temp one = bld.copy(bld.def(v1), Operand(0x3c00u)); + Temp tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0u), one, src); + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); + } else { + unreachable("Wrong destination register class for nir_op_b2f16."); + } + break; + } case nir_op_b2f32: { Temp src = get_alu_src(ctx, instr->src[0]); assert(src.regClass() == bld.lm); @@ -2190,63 +2614,23 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } - case nir_op_i2i32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } - case nir_op_u2u32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 16) { - if (dst.regClass() == s1) { - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src); - } else { - // TODO: do better with SDWA - bld.vop2(aco_opcode::v_and_b32, Definition(dst), Operand(0xFFFFu), src); - } - } else if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } + case nir_op_i2i8: + case nir_op_i2i16: + case nir_op_i2i32: case nir_op_i2i64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (src.regClass() == s1) { - Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), src, Operand(31u)); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else if (src.regClass() == v1) { - Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), src); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst); break; } + case nir_op_u2u8: + case nir_op_u2u16: + case nir_op_u2u32: case nir_op_u2u64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32) { - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, Operand(0u)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst); break; } + case nir_op_b2b32: case nir_op_b2i32: { Temp src = get_alu_src(ctx, instr->src[0]); assert(src.regClass() == bld.lm); @@ -2261,6 +2645,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } + case nir_op_b2b1: case nir_op_i2b1: { Temp src = get_alu_src(ctx, instr->src[0]); assert(dst.regClass() == bld.lm); @@ -2290,12 +2675,40 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1); break; } - case nir_op_unpack_64_2x32_split_x: - bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0])); - break; - case nir_op_unpack_64_2x32_split_y: - bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); - break; + case nir_op_unpack_64_2x32_split_x: + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0])); + break; + case nir_op_unpack_64_2x32_split_y: + bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); + break; + case nir_op_unpack_32_2x16_split_x: + if (dst.type() == RegType::vgpr) { + bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0])); + } else { + bld.copy(Definition(dst), get_alu_src(ctx, instr->src[0])); + } + break; + case nir_op_unpack_32_2x16_split_y: + if (dst.type() == RegType::vgpr) { + bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0])); + } else { + bld.sop2(aco_opcode::s_bfe_u32, Definition(dst), bld.def(s1, scc), get_alu_src(ctx, instr->src[0]), Operand(uint32_t(16 << 16 | 16))); + } + break; + case nir_op_pack_32_2x16_split: { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = get_alu_src(ctx, instr->src[1]); + if (dst.regClass() == v1) { + src0 = emit_extract_vector(ctx, src0, 0, v2b); + src1 = emit_extract_vector(ctx, src1, 0, v2b); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src0, src1); + } else { + src0 = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), src0, Operand(0xFFFFu)); + src1 = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), src1, Operand(16u)); + bld.sop2(aco_opcode::s_or_b32, Definition(dst), bld.def(s1, scc), src0, src1); + } + break; + } case nir_op_pack_half_2x16: { Temp src = get_alu_src(ctx, instr->src[0], 2); @@ -2504,34 +2917,34 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flt: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_f16, aco_opcode::v_cmp_lt_f32, aco_opcode::v_cmp_lt_f64); break; } case nir_op_fge: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_f16, aco_opcode::v_cmp_ge_f32, aco_opcode::v_cmp_ge_f64); break; } case nir_op_feq: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_f16, aco_opcode::v_cmp_eq_f32, aco_opcode::v_cmp_eq_f64); break; } case nir_op_fne: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_neq_f16, aco_opcode::v_cmp_neq_f32, aco_opcode::v_cmp_neq_f64); break; } case nir_op_ilt: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_i16, aco_opcode::v_cmp_lt_i32, aco_opcode::v_cmp_lt_i64, aco_opcode::s_cmp_lt_i32); break; } case nir_op_ige: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_i16, aco_opcode::v_cmp_ge_i32, aco_opcode::v_cmp_ge_i64, aco_opcode::s_cmp_ge_i32); break; } case nir_op_ieq: { if (instr->src[0].src.ssa->bit_size == 1) emit_boolean_logic(ctx, instr, Builder::s_xnor, dst); else - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_eq_i16, aco_opcode::v_cmp_eq_i32, aco_opcode::v_cmp_eq_i64, aco_opcode::s_cmp_eq_i32, ctx->program->chip_class >= GFX8 ? aco_opcode::s_cmp_eq_u64 : aco_opcode::num_opcodes); break; } @@ -2539,16 +2952,16 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (instr->src[0].src.ssa->bit_size == 1) emit_boolean_logic(ctx, instr, Builder::s_xor, dst); else - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lg_i16, aco_opcode::v_cmp_lg_i32, aco_opcode::v_cmp_lg_i64, aco_opcode::s_cmp_lg_i32, ctx->program->chip_class >= GFX8 ? aco_opcode::s_cmp_lg_u64 : aco_opcode::num_opcodes); break; } case nir_op_ult: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_lt_u16, aco_opcode::v_cmp_lt_u32, aco_opcode::v_cmp_lt_u64, aco_opcode::s_cmp_lt_u32); break; } case nir_op_uge: { - emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); + emit_comparison(ctx, instr, dst, aco_opcode::v_cmp_ge_u16, aco_opcode::v_cmp_ge_u32, aco_opcode::v_cmp_ge_u64, aco_opcode::s_cmp_ge_u32); break; } case nir_op_fddx: @@ -2609,6 +3022,12 @@ void visit_load_const(isel_context *ctx, nir_load_const_instr *instr) int val = instr->value[0].b ? -1 : 0; Operand op = bld.lm.size() == 1 ? Operand((uint32_t) val) : Operand((uint64_t) val); bld.sop1(Builder::s_mov, Definition(dst), op); + } else if (instr->def.bit_size == 8) { + /* ensure that the value is correctly represented in the low byte of the register */ + bld.sopk(aco_opcode::s_movk_i32, Definition(dst), instr->value[0].u8); + } else if (instr->def.bit_size == 16) { + /* ensure that the value is correctly represented in the low half of the register */ + bld.sopk(aco_opcode::s_movk_i32, Definition(dst), instr->value[0].u16); } else if (dst.size() == 1) { bld.copy(Definition(dst), Operand(instr->value[0].u32)); } else { @@ -3349,11 +3768,8 @@ bool load_input_from_temps(isel_context *ctx, nir_intrinsic_instr *instr, Temp d unsigned idx = nir_intrinsic_base(instr) + nir_intrinsic_component(instr) + 4 * nir_src_as_uint(*off_src); Temp *src = &ctx->inputs.temps[idx]; - Temp vec = create_vec_from_array(ctx, src, dst.size(), dst.regClass().type(), 4u); - assert(vec.size() == dst.size()); + create_vec_from_array(ctx, src, dst.size(), dst.regClass().type(), 4u, 0, dst); - Builder bld(ctx->program, ctx->block); - bld.copy(Definition(dst), vec); return true; } @@ -3366,8 +3782,13 @@ void visit_store_ls_or_es_output(isel_context *ctx, nir_intrinsic_instr *instr) unsigned write_mask = nir_intrinsic_write_mask(instr); unsigned elem_size_bytes = instr->src[0].ssa->bit_size / 8u; - if (ctx->tcs_in_out_eq) - store_output_to_temps(ctx, instr); + if (ctx->tcs_in_out_eq && store_output_to_temps(ctx, instr)) { + /* When the TCS only reads this output directly and for the same vertices as its invocation id, it is unnecessary to store the VS output to LDS. */ + bool indirect_write; + bool temp_only_input = tcs_driver_location_matches_api_mask(ctx, instr, true, ctx->tcs_temp_only_inputs, &indirect_write); + if (temp_only_input && !indirect_write) + return; + } if (ctx->stage == vertex_es || ctx->stage == tess_eval_es) { /* GFX6-8: ES stage is not merged into GS, data is passed from ES to GS in VMEM. */ @@ -3478,6 +3899,8 @@ void visit_store_output(isel_context *ctx, nir_intrinsic_instr *instr) if (ctx->stage == vertex_vs || ctx->stage == tess_eval_vs || ctx->stage == fragment_fs || + ctx->stage == ngg_vertex_gs || + ctx->stage == ngg_tess_eval_gs || ctx->shader->info.stage == MESA_SHADER_GEOMETRY) { bool stored_to_temps = store_output_to_temps(ctx, instr); if (!stored_to_temps) { @@ -4158,20 +4581,28 @@ void visit_load_resource(isel_context *ctx, nir_intrinsic_instr *instr) bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)), index); } -void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, - Temp rsrc, Temp offset, bool glc=false, bool readonly=true) +void load_buffer(isel_context *ctx, unsigned num_components, unsigned component_size, + Temp dst, Temp rsrc, Temp offset, int byte_align, + bool glc=false, bool readonly=true) { Builder bld(ctx->program, ctx->block); - - unsigned num_bytes = dst.size() * 4; bool dlc = glc && ctx->options->chip_class >= GFX10; + unsigned num_bytes = num_components * component_size; aco_opcode op; - if (dst.type() == RegType::vgpr || (ctx->options->chip_class < GFX8 && !readonly)) { + if (dst.type() == RegType::vgpr || ((ctx->options->chip_class < GFX8 || component_size < 4) && !readonly)) { Operand vaddr = offset.type() == RegType::vgpr ? Operand(offset) : Operand(v1); Operand soffset = offset.type() == RegType::sgpr ? Operand(offset) : Operand((uint32_t) 0); unsigned const_offset = 0; + /* for small bit sizes add buffer for unaligned loads */ + if (byte_align) { + if (num_bytes > 2) + num_bytes += byte_align == -1 ? 4 - component_size : byte_align; + else + byte_align = 0; + } + Temp lower = Temp(); if (num_bytes > 16) { assert(num_components == 3 || num_components == 4); @@ -4197,12 +4628,23 @@ void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, } switch (num_bytes) { + case 1: + op = aco_opcode::buffer_load_ubyte; + break; + case 2: + op = aco_opcode::buffer_load_ushort; + break; + case 3: case 4: op = aco_opcode::buffer_load_dword; break; + case 5: + case 6: + case 7: case 8: op = aco_opcode::buffer_load_dwordx2; break; + case 10: case 12: assert(ctx->options->chip_class > GFX6); op = aco_opcode::buffer_load_dwordx3; @@ -4225,7 +4667,41 @@ void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, mubuf->offset = const_offset; aco_ptr instr = std::move(mubuf); - if (dst.size() > 4) { + if (component_size < 4) { + Temp vec = num_bytes <= 4 ? bld.tmp(v1) : num_bytes <= 8 ? bld.tmp(v2) : bld.tmp(v3); + instr->definitions[0] = Definition(vec); + bld.insert(std::move(instr)); + + if (byte_align == -1 || (byte_align && dst.type() == RegType::sgpr)) { + Operand align = byte_align == -1 ? Operand(offset) : Operand((uint32_t)byte_align); + Temp tmp[3] = {vec, vec, vec}; + + if (vec.size() == 3) { + tmp[0] = bld.tmp(v1), tmp[1] = bld.tmp(v1), tmp[2] = bld.tmp(v1); + bld.pseudo(aco_opcode::p_split_vector, Definition(tmp[0]), Definition(tmp[1]), Definition(tmp[2]), vec); + } else if (vec.size() == 2) { + tmp[0] = bld.tmp(v1), tmp[1] = bld.tmp(v1), tmp[2] = tmp[1]; + bld.pseudo(aco_opcode::p_split_vector, Definition(tmp[0]), Definition(tmp[1]), vec); + } + for (unsigned i = 0; i < dst.size(); i++) + tmp[i] = bld.vop3(aco_opcode::v_alignbyte_b32, bld.def(v1), tmp[i + 1], tmp[i], align); + + vec = tmp[0]; + if (dst.size() == 2) + vec = bld.pseudo(aco_opcode::p_create_vector, bld.def(v2), tmp[0], tmp[1]); + + byte_align = 0; + } + + if (dst.type() == RegType::vgpr && num_components == 1) { + bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), vec, Operand(byte_align / component_size)); + } else { + trim_subdword_vector(ctx, vec, dst, 4 * vec.size() / component_size, ((1 << num_components) - 1) << byte_align / component_size); + } + + return; + + } else if (dst.size() > 4) { assert(lower != Temp()); Temp upper = bld.tmp(RegType::vgpr, dst.size() - lower.size()); instr->definitions[0] = Definition(upper); @@ -4261,13 +4737,24 @@ void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, emit_split_vector(ctx, dst, num_components); } } else { + /* for small bit sizes add buffer for unaligned loads */ + if (byte_align) + num_bytes += byte_align == -1 ? 4 - component_size : byte_align; + switch (num_bytes) { + case 1: + case 2: + case 3: case 4: op = aco_opcode::s_buffer_load_dword; break; + case 5: + case 6: + case 7: case 8: op = aco_opcode::s_buffer_load_dwordx2; break; + case 10: case 12: case 16: op = aco_opcode::s_buffer_load_dwordx4; @@ -4279,9 +4766,10 @@ void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, default: unreachable("Load SSBO not implemented for this size."); } + offset = bld.as_uniform(offset); aco_ptr load{create_instruction(op, Format::SMEM, 2, 1)}; load->operands[0] = Operand(rsrc); - load->operands[1] = Operand(bld.as_uniform(offset)); + load->operands[1] = Operand(offset); assert(load->operands[1].getTemp().type() == RegType::sgpr); load->definitions[0] = Definition(dst); load->glc = glc; @@ -4290,8 +4778,16 @@ void load_buffer(isel_context *ctx, unsigned num_components, Temp dst, load->can_reorder = false; // FIXME: currently, it doesn't seem beneficial due to how our scheduler works assert(ctx->options->chip_class >= GFX8 || !glc); + /* adjust misaligned small bit size loads */ + if (byte_align) { + Temp vec = num_bytes <= 4 ? bld.tmp(s1) : num_bytes <= 8 ? bld.tmp(s2) : bld.tmp(s4); + load->definitions[0] = Definition(vec); + bld.insert(std::move(load)); + Operand byte_offset = byte_align > 0 ? Operand(uint32_t(byte_align)) : Operand(offset); + byte_align_scalar(ctx, vec, byte_offset, dst); + /* trim vector */ - if (dst.size() == 3) { + } else if (dst.size() == 3) { Temp vec = bld.tmp(s4); load->definitions[0] = Definition(vec); bld.insert(std::move(load)); @@ -4353,20 +4849,25 @@ void visit_load_ubo(isel_context *ctx, nir_intrinsic_instr *instr) rsrc = convert_pointer_to_64_bit(ctx, rsrc); rsrc = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), rsrc, Operand(0u)); } - - load_buffer(ctx, instr->num_components, dst, rsrc, get_ssa_temp(ctx, instr->src[1].ssa)); + unsigned size = instr->dest.ssa.bit_size / 8; + int byte_align = 0; + if (size < 4) { + unsigned align_mul = nir_intrinsic_align_mul(instr); + unsigned align_offset = nir_intrinsic_align_offset(instr); + byte_align = align_mul % 4 == 0 ? align_offset : -1; + } + load_buffer(ctx, instr->num_components, size, dst, rsrc, get_ssa_temp(ctx, instr->src[1].ssa), byte_align); } void visit_load_push_constant(isel_context *ctx, nir_intrinsic_instr *instr) { Builder bld(ctx->program, ctx->block); Temp dst = get_ssa_temp(ctx, &instr->dest.ssa); - unsigned offset = nir_intrinsic_base(instr); + unsigned count = instr->dest.ssa.num_components; nir_const_value *index_cv = nir_src_as_const_value(instr->src[0]); - if (index_cv && instr->dest.ssa.bit_size == 32) { - unsigned count = instr->dest.ssa.num_components; + if (index_cv && instr->dest.ssa.bit_size == 32) { unsigned start = (offset + index_cv->u32) / 4u; start -= ctx->args->ac.base_inline_push_consts; if (start + count <= ctx->args->ac.num_inline_push_consts) { @@ -4389,9 +4890,22 @@ void visit_load_push_constant(isel_context *ctx, nir_intrinsic_instr *instr) Temp ptr = convert_pointer_to_64_bit(ctx, get_arg(ctx, ctx->args->ac.push_constants)); Temp vec = dst; bool trim = false; + bool aligned = true; + + if (instr->dest.ssa.bit_size == 8) { + aligned = index_cv && (offset + index_cv->u32) % 4 == 0; + bool fits_in_dword = count == 1 || (index_cv && ((offset + index_cv->u32) % 4 + count) <= 4); + if (!aligned) + vec = fits_in_dword ? bld.tmp(s1) : bld.tmp(s2); + } else if (instr->dest.ssa.bit_size == 16) { + aligned = index_cv && (offset + index_cv->u32) % 4 == 0; + if (!aligned) + vec = count == 4 ? bld.tmp(s4) : count > 1 ? bld.tmp(s2) : bld.tmp(s1); + } + aco_opcode op; - switch (dst.size()) { + switch (vec.size()) { case 1: op = aco_opcode::s_load_dword; break; @@ -4416,6 +4930,12 @@ void visit_load_push_constant(isel_context *ctx, nir_intrinsic_instr *instr) bld.smem(op, Definition(vec), ptr, index); + if (!aligned) { + Operand byte_offset = index_cv ? Operand((offset + index_cv->u32) % 4) : Operand(index); + byte_align_scalar(ctx, vec, byte_offset, dst); + return; + } + if (trim) { emit_split_vector(ctx, vec, 4); RegClass rc = dst.size() == 3 ? s1 : s2; @@ -4460,8 +4980,10 @@ void visit_load_constant(isel_context *ctx, nir_intrinsic_instr *instr) bld.sop1(aco_opcode::p_constaddr, bld.def(s2), bld.def(s1, scc), Operand(ctx->constant_data_offset)), Operand(MIN2(base + range, ctx->shader->constant_data_size)), Operand(desc_type)); - - load_buffer(ctx, instr->num_components, dst, rsrc, offset); + unsigned size = instr->dest.ssa.bit_size / 8; + // TODO: get alignment information for subdword constants + unsigned byte_align = size < 4 ? -1 : 0; + load_buffer(ctx, instr->num_components, size, dst, rsrc, offset, byte_align); } void visit_discard_if(isel_context *ctx, nir_intrinsic_instr *instr) @@ -5274,7 +5796,14 @@ void visit_load_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) rsrc = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), rsrc, Operand(0u)); bool glc = nir_intrinsic_access(instr) & (ACCESS_VOLATILE | ACCESS_COHERENT); - load_buffer(ctx, num_components, dst, rsrc, get_ssa_temp(ctx, instr->src[1].ssa), glc, false); + unsigned size = instr->dest.ssa.bit_size / 8; + int byte_align = 0; + if (size < 4) { + unsigned align_mul = nir_intrinsic_align_mul(instr); + unsigned align_offset = nir_intrinsic_align_offset(instr); + byte_align = align_mul % 4 == 0 ? align_offset : -1; + } + load_buffer(ctx, num_components, size, dst, rsrc, get_ssa_temp(ctx, instr->src[1].ssa), byte_align, glc, false); } void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) @@ -5289,7 +5818,8 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) rsrc = bld.smem(aco_opcode::s_load_dwordx4, bld.def(s4), rsrc, Operand(0u)); bool smem = !ctx->divergent_vals[instr->src[2].ssa->index] && - ctx->options->chip_class >= GFX8; + ctx->options->chip_class >= GFX8 && + elem_size_bytes >= 4; if (smem) offset = bld.as_uniform(offset); bool smem_nonfs = smem && ctx->stage != fragment_fs; @@ -5304,6 +5834,15 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) } int num_bytes = count * elem_size_bytes; + /* dword or larger stores have to be dword-aligned */ + if (elem_size_bytes < 4 && num_bytes > 2) { + // TODO: improve alignment check of sub-dword stores + unsigned count_new = 2 / elem_size_bytes; + writemask |= ((1 << (count - count_new)) - 1) << (start + count_new); + count = count_new; + num_bytes = 2; + } + if (num_bytes > 16) { assert(elem_size_bytes == 8); writemask |= (((count - 2) << 1) - 1) << (start + 2); @@ -5311,12 +5850,20 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) num_bytes = 16; } - // TODO: check alignment of sub-dword stores - // TODO: split 3 bytes. there is no store instruction for that - Temp write_data; - if (count != instr->num_components) { - emit_split_vector(ctx, data, instr->num_components); + if (elem_size_bytes < 4) { + if (data.type() == RegType::sgpr) { + data = as_vgpr(ctx, data); + emit_split_vector(ctx, data, 4 * data.size() / elem_size_bytes); + } + RegClass rc = RegClass(RegType::vgpr, elem_size_bytes).as_subdword(); + aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, count, 1)}; + for (int i = 0; i < count; i++) + vec->operands[i] = Operand(emit_extract_vector(ctx, data, start + i, rc)); + write_data = bld.tmp(RegClass(RegType::vgpr, num_bytes).as_subdword()); + vec->definitions[0] = Definition(write_data); + bld.insert(std::move(vec)); + } else if (count != instr->num_components) { aco_ptr vec{create_instruction(aco_opcode::p_create_vector, Format::PSEUDO, count, 1)}; for (int i = 0; i < count; i++) { Temp elem = emit_extract_vector(ctx, data, start + i, RegClass(data.type(), elem_size_bytes / 4)); @@ -5335,8 +5882,14 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) write_data = data; } - aco_opcode vmem_op, smem_op; + aco_opcode vmem_op, smem_op = aco_opcode::last_opcode; switch (num_bytes) { + case 1: + vmem_op = aco_opcode::buffer_store_byte; + break; + case 2: + vmem_op = aco_opcode::buffer_store_short; + break; case 4: vmem_op = aco_opcode::buffer_store_dword; smem_op = aco_opcode::s_buffer_store_dword; @@ -5347,7 +5900,6 @@ void visit_store_ssbo(isel_context *ctx, nir_intrinsic_instr *instr) break; case 12: vmem_op = aco_opcode::buffer_store_dwordx3; - smem_op = aco_opcode::last_opcode; assert(!smem && ctx->options->chip_class > GFX6); break; case 16: @@ -7826,7 +8378,7 @@ void visit_tex(isel_context *ctx, nir_tex_instr *instr) if (instr->sampler_dim == GLSL_SAMPLER_DIM_1D && ctx->options->chip_class == GFX9) { assert(has_ddx && has_ddy && ddx.size() == 1 && ddy.size() == 1); Temp zero = bld.copy(bld.def(v1), Operand(0u)); - derivs = {ddy, zero, ddy, zero}; + derivs = {ddx, zero, ddy, zero}; } else { for (unsigned i = 0; has_ddx && i < ddx.size(); i++) derivs.emplace_back(emit_extract_vector(ctx, ddx, i, v1)); @@ -8863,11 +9415,98 @@ static void end_divergent_if(isel_context *ctx, if_context *ic) } } +static void begin_uniform_if_then(isel_context *ctx, if_context *ic, Temp cond) +{ + assert(cond.regClass() == s1); + + append_logical_end(ctx->block); + ctx->block->kind |= block_kind_uniform; + + aco_ptr branch; + aco_opcode branch_opcode = aco_opcode::p_cbranch_z; + branch.reset(create_instruction(branch_opcode, Format::PSEUDO_BRANCH, 1, 0)); + branch->operands[0] = Operand(cond); + branch->operands[0].setFixed(scc); + ctx->block->instructions.emplace_back(std::move(branch)); + + ic->BB_if_idx = ctx->block->index; + ic->BB_endif = Block(); + ic->BB_endif.loop_nest_depth = ctx->cf_info.loop_nest_depth; + ic->BB_endif.kind |= ctx->block->kind & block_kind_top_level; + + ctx->cf_info.has_branch = false; + ctx->cf_info.parent_loop.has_divergent_branch = false; + + /** emit then block */ + Block* BB_then = ctx->program->create_and_insert_block(); + BB_then->loop_nest_depth = ctx->cf_info.loop_nest_depth; + add_edge(ic->BB_if_idx, BB_then); + append_logical_start(BB_then); + ctx->block = BB_then; +} + +static void begin_uniform_if_else(isel_context *ctx, if_context *ic) +{ + Block *BB_then = ctx->block; + + ic->uniform_has_then_branch = ctx->cf_info.has_branch; + ic->then_branch_divergent = ctx->cf_info.parent_loop.has_divergent_branch; + + if (!ic->uniform_has_then_branch) { + append_logical_end(BB_then); + /* branch from then block to endif block */ + aco_ptr branch; + branch.reset(create_instruction(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0)); + BB_then->instructions.emplace_back(std::move(branch)); + add_linear_edge(BB_then->index, &ic->BB_endif); + if (!ic->then_branch_divergent) + add_logical_edge(BB_then->index, &ic->BB_endif); + BB_then->kind |= block_kind_uniform; + } + + ctx->cf_info.has_branch = false; + ctx->cf_info.parent_loop.has_divergent_branch = false; + + /** emit else block */ + Block* BB_else = ctx->program->create_and_insert_block(); + BB_else->loop_nest_depth = ctx->cf_info.loop_nest_depth; + add_edge(ic->BB_if_idx, BB_else); + append_logical_start(BB_else); + ctx->block = BB_else; +} + +static void end_uniform_if(isel_context *ctx, if_context *ic) +{ + Block *BB_else = ctx->block; + + if (!ctx->cf_info.has_branch) { + append_logical_end(BB_else); + /* branch from then block to endif block */ + aco_ptr branch; + branch.reset(create_instruction(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0)); + BB_else->instructions.emplace_back(std::move(branch)); + add_linear_edge(BB_else->index, &ic->BB_endif); + if (!ctx->cf_info.parent_loop.has_divergent_branch) + add_logical_edge(BB_else->index, &ic->BB_endif); + BB_else->kind |= block_kind_uniform; + } + + ctx->cf_info.has_branch &= ic->uniform_has_then_branch; + ctx->cf_info.parent_loop.has_divergent_branch &= ic->then_branch_divergent; + + /** emit endif merge block */ + if (!ctx->cf_info.has_branch) { + ctx->block = ctx->program->insert_block(std::move(ic->BB_endif)); + append_logical_start(ctx->block); + } +} + static bool visit_if(isel_context *ctx, nir_if *if_stmt) { Temp cond = get_ssa_temp(ctx, if_stmt->condition.ssa); Builder bld(ctx->program, ctx->block); aco_ptr branch; + if_context ic; if (!ctx->divergent_vals[if_stmt->condition.ssa->index]) { /* uniform condition */ /** @@ -8885,77 +9524,19 @@ static bool visit_if(isel_context *ctx, nir_if *if_stmt) * to the loop exit/entry block. Otherwise, it branches to the next * merge block. **/ - append_logical_end(ctx->block); - ctx->block->kind |= block_kind_uniform; - /* emit branch */ - assert(cond.regClass() == bld.lm); // TODO: in a post-RA optimizer, we could check if the condition is in VCC and omit this instruction + assert(cond.regClass() == ctx->program->lane_mask); cond = bool_to_scalar_condition(ctx, cond); - branch.reset(create_instruction(aco_opcode::p_cbranch_z, Format::PSEUDO_BRANCH, 1, 0)); - branch->operands[0] = Operand(cond); - branch->operands[0].setFixed(scc); - ctx->block->instructions.emplace_back(std::move(branch)); - - unsigned BB_if_idx = ctx->block->index; - Block BB_endif = Block(); - BB_endif.loop_nest_depth = ctx->cf_info.loop_nest_depth; - BB_endif.kind |= ctx->block->kind & block_kind_top_level; - - /** emit then block */ - Block* BB_then = ctx->program->create_and_insert_block(); - BB_then->loop_nest_depth = ctx->cf_info.loop_nest_depth; - add_edge(BB_if_idx, BB_then); - append_logical_start(BB_then); - ctx->block = BB_then; + begin_uniform_if_then(ctx, &ic, cond); visit_cf_list(ctx, &if_stmt->then_list); - BB_then = ctx->block; - bool then_branch = ctx->cf_info.has_branch; - bool then_branch_divergent = ctx->cf_info.parent_loop.has_divergent_branch; - - if (!then_branch) { - append_logical_end(BB_then); - /* branch from then block to endif block */ - branch.reset(create_instruction(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0)); - BB_then->instructions.emplace_back(std::move(branch)); - add_linear_edge(BB_then->index, &BB_endif); - if (!then_branch_divergent) - add_logical_edge(BB_then->index, &BB_endif); - BB_then->kind |= block_kind_uniform; - } - - ctx->cf_info.has_branch = false; - ctx->cf_info.parent_loop.has_divergent_branch = false; - /** emit else block */ - Block* BB_else = ctx->program->create_and_insert_block(); - BB_else->loop_nest_depth = ctx->cf_info.loop_nest_depth; - add_edge(BB_if_idx, BB_else); - append_logical_start(BB_else); - ctx->block = BB_else; + begin_uniform_if_else(ctx, &ic); visit_cf_list(ctx, &if_stmt->else_list); - BB_else = ctx->block; - - if (!ctx->cf_info.has_branch) { - append_logical_end(BB_else); - /* branch from then block to endif block */ - branch.reset(create_instruction(aco_opcode::p_branch, Format::PSEUDO_BRANCH, 0, 0)); - BB_else->instructions.emplace_back(std::move(branch)); - add_linear_edge(BB_else->index, &BB_endif); - if (!ctx->cf_info.parent_loop.has_divergent_branch) - add_logical_edge(BB_else->index, &BB_endif); - BB_else->kind |= block_kind_uniform; - } - ctx->cf_info.has_branch &= then_branch; - ctx->cf_info.parent_loop.has_divergent_branch &= then_branch_divergent; + end_uniform_if(ctx, &ic); - /** emit endif merge block */ - if (!ctx->cf_info.has_branch) { - ctx->block = ctx->program->insert_block(std::move(BB_endif)); - append_logical_start(ctx->block); - } return !ctx->cf_info.has_branch; } else { /* non-uniform condition */ /** @@ -8983,8 +9564,6 @@ static bool visit_if(isel_context *ctx, nir_if *if_stmt) * *) Exceptions may be due to break and continue statements within loops **/ - if_context ic; - begin_divergent_if_then(ctx, &ic, cond); visit_cf_list(ctx, &if_stmt->then_list); @@ -9036,9 +9615,11 @@ static bool export_vs_varying(isel_context *ctx, int slot, bool is_pos, int *nex { assert(ctx->stage == vertex_vs || ctx->stage == tess_eval_vs || - ctx->stage == gs_copy_vs); + ctx->stage == gs_copy_vs || + ctx->stage == ngg_vertex_gs || + ctx->stage == ngg_tess_eval_gs); - int offset = ctx->stage == tess_eval_vs + int offset = (ctx->stage & sw_tes) ? ctx->program->info->tes.outinfo.vs_output_param_offset[slot] : ctx->program->info->vs.outinfo.vs_output_param_offset[slot]; uint64_t mask = ctx->outputs.mask[slot]; @@ -9106,17 +9687,46 @@ static void export_vs_psiz_layer_viewport(isel_context *ctx, int *next_pos) ctx->block->instructions.emplace_back(std::move(exp)); } +static void create_export_phis(isel_context *ctx) +{ + /* Used when exports are needed, but the output temps are defined in a preceding block. + * This function will set up phis in order to access the outputs in the next block. + */ + + assert(ctx->block->instructions.back()->opcode == aco_opcode::p_logical_start); + aco_ptr logical_start = aco_ptr(ctx->block->instructions.back().release()); + ctx->block->instructions.pop_back(); + + Builder bld(ctx->program, ctx->block); + + for (unsigned slot = 0; slot <= VARYING_SLOT_VAR31; ++slot) { + uint64_t mask = ctx->outputs.mask[slot]; + for (unsigned i = 0; i < 4; ++i) { + if (!(mask & (1 << i))) + continue; + + Temp old = ctx->outputs.temps[slot * 4 + i]; + Temp phi = bld.pseudo(aco_opcode::p_phi, bld.def(v1), old, Operand(v1)); + ctx->outputs.temps[slot * 4 + i] = phi; + } + } + + bld.insert(std::move(logical_start)); +} + static void create_vs_exports(isel_context *ctx) { assert(ctx->stage == vertex_vs || ctx->stage == tess_eval_vs || - ctx->stage == gs_copy_vs); + ctx->stage == gs_copy_vs || + ctx->stage == ngg_vertex_gs || + ctx->stage == ngg_tess_eval_gs); - radv_vs_output_info *outinfo = ctx->stage == tess_eval_vs + radv_vs_output_info *outinfo = (ctx->stage & sw_tes) ? &ctx->program->info->tes.outinfo : &ctx->program->info->vs.outinfo; - if (outinfo->export_prim_id) { + if (outinfo->export_prim_id && !(ctx->stage & hw_ngg_gs)) { ctx->outputs.mask[VARYING_SLOT_PRIMITIVE_ID] |= 0x1; ctx->outputs.temps[VARYING_SLOT_PRIMITIVE_ID * 4u] = get_arg(ctx, ctx->args->vs_prim_id); } @@ -9146,8 +9756,10 @@ static void create_vs_exports(isel_context *ctx) } for (unsigned i = 0; i <= VARYING_SLOT_VAR31; ++i) { - if (i < VARYING_SLOT_VAR0 && i != VARYING_SLOT_LAYER && - i != VARYING_SLOT_PRIMITIVE_ID) + if (i < VARYING_SLOT_VAR0 && + i != VARYING_SLOT_LAYER && + i != VARYING_SLOT_PRIMITIVE_ID && + i != VARYING_SLOT_VIEWPORT) continue; export_vs_varying(ctx, i, false, NULL); @@ -9506,7 +10118,7 @@ static void emit_stream_output(isel_context *ctx, Temp out[4]; bool all_undef = true; - assert(ctx->stage == vertex_vs || ctx->stage == gs_copy_vs); + assert(ctx->stage & hw_vs); for (unsigned i = 0; i < num_comps; i++) { out[i] = ctx->outputs.temps[loc * 4 + start + i]; all_undef = all_undef && !out[i].id(); @@ -9784,6 +10396,239 @@ void cleanup_cfg(Program *program) } } +Temp merged_wave_info_to_mask(isel_context *ctx, unsigned i) +{ + Builder bld(ctx->program, ctx->block); + + /* The s_bfm only cares about s0.u[5:0] so we don't need either s_bfe nor s_and here */ + Temp count = i == 0 + ? get_arg(ctx, ctx->args->merged_wave_info) + : bld.sop2(aco_opcode::s_lshr_b32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->merged_wave_info), Operand(i * 8u)); + + Temp mask = bld.sop2(aco_opcode::s_bfm_b64, bld.def(s2), count, Operand(0u)); + Temp cond; + + if (ctx->program->wave_size == 64) { + /* Special case for 64 active invocations, because 64 doesn't work with s_bfm */ + Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), count, Operand(6u /* log2(64) */)); + cond = bld.sop2(Builder::s_cselect, bld.def(bld.lm), Operand(-1u), mask, bld.scc(active_64)); + } else { + /* We use s_bfm_b64 (not _b32) which works with 32, but we need to extract the lower half of the register */ + cond = emit_extract_vector(ctx, mask, 0, bld.lm); + } + + return cond; +} + +bool ngg_early_prim_export(isel_context *ctx) +{ + /* TODO: Check edge flags, and if they are written, return false. (Needed for OpenGL, not for Vulkan.) */ + return true; +} + +void ngg_emit_sendmsg_gs_alloc_req(isel_context *ctx) +{ + Builder bld(ctx->program, ctx->block); + + /* It is recommended to do the GS_ALLOC_REQ as soon and as quickly as possible, so we set the maximum priority (3). */ + bld.sopp(aco_opcode::s_setprio, -1u, 0x3u); + + /* Get the id of the current wave within the threadgroup (workgroup) */ + Builder::Result wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16))); + + /* Execute the following code only on the first wave (wave id 0), + * use the SCC def to tell if the wave id is zero or not. + */ + Temp cond = wave_id_in_tg.def(1).getTemp(); + if_context ic; + begin_uniform_if_then(ctx, &ic, cond); + begin_uniform_if_else(ctx, &ic); + bld.reset(ctx->block); + + /* Number of vertices output by VS/TES */ + Temp vtx_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->gs_tg_info), Operand(12u | (9u << 16u))); + /* Number of primitives output by VS/TES */ + Temp prm_cnt = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->gs_tg_info), Operand(22u | (9u << 16u))); + + /* Put the number of vertices and primitives into m0 for the GS_ALLOC_REQ */ + Temp tmp = bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), prm_cnt, Operand(12u)); + tmp = bld.sop2(aco_opcode::s_or_b32, bld.m0(bld.def(s1)), bld.def(s1, scc), tmp, vtx_cnt); + + /* Request the SPI to allocate space for the primitives and vertices that will be exported by the threadgroup. */ + bld.sopp(aco_opcode::s_sendmsg, bld.m0(tmp), -1, sendmsg_gs_alloc_req); + + /* After the GS_ALLOC_REQ is done, reset priority to default (0). */ + bld.sopp(aco_opcode::s_setprio, -1u, 0x0u); + + end_uniform_if(ctx, &ic); +} + +Temp ngg_get_prim_exp_arg(isel_context *ctx, unsigned num_vertices, const Temp vtxindex[]) +{ + Builder bld(ctx->program, ctx->block); + + if (ctx->args->options->key.vs_common_out.as_ngg_passthrough) { + return get_arg(ctx, ctx->args->gs_vtx_offset[0]); + } + + Temp gs_invocation_id = get_arg(ctx, ctx->args->ac.gs_invocation_id); + Temp tmp; + + for (unsigned i = 0; i < num_vertices; ++i) { + assert(vtxindex[i].id()); + + if (i) + tmp = bld.vop3(aco_opcode::v_lshl_add_u32, bld.def(v1), vtxindex[i], Operand(10u * i), tmp); + else + tmp = vtxindex[i]; + + /* The initial edge flag is always false in tess eval shaders. */ + if (ctx->stage == ngg_vertex_gs) { + Temp edgeflag = bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), gs_invocation_id, Operand(8 + i), Operand(1u)); + tmp = bld.vop3(aco_opcode::v_lshl_add_u32, bld.def(v1), edgeflag, Operand(10u * i + 9u), tmp); + } + } + + /* TODO: Set isnull field in case of merged NGG VS+GS. */ + + return tmp; +} + +void ngg_emit_prim_export(isel_context *ctx, unsigned num_vertices_per_primitive, const Temp vtxindex[]) +{ + Builder bld(ctx->program, ctx->block); + Temp prim_exp_arg = ngg_get_prim_exp_arg(ctx, num_vertices_per_primitive, vtxindex); + + bld.exp(aco_opcode::exp, prim_exp_arg, Operand(v1), Operand(v1), Operand(v1), + 1 /* enabled mask */, V_008DFC_SQ_EXP_PRIM /* dest */, + false /* compressed */, true/* done */, false /* valid mask */); +} + +void ngg_emit_nogs_gsthreads(isel_context *ctx) +{ + /* Emit the things that NGG GS threads need to do, for shaders that don't have SW GS. + * These must always come before VS exports. + * + * It is recommended to do these as early as possible. They can be at the beginning when + * there is no SW GS and the shader doesn't write edge flags. + */ + + if_context ic; + Temp is_gs_thread = merged_wave_info_to_mask(ctx, 1); + begin_divergent_if_then(ctx, &ic, is_gs_thread); + + Builder bld(ctx->program, ctx->block); + constexpr unsigned max_vertices_per_primitive = 3; + unsigned num_vertices_per_primitive = max_vertices_per_primitive; + + if (ctx->stage == ngg_vertex_gs) { + /* TODO: optimize for points & lines */ + } else if (ctx->stage == ngg_tess_eval_gs) { + if (ctx->shader->info.tess.point_mode) + num_vertices_per_primitive = 1; + else if (ctx->shader->info.tess.primitive_mode == GL_ISOLINES) + num_vertices_per_primitive = 2; + } else { + unreachable("Unsupported NGG shader stage"); + } + + Temp vtxindex[max_vertices_per_primitive]; + vtxindex[0] = bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xffffu), + get_arg(ctx, ctx->args->gs_vtx_offset[0])); + vtxindex[1] = num_vertices_per_primitive < 2 ? Temp(0, v1) : + bld.vop3(aco_opcode::v_bfe_u32, bld.def(v1), + get_arg(ctx, ctx->args->gs_vtx_offset[0]), Operand(16u), Operand(16u)); + vtxindex[2] = num_vertices_per_primitive < 3 ? Temp(0, v1) : + bld.vop2(aco_opcode::v_and_b32, bld.def(v1), Operand(0xffffu), + get_arg(ctx, ctx->args->gs_vtx_offset[2])); + + /* Export primitive data to the index buffer. */ + ngg_emit_prim_export(ctx, num_vertices_per_primitive, vtxindex); + + /* Export primitive ID. */ + if (ctx->stage == ngg_vertex_gs && ctx->args->options->key.vs_common_out.export_prim_id) { + /* Copy Primitive IDs from GS threads to the LDS address corresponding to the ES thread of the provoking vertex. */ + Temp prim_id = get_arg(ctx, ctx->args->ac.gs_prim_id); + Temp provoking_vtx_index = vtxindex[0]; + Temp addr = bld.v_mul_imm(bld.def(v1), provoking_vtx_index, 4u); + + store_lds(ctx, 4, prim_id, 0x1u, addr, 0u, 4u); + } + + begin_divergent_if_else(ctx, &ic); + end_divergent_if(ctx, &ic); +} + +void ngg_emit_nogs_output(isel_context *ctx) +{ + /* Emits NGG GS output, for stages that don't have SW GS. */ + + if_context ic; + Builder bld(ctx->program, ctx->block); + bool late_prim_export = !ngg_early_prim_export(ctx); + + /* NGG streamout is currently disabled by default. */ + assert(!ctx->args->shader_info->so.num_outputs); + + if (late_prim_export) { + /* VS exports are output to registers in a predecessor block. Emit phis to get them into this block. */ + create_export_phis(ctx); + /* Do what we need to do in the GS threads. */ + ngg_emit_nogs_gsthreads(ctx); + + /* What comes next should be executed on ES threads. */ + Temp is_es_thread = merged_wave_info_to_mask(ctx, 0); + begin_divergent_if_then(ctx, &ic, is_es_thread); + bld.reset(ctx->block); + } + + /* Export VS outputs */ + ctx->block->kind |= block_kind_export_end; + create_vs_exports(ctx); + + /* Export primitive ID */ + if (ctx->args->options->key.vs_common_out.export_prim_id) { + Temp prim_id; + + if (ctx->stage == ngg_vertex_gs) { + /* Wait for GS threads to store primitive ID in LDS. */ + bld.barrier(aco_opcode::p_memory_barrier_shared); + bld.sopp(aco_opcode::s_barrier); + + /* Calculate LDS address where the GS threads stored the primitive ID. */ + Temp wave_id_in_tg = bld.sop2(aco_opcode::s_bfe_u32, bld.def(s1), bld.def(s1, scc), + get_arg(ctx, ctx->args->merged_wave_info), Operand(24u | (4u << 16))); + Temp thread_id_in_wave = emit_mbcnt(ctx, bld.def(v1)); + Temp wave_id_mul = bld.v_mul_imm(bld.def(v1), as_vgpr(ctx, wave_id_in_tg), ctx->program->wave_size); + Temp thread_id_in_tg = bld.vadd32(bld.def(v1), Operand(wave_id_mul), Operand(thread_id_in_wave)); + Temp addr = bld.v_mul_imm(bld.def(v1), thread_id_in_tg, 4u); + + /* Load primitive ID from LDS. */ + prim_id = load_lds(ctx, 4, bld.tmp(v1), addr, 0u, 4u); + } else if (ctx->stage == ngg_tess_eval_gs) { + /* TES: Just use the patch ID as the primitive ID. */ + prim_id = get_arg(ctx, ctx->args->ac.tes_patch_id); + } else { + unreachable("unsupported NGG shader stage."); + } + + ctx->outputs.mask[VARYING_SLOT_PRIMITIVE_ID] |= 0x1; + ctx->outputs.temps[VARYING_SLOT_PRIMITIVE_ID * 4u] = prim_id; + + export_vs_varying(ctx, VARYING_SLOT_PRIMITIVE_ID, false, nullptr); + } + + if (late_prim_export) { + begin_divergent_if_else(ctx, &ic); + end_divergent_if(ctx, &ic); + bld.reset(ctx->block); + } +} + void select_program(Program *program, unsigned shader_count, struct nir_shader *const *shaders, @@ -9792,6 +10637,7 @@ void select_program(Program *program, { isel_context ctx = setup_isel_context(program, shader_count, shaders, config, args, false); if_context ic_merged_wave_info; + bool ngg_no_gs = ctx.stage == ngg_vertex_gs || ctx.stage == ngg_tess_eval_gs; for (unsigned i = 0; i < shader_count; i++) { nir_shader *nir = shaders[i]; @@ -9810,6 +10656,13 @@ void select_program(Program *program, split_arguments(&ctx, startpgm); } + if (ngg_no_gs) { + ngg_emit_sendmsg_gs_alloc_req(&ctx); + + if (ngg_early_prim_export(&ctx)) + ngg_emit_nogs_gsthreads(&ctx); + } + /* In a merged VS+TCS HS, the VS implementation can be completely empty. */ nir_function_impl *func = nir_shader_get_entrypoint(nir); bool empty_shader = nir_cf_list_is_empty_block(&func->body) && @@ -9818,28 +10671,10 @@ void select_program(Program *program, (nir->info.stage == MESA_SHADER_TESS_EVAL && ctx.stage == tess_eval_geometry_gs)); - bool check_merged_wave_info = ctx.tcs_in_out_eq ? i == 0 : (shader_count >= 2 && !empty_shader); + bool check_merged_wave_info = ctx.tcs_in_out_eq ? i == 0 : ((shader_count >= 2 && !empty_shader) || ngg_no_gs); bool endif_merged_wave_info = ctx.tcs_in_out_eq ? i == 1 : check_merged_wave_info; if (check_merged_wave_info) { - Builder bld(ctx.program, ctx.block); - - /* The s_bfm only cares about s0.u[5:0] so we don't need either s_bfe nor s_and here */ - Temp count = i == 0 ? get_arg(&ctx, args->merged_wave_info) - : bld.sop2(aco_opcode::s_lshr_b32, bld.def(s1), bld.def(s1, scc), - get_arg(&ctx, args->merged_wave_info), Operand(i * 8u)); - - Temp mask = bld.sop2(aco_opcode::s_bfm_b64, bld.def(s2), count, Operand(0u)); - Temp cond; - - if (ctx.program->wave_size == 64) { - /* Special case for 64 active invocations, because 64 doesn't work with s_bfm */ - Temp active_64 = bld.sopc(aco_opcode::s_bitcmp1_b32, bld.def(s1, scc), count, Operand(6u /* log2(64) */)); - cond = bld.sop2(Builder::s_cselect, bld.def(bld.lm), Operand(-1u), mask, bld.scc(active_64)); - } else { - /* We use s_bfm_b64 (not _b32) which works with 32, but we need to extract the lower half of the register */ - cond = emit_extract_vector(&ctx, mask, 0, bld.lm); - } - + Temp cond = merged_wave_info_to_mask(&ctx, i); begin_divergent_if_then(&ctx, &ic_merged_wave_info, cond); } @@ -9860,11 +10695,14 @@ void select_program(Program *program, visit_cf_list(&ctx, &func->body); - if (ctx.program->info->so.num_outputs && (ctx.stage == vertex_vs || ctx.stage == tess_eval_vs)) + if (ctx.program->info->so.num_outputs && (ctx.stage & hw_vs)) emit_streamout(&ctx, 0); - if (ctx.stage == vertex_vs || ctx.stage == tess_eval_vs) { + if (ctx.stage & hw_vs) { create_vs_exports(&ctx); + ctx.block->kind |= block_kind_export_end; + } else if (ngg_no_gs && ngg_early_prim_export(&ctx)) { + ngg_emit_nogs_output(&ctx); } else if (nir->info.stage == MESA_SHADER_GEOMETRY) { Builder bld(ctx.program, ctx.block); bld.barrier(aco_opcode::p_memory_barrier_gs_data); @@ -9873,14 +10711,19 @@ void select_program(Program *program, write_tcs_tess_factors(&ctx); } - if (ctx.stage == fragment_fs) + if (ctx.stage == fragment_fs) { create_fs_exports(&ctx); + ctx.block->kind |= block_kind_export_end; + } if (endif_merged_wave_info) { begin_divergent_if_else(&ctx, &ic_merged_wave_info); end_divergent_if(&ctx, &ic_merged_wave_info); } + if (ngg_no_gs && !ngg_early_prim_export(&ctx)) + ngg_emit_nogs_output(&ctx); + ralloc_free(ctx.divergent_vals); if (i == 0 && ctx.stage == vertex_tess_control_hs && ctx.tcs_in_out_eq) { @@ -9893,7 +10736,7 @@ void select_program(Program *program, program->config->float_mode = program->blocks[0].fp_mode.val; append_logical_end(ctx.block); - ctx.block->kind |= block_kind_uniform | block_kind_export_end; + ctx.block->kind |= block_kind_uniform; Builder bld(ctx.program, ctx.block); if (ctx.program->wb_smem_l1_on_end) bld.smem(aco_opcode::s_dcache_wb, false);