nir,amd: remove trinary_minmax opcodes
authorDaniel Schürmann <daniel@schuermann.dev>
Thu, 18 Jun 2020 14:14:20 +0000 (15:14 +0100)
committerMarge Bot <eric+marge@anholt.net>
Mon, 24 Aug 2020 20:56:11 +0000 (20:56 +0000)
These consist of the variations nir_op_{i|u|f}{min|max|med}3 which are either
lowered in the backend (LLVM) anyway or can be recombined by the backend (ACO).

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/6421>

src/amd/compiler/aco_instruction_selection.cpp
src/amd/compiler/aco_instruction_selection_setup.cpp
src/amd/llvm/ac_llvm_build.c
src/amd/llvm/ac_nir_to_llvm.c
src/compiler/nir/nir_lower_int64.c
src/compiler/nir/nir_opcodes.py
src/compiler/nir/nir_opt_algebraic.py
src/compiler/nir/nir_range_analysis.c
src/compiler/spirv/vtn_amd.c

index 6f1f8b4e07e701d59d173dd70c4a87f0eeb497bb..737a88e8d1904130c6d5941afe659735443b178a 100644 (file)
@@ -1793,84 +1793,6 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
       }
       break;
    }
       }
       break;
    }
-   case nir_op_fmax3: {
-      if (dst.regClass() == v2b) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f16, dst, false);
-      } else if (dst.regClass() == v1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_fmin3: {
-      if (dst.regClass() == v2b) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f16, dst, false);
-      } else if (dst.regClass() == v1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_fmed3: {
-      if (dst.regClass() == v2b) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f16, dst, false);
-      } else if (dst.regClass() == v1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_f32, dst, ctx->block->fp_mode.must_flush_denorms32);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_umax3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_u32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_umin3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_u32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_umed3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_u32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_imax3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_max3_i32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_imin3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_min3_i32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
-   case nir_op_imed3: {
-      if (dst.size() == 1) {
-         emit_vop3a_instruction(ctx, instr, aco_opcode::v_med3_i32, dst);
-      } else {
-         isel_err(&instr->instr, "Unimplemented NIR instr bit size");
-      }
-      break;
-   }
    case nir_op_cube_face_coord: {
       Temp in = get_alu_src(ctx, instr->src[0], 3);
       Temp src[3] = { emit_extract_vector(ctx, in, 0, v1),
    case nir_op_cube_face_coord: {
       Temp in = get_alu_src(ctx, instr->src[0], 3);
       Temp src[3] = { emit_extract_vector(ctx, in, 0, v1),
index 874e015ca78845aa23ae72835197b3b50b7ba341..53f7ced4bf5be4058a490d48dee6b976b71b9687 100644 (file)
@@ -600,9 +600,6 @@ void init_context(isel_context *ctx, nir_shader *shader)
                   case nir_op_fsub:
                   case nir_op_fmax:
                   case nir_op_fmin:
                   case nir_op_fsub:
                   case nir_op_fmax:
                   case nir_op_fmin:
-                  case nir_op_fmax3:
-                  case nir_op_fmin3:
-                  case nir_op_fmed3:
                   case nir_op_fneg:
                   case nir_op_fabs:
                   case nir_op_fsat:
                   case nir_op_fneg:
                   case nir_op_fabs:
                   case nir_op_fsat:
index 8be8433c997e207ee15312981ce51a02811dab3e..77d3f7e73fbe84cf49548ad95d247b14a8a04d0b 100644 (file)
@@ -2727,54 +2727,6 @@ void ac_build_waitcnt(struct ac_llvm_context *ctx, unsigned wait_flags)
                           ctx->voidt, args, 1, 0);
 }
 
                           ctx->voidt, args, 1, 0);
 }
 
