nir,amd: remove trinary_minmax opcodes
[mesa.git] / src / compiler / spirv / vtn_amd.c
index 4ba8193b532dd2df07a76c4f61c1c466818f7471..55000418dcd49a911662ca3346cf8c215c08a405 100644 (file)
@@ -126,34 +126,45 @@ vtn_handle_amd_shader_trinary_minmax_instruction(struct vtn_builder *b, SpvOp ex
    for (unsigned i = 0; i < num_inputs; i++)
       src[i] = vtn_get_nir_ssa(b, w[i + 5]);
 
+   /* place constants at src[1-2] for easier constant-folding */
+   for (unsigned i = 1; i <= 2; i++) {
+      if (nir_src_as_const_value(nir_src_for_ssa(src[0]))) {
+         nir_ssa_def* tmp = src[i];
+         src[i] = src[0];
+         src[0] = tmp;
+      }
+   }
    nir_ssa_def *def;
    switch ((enum ShaderTrinaryMinMaxAMD)ext_opcode) {
    case FMin3AMD:
-      def = nir_fmin3(nb, src[0], src[1], src[2]);
+      def = nir_fmin(nb, src[0], nir_fmin(nb, src[1], src[2]));
       break;
    case UMin3AMD:
-      def = nir_umin3(nb, src[0], src[1], src[2]);
+      def = nir_umin(nb, src[0], nir_umin(nb, src[1], src[2]));
       break;
    case SMin3AMD:
-      def = nir_imin3(nb, src[0], src[1], src[2]);
+      def = nir_imin(nb, src[0], nir_imin(nb, src[1], src[2]));
       break;
    case FMax3AMD:
-      def = nir_fmax3(nb, src[0], src[1], src[2]);
+      def = nir_fmax(nb, src[0], nir_fmax(nb, src[1], src[2]));
       break;
    case UMax3AMD:
-      def = nir_umax3(nb, src[0], src[1], src[2]);
+      def = nir_umax(nb, src[0], nir_umax(nb, src[1], src[2]));
       break;
    case SMax3AMD:
-      def = nir_imax3(nb, src[0], src[1], src[2]);
+      def = nir_imax(nb, src[0], nir_imax(nb, src[1], src[2]));
       break;
    case FMid3AMD:
-      def = nir_fmed3(nb, src[0], src[1], src[2]);
+      def = nir_fmin(nb, nir_fmax(nb, src[0], nir_fmin(nb, src[1], src[2])),
+                     nir_fmax(nb, src[1], src[2]));
       break;
    case UMid3AMD:
-      def = nir_umed3(nb, src[0], src[1], src[2]);
+      def = nir_umin(nb, nir_umax(nb, src[0], nir_umin(nb, src[1], src[2])),
+                     nir_umax(nb, src[1], src[2]));
       break;
    case SMid3AMD:
-      def = nir_imed3(nb, src[0], src[1], src[2]);
+      def = nir_imin(nb, nir_imax(nb, src[0], nir_imin(nb, src[1], src[2])),
+                     nir_imax(nb, src[1], src[2]));
       break;
    default:
       unreachable("unknown opcode\n");