From: Rhys Perry Date: Tue, 14 Apr 2020 15:39:58 +0000 (+0100) Subject: aco: implement various 8/16-bit conversions X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=ac74367befcf51917025f9fe2ce1dc431c2875fd;p=mesa.git aco: implement various 8/16-bit conversions Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 21ba2ec2cf6..d1b9f9238d7 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -950,6 +950,58 @@ Temp emit_floor_f64(isel_context *ctx, Builder& bld, Definition dst, Temp val) return add->definitions[0].getTemp(); } +Temp convert_int(Builder& bld, Temp src, unsigned src_bits, unsigned dst_bits, bool is_signed, Temp dst=Temp()) { + if (!dst.id()) { + if (dst_bits % 32 == 0 || src.type() == RegType::sgpr) + dst = bld.tmp(src.type(), DIV_ROUND_UP(dst_bits, 32u)); + else + dst = bld.tmp(RegClass(RegType::vgpr, dst_bits / 8u).as_subdword()); + } + + if (dst.bytes() == src.bytes() && dst_bits < src_bits) + return bld.copy(Definition(dst), src); + else if (dst.bytes() < src.bytes()) + return bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), src, Operand(0u)); + + Temp tmp = dst; + if (dst_bits == 64) + tmp = src_bits == 32 ? src : bld.tmp(src.type(), 1); + + if (tmp == src) { + } else if (src.regClass() == s1) { + if (is_signed) + bld.sop1(src_bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, Definition(tmp), src); + else + bld.sop2(aco_opcode::s_and_b32, Definition(tmp), bld.def(s1, scc), Operand(src_bits == 8 ? 0xFFu : 0xFFFFu), src); + } else { + assert(src_bits != 8 || src.regClass() == v1b); + assert(src_bits != 16 || src.regClass() == v2b); + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; + sdwa->operands[0] = Operand(src); + sdwa->definitions[0] = Definition(tmp); + if (is_signed) + sdwa->sel[0] = src_bits == 8 ? sdwa_sbyte : sdwa_sword; + else + sdwa->sel[0] = src_bits == 8 ? sdwa_ubyte : sdwa_uword; + sdwa->dst_sel = tmp.bytes() == 2 ? sdwa_uword : sdwa_udword; + bld.insert(std::move(sdwa)); + } + + if (dst_bits == 64) { + if (is_signed && dst.regClass() == s2) { + Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), tmp, Operand(31u)); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else if (is_signed && dst.regClass() == v2) { + Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), tmp); + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, high); + } else { + bld.pseudo(aco_opcode::p_create_vector, Definition(dst), tmp, Operand(0u)); + } + } + + return dst; +} + void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) { if (!instr->dest.dest.is_ssa) { @@ -2114,12 +2166,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 16) { Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(tmp); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sbyte; - sdwa->dst_sel = sdwa_sdword; - ctx->block->instructions.emplace_back(std::move(sdwa)); + tmp = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v1b), tmp, Operand(0u)); + convert_int(bld, tmp, 8, 32, true, dst); } else if (instr->src[0].src.ssa->bit_size == 32) { bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { @@ -2201,19 +2249,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_i2f16: { assert(dst.regClass() == v2b); - Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_i16, bld.def(v1), - get_alu_src(ctx, instr->src[0])); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, true); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_i16, bld.def(v1), src); bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); break; } case nir_op_i2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_i32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_i32, Definition(dst), src); break; } case nir_op_i2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_i32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f64_i32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -2233,19 +2289,32 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } case nir_op_u2f16: { assert(dst.regClass() == v2b); - Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v1), - get_alu_src(ctx, instr->src[0])); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) + src = convert_int(bld, src, 8, 16, false); + Temp tmp = bld.vop1(aco_opcode::v_cvt_f16_u16, bld.def(v1), src); bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(v2b), tmp); break; } case nir_op_u2f32: { assert(dst.size() == 1); - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f32_u32, dst); + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size == 8) { + //TODO: we should use v_cvt_f32_ubyte1/v_cvt_f32_ubyte2/etc depending on the register assignment + bld.vop1(aco_opcode::v_cvt_f32_ubyte0, Definition(dst), src); + } else { + if (instr->src[0].src.ssa->bit_size == 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, true); + bld.vop1(aco_opcode::v_cvt_f32_u32, Definition(dst), src); + } break; } case nir_op_u2f64: { - if (instr->src[0].src.ssa->bit_size == 32) { - emit_vop1_instruction(ctx, instr, aco_opcode::v_cvt_f64_u32, dst); + if (instr->src[0].src.ssa->bit_size <= 32) { + Temp src = get_alu_src(ctx, instr->src[0]); + if (instr->src[0].src.ssa->bit_size <= 16) + src = convert_int(bld, src, instr->src[0].src.ssa->bit_size, 32, false); + bld.vop1(aco_opcode::v_cvt_f64_u32, Definition(dst), src); } else if (instr->src[0].src.ssa->bit_size == 64) { Temp src = get_alu_src(ctx, instr->src[0]); RegClass rc = RegClass(src.type(), 1); @@ -2554,159 +2623,19 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_i2i8: - case nir_op_u2u8: { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - break; - } - case nir_op_i2i16: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sbyte; - sdwa->dst_sel = sdwa_sword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - } - break; - } - case nir_op_u2u16: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); - else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_ubyte; - sdwa->dst_sel = sdwa_uword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else { - Temp src = get_alu_src(ctx, instr->src[0]); - /* we can actually just say dst = src */ - if (src.regClass() == s1) - bld.copy(Definition(dst), src); - else - emit_extract_vector(ctx, src, 0, dst); - } - break; - } - case nir_op_i2i32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i8, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sbyte; - sdwa->dst_sel = sdwa_sdword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 16) { - if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_sext_i32_i16, Definition(dst), Operand(src)); - } else { - assert(src.regClass() == v2b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_sword; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } - case nir_op_u2u32: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 8) { - if (dst.regClass() == s1) - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFu), src); - else { - assert(src.regClass() == v1b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_ubyte; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 16) { - if (dst.regClass() == s1) { - bld.sop2(aco_opcode::s_and_b32, Definition(dst), bld.def(s1, scc), Operand(0xFFFFu), src); - } else { - assert(src.regClass() == v2b); - aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, asSDWA(Format::VOP1), 1, 1)}; - sdwa->operands[0] = Operand(src); - sdwa->definitions[0] = Definition(dst); - sdwa->sel[0] = sdwa_uword; - sdwa->dst_sel = sdwa_udword; - ctx->block->instructions.emplace_back(std::move(sdwa)); - } - } else if (instr->src[0].src.ssa->bit_size == 64) { - /* we can actually just say dst = src, as it would map the lower register */ - emit_extract_vector(ctx, src, 0, dst); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } - break; - } + case nir_op_i2i16: + case nir_op_i2i32: case nir_op_i2i64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (src.regClass() == s1) { - Temp high = bld.sop2(aco_opcode::s_ashr_i32, bld.def(s1), bld.def(s1, scc), src, Operand(31u)); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else if (src.regClass() == v1) { - Temp high = bld.vop2(aco_opcode::v_ashrrev_i32, bld.def(v1), Operand(31u), src); - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, high); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, true, dst); break; } + case nir_op_u2u8: + case nir_op_u2u16: + case nir_op_u2u32: case nir_op_u2u64: { - Temp src = get_alu_src(ctx, instr->src[0]); - if (instr->src[0].src.ssa->bit_size == 32) { - bld.pseudo(aco_opcode::p_create_vector, Definition(dst), src, Operand(0u)); - } else { - fprintf(stderr, "Unimplemented NIR instr bit size: "); - nir_print_instr(&instr->instr, stderr); - fprintf(stderr, "\n"); - } + convert_int(bld, get_alu_src(ctx, instr->src[0]), + instr->src[0].src.ssa->bit_size, instr->dest.dest.ssa.bit_size, false, dst); break; } case nir_op_b2b32: