From a79dad950b1f10ddeca2c907025a0f649b470cb9 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Daniel=20Sch=C3=BCrmann?= Date: Thu, 18 Jun 2020 15:14:20 +0100 Subject: [PATCH] nir,amd: remove trinary_minmax opcodes MIME-Version: 1.0 Content-Type: text/plain; charset=utf8 Content-Transfer-Encoding: 8bit These consist of the variations nir_op_{i|u|f}{min|max|med}3 which are either lowered in the backend (LLVM) anyway or can be recombined by the backend (ACO). Reviewed-by: Marek Olšák Part-of: --- .../compiler/aco_instruction_selection.cpp | 78 ------------------- .../aco_instruction_selection_setup.cpp | 3 - src/amd/llvm/ac_llvm_build.c | 48 ------------ src/amd/llvm/ac_nir_to_llvm.c | 51 ------------ src/compiler/nir/nir_lower_int64.c | 18 ----- src/compiler/nir/nir_opcodes.py | 14 ---- src/compiler/nir/nir_opt_algebraic.py | 4 - src/compiler/nir/nir_range_analysis.c | 14 ---- src/compiler/spirv/vtn_amd.c | 29 ++++--- 9 files changed, 20 insertions(+), 239 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 6f1f8b4e07e..737a88e8d19 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1793,84 +1793,6 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } break; } - case nir_op_fmax3: { - if (dst.regClass() == v2b) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, dst, false); - } else if (dst.regClass() == v1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_fmin3: { - if (dst.regClass() == v2b) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, dst, false); - } else if (dst.regClass() == v1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_fmed3: { - if (dst.regClass() == v2b) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, dst, false); - } else if (dst.regClass() == v1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_umax3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_u32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_umin3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_u32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_umed3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_u32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_imax3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_i32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_imin3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_i32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } - case nir_op_imed3: { - if (dst.size() == 1) { - emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_i32, dst); - } else { - isel_err(&instr->instr, "Unimplemented NIR instr bit size"); - } - break; - } case nir_op_cube_face_coord: { Temp in = get_alu_src(ctx, instr->src[0], 3); Temp src[3] = { emit_extract_vector(ctx, in, 0, v1), diff --git a/src/amd/compiler/aco_instruction_selection_setup.cpp b/src/amd/compiler/aco_instruction_selection_setup.cpp index 874e015ca78..53f7ced4bf5 100644 --- a/src/amd/compiler/aco_instruction_selection_setup.cpp +++ b/src/amd/compiler/aco_instruction_selection_setup.cpp @@ -600,9 +600,6 @@ void init_context(isel_context *ctx, nir_shader *shader) case nir_op_fsub: case nir_op_fmax: case nir_op_fmin: - case nir_op_fmax3: - case nir_op_fmin3: - case nir_op_fmed3: case nir_op_fneg: case nir_op_fabs: case nir_op_fsat: diff --git a/src/amd/llvm/ac_llvm_build.c b/src/amd/llvm/ac_llvm_build.c index 8be8433c997..77d3f7e73fb 100644 --- a/src/amd/llvm/ac_llvm_build.c +++ b/src/amd/llvm/ac_llvm_build.c @@ -2727,54 +2727,6 @@ void ac_build_waitcnt(struct ac_llvm_context *ctx, unsigned wait_flags) ctx->voidt, args, 1, 0); } -LLVMValueRef ac_build_fmed3(struct ac_llvm_context *ctx, LLVMValueRef src0, - LLVMValueRef src1, LLVMValueRef src2, - unsigned bitsize) -{ - LLVMValueRef result; - - if (bitsize == 64 || (bitsize == 16 && ctx->chip_class <= GFX8)) { - /* Lower 64-bit fmed because LLVM doesn't expose an intrinsic, - * or lower 16-bit fmed because it's only supported on GFX9+. - */ - LLVMValueRef min1, min2, max1; - - min1 = ac_build_fmin(ctx, src0, src1); - max1 = ac_build_fmax(ctx, src0, src1); - min2 = ac_build_fmin(ctx, max1, src2); - - result = ac_build_fmax(ctx, min2, min1); - } else { - LLVMTypeRef type; - char *intr; - - if (bitsize == 16) { - intr = "llvm.amdgcn.fmed3.f16"; - type = ctx->f16; - } else { - assert(bitsize == 32); - intr = "llvm.amdgcn.fmed3.f32"; - type = ctx->f32; - } - - LLVMValueRef params[] = { - src0, - src1, - src2, - }; - - result = ac_build_intrinsic(ctx, intr, type, params, 3, - AC_FUNC_ATTR_READNONE); - } - - if (ctx->chip_class < GFX9 && bitsize == 32) { - /* Only pre-GFX9 chips do not flush denorms. */ - result = ac_build_canonicalize(ctx, result, bitsize); - } - - return result; -} - LLVMValueRef ac_build_fract(struct ac_llvm_context *ctx, LLVMValueRef src0, unsigned bitsize) { diff --git a/src/amd/llvm/ac_nir_to_llvm.c b/src/amd/llvm/ac_nir_to_llvm.c index 37a483e3ba6..1b6ef264eef 100644 --- a/src/amd/llvm/ac_nir_to_llvm.c +++ b/src/amd/llvm/ac_nir_to_llvm.c @@ -1174,57 +1174,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr) break; } - case nir_op_fmin3: - result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum", - ac_to_float_type(&ctx->ac, def_type), src[0], src[1]); - result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum", - ac_to_float_type(&ctx->ac, def_type), result, src[2]); - break; - case nir_op_umin3: - result = ac_build_umin(&ctx->ac, src[0], src[1]); - result = ac_build_umin(&ctx->ac, result, src[2]); - break; - case nir_op_imin3: - result = ac_build_imin(&ctx->ac, src[0], src[1]); - result = ac_build_imin(&ctx->ac, result, src[2]); - break; - case nir_op_fmax3: - result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum", - ac_to_float_type(&ctx->ac, def_type), src[0], src[1]); - result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum", - ac_to_float_type(&ctx->ac, def_type), result, src[2]); - break; - case nir_op_umax3: - result = ac_build_umax(&ctx->ac, src[0], src[1]); - result = ac_build_umax(&ctx->ac, result, src[2]); - break; - case nir_op_imax3: - result = ac_build_imax(&ctx->ac, src[0], src[1]); - result = ac_build_imax(&ctx->ac, result, src[2]); - break; - case nir_op_fmed3: { - src[0] = ac_to_float(&ctx->ac, src[0]); - src[1] = ac_to_float(&ctx->ac, src[1]); - src[2] = ac_to_float(&ctx->ac, src[2]); - result = ac_build_fmed3(&ctx->ac, src[0], src[1], src[2], - instr->dest.dest.ssa.bit_size); - break; - } - case nir_op_imed3: { - LLVMValueRef tmp1 = ac_build_imin(&ctx->ac, src[0], src[1]); - LLVMValueRef tmp2 = ac_build_imax(&ctx->ac, src[0], src[1]); - tmp2 = ac_build_imin(&ctx->ac, tmp2, src[2]); - result = ac_build_imax(&ctx->ac, tmp1, tmp2); - break; - } - case nir_op_umed3: { - LLVMValueRef tmp1 = ac_build_umin(&ctx->ac, src[0], src[1]); - LLVMValueRef tmp2 = ac_build_umax(&ctx->ac, src[0], src[1]); - tmp2 = ac_build_umin(&ctx->ac, tmp2, src[2]); - result = ac_build_umax(&ctx->ac, tmp1, tmp2); - break; - } - default: fprintf(stderr, "Unknown NIR alu instr: "); nir_print_instr(&instr->instr, stderr); diff --git a/src/compiler/nir/nir_lower_int64.c b/src/compiler/nir/nir_lower_int64.c index 0c14fe58853..e780948c37d 100644 --- a/src/compiler/nir/nir_lower_int64.c +++ b/src/compiler/nir/nir_lower_int64.c @@ -838,12 +838,6 @@ nir_lower_int64_op_to_options_mask(nir_op opcode) case nir_op_imax: case nir_op_umin: case nir_op_umax: - case nir_op_imin3: - case nir_op_imax3: - case nir_op_umin3: - case nir_op_umax3: - case nir_op_imed3: - case nir_op_umed3: return nir_lower_minmax64; case nir_op_iabs: return nir_lower_iabs64; @@ -944,18 +938,6 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state) return lower_umin64(b, src[0], src[1]); case nir_op_umax: return lower_umax64(b, src[0], src[1]); - case nir_op_imin3: - return lower_imin64(b, src[0], lower_imin64(b, src[1], src[2])); - case nir_op_imax3: - return lower_imax64(b, src[0], lower_imax64(b, src[1], src[2])); - case nir_op_umin3: - return lower_umin64(b, src[0], lower_umin64(b, src[1], src[2])); - case nir_op_umax3: - return lower_umax64(b, src[0], lower_umax64(b, src[1], src[2])); - case nir_op_imed3: - return lower_imax64(b, lower_imin64(b, lower_imax64(b, src[0], src[1]), src[2]), lower_imin64(b, src[0], src[1])); - case nir_op_umed3: - return lower_umax64(b, lower_umin64(b, lower_umax64(b, src[0], src[1]), src[2]), lower_umin64(b, src[0], src[1])); case nir_op_iabs: return lower_iabs64(b, src[0]); case nir_op_ineg: diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py index 87b5e4efac1..e19d7b00a7d 100644 --- a/src/compiler/nir/nir_opcodes.py +++ b/src/compiler/nir/nir_opcodes.py @@ -950,22 +950,8 @@ triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2") # component on vectors). There are two versions, one for floating point # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0). - triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2") -# 3 way min/max/med -triop("fmin3", tfloat, "", "fminf(src0, fminf(src1, src2))") -triop("imin3", tint, "", "MIN2(src0, MIN2(src1, src2))") -triop("umin3", tuint, "", "MIN2(src0, MIN2(src1, src2))") - -triop("fmax3", tfloat, "", "fmaxf(src0, fmaxf(src1, src2))") -triop("imax3", tint, "", "MAX2(src0, MAX2(src1, src2))") -triop("umax3", tuint, "", "MAX2(src0, MAX2(src1, src2))") - -triop("fmed3", tfloat, "", "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))") -triop("imed3", tint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))") -triop("umed3", tuint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))") - opcode("bcsel", 0, tuint, [0, 0, 0], [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2") opcode("b8csel", 0, tuint, [0, 0, 0], diff --git a/src/compiler/nir/nir_opt_algebraic.py b/src/compiler/nir/nir_opt_algebraic.py index 4a2efa8252f..c394b07a4e9 100644 --- a/src/compiler/nir/nir_opt_algebraic.py +++ b/src/compiler/nir/nir_opt_algebraic.py @@ -1153,10 +1153,6 @@ optimizations.extend([ (('bcsel', a, ('bcsel', b, c, d), d), ('bcsel', ('iand', a, b), c, d)), (('bcsel', a, b, ('bcsel', c, b, d)), ('bcsel', ('ior', a, c), b, d)), - (('fmin3@64', a, b, c), ('fmin@64', a, ('fmin@64', b, c))), - (('fmax3@64', a, b, c), ('fmax@64', a, ('fmax@64', b, c))), - (('fmed3@64', a, b, c), ('fmax@64', ('fmin@64', ('fmax@64', a, b), c), ('fmin@64', a, b))), - # Misc. lowering (('fmod', a, b), ('fsub', a, ('fmul', b, ('ffloor', ('fdiv', a, b)))), 'options->lower_fmod'), (('frem', a, b), ('fsub', a, ('fmul', b, ('ftrunc', ('fdiv', a, b)))), 'options->lower_fmod'), diff --git a/src/compiler/nir/nir_range_analysis.c b/src/compiler/nir/nir_range_analysis.c index 5ef66ad8922..e23c7c4fdb7 100644 --- a/src/compiler/nir/nir_range_analysis.c +++ b/src/compiler/nir/nir_range_analysis.c @@ -1319,10 +1319,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht, case nir_op_udiv: case nir_op_bcsel: case nir_op_b32csel: - case nir_op_imax3: - case nir_op_imin3: - case nir_op_umax3: - case nir_op_umin3: case nir_op_ubfe: case nir_op_bfm: case nir_op_f2u32: @@ -1405,16 +1401,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht, case nir_op_b32csel: res = src1 > src2 ? src1 : src2; break; - case nir_op_imax3: - case nir_op_imin3: - case nir_op_umax3: - src0 = src0 > src1 ? src0 : src1; - res = src0 > src2 ? src0 : src2; - break; - case nir_op_umin3: - src0 = src0 < src1 ? src0 : src1; - res = src0 < src2 ? src0 : src2; - break; case nir_op_ubfe: res = bitmask(MIN2(src2, scalar.def->bit_size)); break; diff --git a/src/compiler/spirv/vtn_amd.c b/src/compiler/spirv/vtn_amd.c index 4ba8193b532..55000418dcd 100644 --- a/src/compiler/spirv/vtn_amd.c +++ b/src/compiler/spirv/vtn_amd.c @@ -126,34 +126,45 @@ vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ex for (unsigned i = 0; i < num_inputs; i++) src[i] = vtn_get_nir_ssa(b, w[i + 5]); + /* place constants at src[1-2] for easier constant-folding */ + for (unsigned i = 1; i <= 2; i++) { + if (nir_src_as_const_value(nir_src_for_ssa(src[0]))) { + nir_ssa_def* tmp = src[i]; + src[i] = src[0]; + src[0] = tmp; + } + } nir_ssa_def *def; switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) { case FMin3AMD: - def = nir_fmin3(nb, src[0], src[1], src[2]); + def = nir_fmin(nb, src[0], nir_fmin(nb, src[1], src[2])); break; case UMin3AMD: - def = nir_umin3(nb, src[0], src[1], src[2]); + def = nir_umin(nb, src[0], nir_umin(nb, src[1], src[2])); break; case SMin3AMD: - def = nir_imin3(nb, src[0], src[1], src[2]); + def = nir_imin(nb, src[0], nir_imin(nb, src[1], src[2])); break; case FMax3AMD: - def = nir_fmax3(nb, src[0], src[1], src[2]); + def = nir_fmax(nb, src[0], nir_fmax(nb, src[1], src[2])); break; case UMax3AMD: - def = nir_umax3(nb, src[0], src[1], src[2]); + def = nir_umax(nb, src[0], nir_umax(nb, src[1], src[2])); break; case SMax3AMD: - def = nir_imax3(nb, src[0], src[1], src[2]); + def = nir_imax(nb, src[0], nir_imax(nb, src[1], src[2])); break; case FMid3AMD: - def = nir_fmed3(nb, src[0], src[1], src[2]); + def = nir_fmin(nb, nir_fmax(nb, src[0], nir_fmin(nb, src[1], src[2])), + nir_fmax(nb, src[1], src[2])); break; case UMid3AMD: - def = nir_umed3(nb, src[0], src[1], src[2]); + def = nir_umin(nb, nir_umax(nb, src[0], nir_umin(nb, src[1], src[2])), + nir_umax(nb, src[1], src[2])); break; case SMid3AMD: - def = nir_imed3(nb, src[0], src[1], src[2]); + def = nir_imin(nb, nir_imax(nb, src[0], nir_imin(nb, src[1], src[2])), + nir_imax(nb, src[1], src[2])); break; default: unreachable("unknown opcode\n"); -- 2.30.2