-LLVMValueRef ac_build_fmed3(struct ac_llvm_context *ctx, LLVMValueRef src0,
-                           LLVMValueRef src1, LLVMValueRef src2,
-                           unsigned bitsize)
-{
-       LLVMValueRef result;
-
-       if (bitsize == 64 || (bitsize == 16 && ctx->chip_class <= GFX8)) {
-               /* Lower 64-bit fmed because LLVM doesn't expose an intrinsic,
-                * or lower 16-bit fmed because it's only supported on GFX9+.
-                */
-               LLVMValueRef min1, min2, max1;
-
-               min1 = ac_build_fmin(ctx, src0, src1);
-               max1 = ac_build_fmax(ctx, src0, src1);
-               min2 = ac_build_fmin(ctx, max1, src2);
-
-               result = ac_build_fmax(ctx, min2, min1);
-       } else {
-               LLVMTypeRef type;
-               char *intr;
-
-               if (bitsize == 16) {
-                       intr = "llvm.amdgcn.fmed3.f16";
-                       type = ctx->f16;
-               } else {
-                       assert(bitsize == 32);
-                       intr = "llvm.amdgcn.fmed3.f32";
-                       type = ctx->f32;
-               }
-
-               LLVMValueRef params[] = {
-                       src0,
-                       src1,
-                       src2,
-               };
-
-               result = ac_build_intrinsic(ctx, intr, type, params, 3,
-                                           AC_FUNC_ATTR_READNONE);
-       }
-
-       if (ctx->chip_class < GFX9 && bitsize == 32) {
-               /* Only pre-GFX9 chips do not flush denorms. */
-               result = ac_build_canonicalize(ctx, result, bitsize);
-       }
-
-       return result;
-}
-
 LLVMValueRef ac_build_fract(struct ac_llvm_context *ctx, LLVMValueRef src0,
                            unsigned bitsize)
 {
 LLVMValueRef ac_build_fract(struct ac_llvm_context *ctx, LLVMValueRef src0,
                            unsigned bitsize)
 {
index 37a483e3ba6bf4720eeb6ca0ff60bd77721adbcd..1b6ef264eef3bb7024eab5b7ce59d728e520876e 100644 (file)
@@ -1174,57 +1174,6 @@ static void visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
                break;
        }
 
                break;
        }
 
-       case nir_op_fmin3:
-               result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum",
-                                               ac_to_float_type(&ctx->ac, def_type), src[0], src[1]);
-               result = emit_intrin_2f_param(&ctx->ac, "llvm.minnum",
-                                               ac_to_float_type(&ctx->ac, def_type), result, src[2]);
-               break;
-       case nir_op_umin3:
-               result = ac_build_umin(&ctx->ac, src[0], src[1]);
-               result = ac_build_umin(&ctx->ac, result, src[2]);
-               break;
-       case nir_op_imin3:
-               result = ac_build_imin(&ctx->ac, src[0], src[1]);
-               result = ac_build_imin(&ctx->ac, result, src[2]);
-               break;
-       case nir_op_fmax3:
-               result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum",
-                                               ac_to_float_type(&ctx->ac, def_type), src[0], src[1]);
-               result = emit_intrin_2f_param(&ctx->ac, "llvm.maxnum",
-                                               ac_to_float_type(&ctx->ac, def_type), result, src[2]);
-               break;
-       case nir_op_umax3:
-               result = ac_build_umax(&ctx->ac, src[0], src[1]);
-               result = ac_build_umax(&ctx->ac, result, src[2]);
-               break;
-       case nir_op_imax3:
-               result = ac_build_imax(&ctx->ac, src[0], src[1]);
-               result = ac_build_imax(&ctx->ac, result, src[2]);
-               break;
-       case nir_op_fmed3: {
-               src[0] = ac_to_float(&ctx->ac, src[0]);
-               src[1] = ac_to_float(&ctx->ac, src[1]);
-               src[2] = ac_to_float(&ctx->ac, src[2]);
-               result = ac_build_fmed3(&ctx->ac, src[0], src[1], src[2],
-                                       instr->dest.dest.ssa.bit_size);
-               break;
-       }
-       case nir_op_imed3: {
-               LLVMValueRef tmp1 = ac_build_imin(&ctx->ac, src[0], src[1]);
-               LLVMValueRef tmp2 = ac_build_imax(&ctx->ac, src[0], src[1]);
-               tmp2 = ac_build_imin(&ctx->ac, tmp2, src[2]);
-               result = ac_build_imax(&ctx->ac, tmp1, tmp2);
-               break;
-       }
-       case nir_op_umed3: {
-               LLVMValueRef tmp1 = ac_build_umin(&ctx->ac, src[0], src[1]);
-               LLVMValueRef tmp2 = ac_build_umax(&ctx->ac, src[0], src[1]);
-               tmp2 = ac_build_umin(&ctx->ac, tmp2, src[2]);
-               result = ac_build_umax(&ctx->ac, tmp1, tmp2);
-               break;
-       }
-
        default:
                fprintf(stderr, "Unknown NIR alu instr: ");
                nir_print_instr(&instr->instr, stderr);
        default:
                fprintf(stderr, "Unknown NIR alu instr: ");
                nir_print_instr(&instr->instr, stderr);
index 0c14fe58853e9f55956b791843840ca7e837f145..e780948c37d3362ae768e5c0f9f1c8764b7dd725 100644 (file)
@@ -838,12 +838,6 @@ nir_lower_int64_op_to_options_mask(nir_op opcode)
    case nir_op_imax:
    case nir_op_umin:
    case nir_op_umax:
    case nir_op_imax:
    case nir_op_umin:
    case nir_op_umax:
-   case nir_op_imin3:
-   case nir_op_imax3:
-   case nir_op_umin3:
-   case nir_op_umax3:
-   case nir_op_imed3:
-   case nir_op_umed3:
       return nir_lower_minmax64;
    case nir_op_iabs:
       return nir_lower_iabs64;
       return nir_lower_minmax64;
    case nir_op_iabs:
       return nir_lower_iabs64;
@@ -944,18 +938,6 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
       return lower_umin64(b, src[0], src[1]);
    case nir_op_umax:
       return lower_umax64(b, src[0], src[1]);
       return lower_umin64(b, src[0], src[1]);
    case nir_op_umax:
       return lower_umax64(b, src[0], src[1]);
-   case nir_op_imin3:
-      return lower_imin64(b, src[0], lower_imin64(b, src[1], src[2]));
-   case nir_op_imax3:
-      return lower_imax64(b, src[0], lower_imax64(b, src[1], src[2]));
-   case nir_op_umin3:
-      return lower_umin64(b, src[0], lower_umin64(b, src[1], src[2]));
-   case nir_op_umax3:
-      return lower_umax64(b, src[0], lower_umax64(b, src[1], src[2]));
-   case nir_op_imed3:
-      return lower_imax64(b, lower_imin64(b, lower_imax64(b, src[0], src[1]), src[2]), lower_imin64(b, src[0], src[1]));
-   case nir_op_umed3:
-      return lower_umax64(b, lower_umin64(b, lower_umax64(b, src[0], src[1]), src[2]), lower_umin64(b, src[0], src[1]));
    case nir_op_iabs:
       return lower_iabs64(b, src[0]);
    case nir_op_ineg:
    case nir_op_iabs:
       return lower_iabs64(b, src[0]);
    case nir_op_ineg:
index 87b5e4efac15b09220903a153c5738d29b70e10d..e19d7b00a7d3e956a57334ad32b674dfbc9f28ad 100644 (file)
@@ -950,22 +950,8 @@ triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
 # component on vectors). There are two versions, one for floating point
 # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
 
 # component on vectors). There are two versions, one for floating point
 # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
 
-
 triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
 
 triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
 
-# 3 way min/max/med
-triop("fmin3", tfloat, "", "fminf(src0, fminf(src1, src2))")
-triop("imin3", tint, "", "MIN2(src0, MIN2(src1, src2))")
-triop("umin3", tuint, "", "MIN2(src0, MIN2(src1, src2))")
-
-triop("fmax3", tfloat, "", "fmaxf(src0, fmaxf(src1, src2))")
-triop("imax3", tint, "", "MAX2(src0, MAX2(src1, src2))")
-triop("umax3", tuint, "", "MAX2(src0, MAX2(src1, src2))")
-
-triop("fmed3", tfloat, "", "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))")
-triop("imed3", tint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
-triop("umed3", tuint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
-
 opcode("bcsel", 0, tuint, [0, 0, 0],
        [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
 opcode("b8csel", 0, tuint, [0, 0, 0],
 opcode("bcsel", 0, tuint, [0, 0, 0],
        [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
 opcode("b8csel", 0, tuint, [0, 0, 0],
index 4a2efa8252f62963e1a8b294d22ad8dcfa54b43a..c394b07a4e91f3347543d8b94c7441edf846c045 100644 (file)
@@ -1153,10 +1153,6 @@ optimizations.extend([
    (('bcsel', a, ('bcsel', b, c, d), d), ('bcsel', ('iand', a, b), c, d)),
    (('bcsel', a, b, ('bcsel', c, b, d)), ('bcsel', ('ior', a, c), b, d)),
 
    (('bcsel', a, ('bcsel', b, c, d), d), ('bcsel', ('iand', a, b), c, d)),
    (('bcsel', a, b, ('bcsel', c, b, d)), ('bcsel', ('ior', a, c), b, d)),
 
-   (('fmin3@64', a, b, c), ('fmin@64', a, ('fmin@64', b, c))),
-   (('fmax3@64', a, b, c), ('fmax@64', a, ('fmax@64', b, c))),
-   (('fmed3@64', a, b, c), ('fmax@64', ('fmin@64', ('fmax@64', a, b), c), ('fmin@64', a, b))),
-
    # Misc. lowering
    (('fmod', a, b), ('fsub', a, ('fmul', b, ('ffloor', ('fdiv', a, b)))), 'options->lower_fmod'),
    (('frem', a, b), ('fsub', a, ('fmul', b, ('ftrunc', ('fdiv', a, b)))), 'options->lower_fmod'),
    # Misc. lowering
    (('fmod', a, b), ('fsub', a, ('fmul', b, ('ffloor', ('fdiv', a, b)))), 'options->lower_fmod'),
    (('frem', a, b), ('fsub', a, ('fmul', b, ('ftrunc', ('fdiv', a, b)))), 'options->lower_fmod'),
index 5ef66ad892222063a10452e0f37d808fa7a0ba9c..e23c7c4fdb76e012330b35de66d28d1a74df0b2d 100644 (file)
@@ -1319,10 +1319,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
       case nir_op_udiv:
       case nir_op_bcsel:
       case nir_op_b32csel:
       case nir_op_udiv:
       case nir_op_bcsel:
       case nir_op_b32csel:
-      case nir_op_imax3:
-      case nir_op_imin3:
-      case nir_op_umax3:
-      case nir_op_umin3:
       case nir_op_ubfe:
       case nir_op_bfm:
       case nir_op_f2u32:
       case nir_op_ubfe:
       case nir_op_bfm:
       case nir_op_f2u32:
@@ -1405,16 +1401,6 @@ nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
       case nir_op_b32csel:
          res = src1 > src2 ? src1 : src2;
          break;
       case nir_op_b32csel:
          res = src1 > src2 ? src1 : src2;
          break;
-      case nir_op_imax3:
-      case nir_op_imin3:
-      case nir_op_umax3:
-         src0 = src0 > src1 ? src0 : src1;
-         res = src0 > src2 ? src0 : src2;
-         break;
-      case nir_op_umin3:
-         src0 = src0 < src1 ? src0 : src1;
-         res = src0 < src2 ? src0 : src2;
-         break;
       case nir_op_ubfe:
          res = bitmask(MIN2(src2, scalar.def->bit_size));
          break;
       case nir_op_ubfe:
          res = bitmask(MIN2(src2, scalar.def->bit_size));
          break;
index 4ba8193b532dd2df07a76c4f61c1c466818f7471..55000418dcd49a911662ca3346cf8c215c08a405 100644 (file)
@@ -126,34 +126,45 @@ vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ex
    for (unsigned i = 0; i < num_inputs; i++)
       src[i] = vtn_get_nir_ssa(b, w[i + 5]);
 
    for (unsigned i = 0; i < num_inputs; i++)
       src[i] = vtn_get_nir_ssa(b, w[i + 5]);
 
+   /* place constants at src[1-2] for easier constant-folding */
+   for (unsigned i = 1; i <= 2; i++) {
+      if (nir_src_as_const_value(nir_src_for_ssa(src[0]))) {
+         nir_ssa_def* tmp = src[i];
+         src[i] = src[0];
+         src[0] = tmp;
+      }
+   }
    nir_ssa_def *def;
    switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) {
    case FMin3AMD:
    nir_ssa_def *def;
    switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) {
    case FMin3AMD:
-      def = nir_fmin3(nb, src[0], src[1], src[2]);
+      def = nir_fmin(nb, src[0], nir_fmin(nb, src[1], src[2]));
       break;
    case UMin3AMD:
       break;
    case UMin3AMD:
-      def = nir_umin3(nb, src[0], src[1], src[2]);
+      def = nir_umin(nb, src[0], nir_umin(nb, src[1], src[2]));
       break;
    case SMin3AMD:
       break;
    case SMin3AMD:
-      def = nir_imin3(nb, src[0], src[1], src[2]);
+      def = nir_imin(nb, src[0], nir_imin(nb, src[1], src[2]));
       break;
    case FMax3AMD:
       break;
    case FMax3AMD:
-      def = nir_fmax3(nb, src[0], src[1], src[2]);
+      def = nir_fmax(nb, src[0], nir_fmax(nb, src[1], src[2]));
       break;
    case UMax3AMD:
       break;
    case UMax3AMD:
-      def = nir_umax3(nb, src[0], src[1], src[2]);
+      def = nir_umax(nb, src[0], nir_umax(nb, src[1], src[2]));
       break;
    case SMax3AMD:
       break;
    case SMax3AMD:
-      def = nir_imax3(nb, src[0], src[1], src[2]);
+      def = nir_imax(nb, src[0], nir_imax(nb, src[1], src[2]));
       break;
    case FMid3AMD:
       break;
    case FMid3AMD:
-      def = nir_fmed3(nb, src[0], src[1], src[2]);
+      def = nir_fmin(nb, nir_fmax(nb, src[0], nir_fmin(nb, src[1], src[2])),
+                     nir_fmax(nb, src[1], src[2]));
       break;
    case UMid3AMD:
       break;
    case UMid3AMD:
-      def = nir_umed3(nb, src[0], src[1], src[2]);
+      def = nir_umin(nb, nir_umax(nb, src[0], nir_umin(nb, src[1], src[2])),
+                     nir_umax(nb, src[1], src[2]));
       break;
    case SMid3AMD:
       break;
    case SMid3AMD:
-      def = nir_imed3(nb, src[0], src[1], src[2]);
+      def = nir_imin(nb, nir_imax(nb, src[0], nir_imin(nb, src[1], src[2])),
+                     nir_imax(nb, src[1], src[2]));
       break;
    default:
       unreachable("unknown opcode\n");
       break;
    default:
       unreachable("unknown opcode\n");