nir: Return progress from nir_lower_load_const_to_scalar().
[mesa.git] / src / compiler / spirv / vtn_alu.c
index 8b9a63ce7606bd991e2b0ce5927cd5ebab27a54f..9e4beedf926676a5ef09cee59d19b668ebd1052b 100644 (file)
@@ -211,7 +211,8 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
 }
 
 nir_op
-vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
+vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap,
+                                nir_alu_type src, nir_alu_type dst)
 {
    /* Indicates that the first two arguments should be swapped.  This is
     * used for implementing greater-than and less-than-or-equal.
@@ -257,7 +258,10 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
    case SpvOpBitReverse:            return nir_op_bitfield_reverse;
    case SpvOpBitCount:              return nir_op_bit_count;
 
-   /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */
+   /* The ordered / unordered operators need special implementation besides
+    * the logical operator to use since they also need to check if operands are
+    * ordered.
+    */
    case SpvOpFOrdEqual:                            return nir_op_feq;
    case SpvOpFUnordEqual:                          return nir_op_feq;
    case SpvOpINotEqual:                            return nir_op_ine;
@@ -281,16 +285,16 @@ vtn_nir_alu_op_for_spirv_opcode(SpvOp opcode, bool *swap)
    case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
 
    /* Conversions: */
-   case SpvOpConvertFToU:           return nir_op_f2u;
-   case SpvOpConvertFToS:           return nir_op_f2i;
-   case SpvOpConvertSToF:           return nir_op_i2f;
-   case SpvOpConvertUToF:           return nir_op_u2f;
    case SpvOpBitcast:               return nir_op_imov;
-   case SpvOpUConvert:
    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
-   /* TODO: NIR is 32-bit only; these are no-ops. */
-   case SpvOpSConvert:              return nir_op_imov;
-   case SpvOpFConvert:              return nir_op_fmov;
+   case SpvOpUConvert:
+   case SpvOpConvertFToU:
+   case SpvOpConvertFToS:
+   case SpvOpConvertSToF:
+   case SpvOpConvertUToF:
+   case SpvOpSConvert:
+   case SpvOpFConvert:
+      return nir_type_conversion_op(src, dst);
 
    /* Derivatives: */
    case SpvOpDPdx:         return nir_op_fddx;
@@ -356,6 +360,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
          case 2:  op = nir_op_bany_inequal2; break;
          case 3:  op = nir_op_bany_inequal3; break;
          case 4:  op = nir_op_bany_inequal4; break;
+         default: unreachable("invalid number of components");
          }
          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
                                        nir_imm_int(&b->nb, NIR_FALSE),
@@ -372,6 +377,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
          case 2:  op = nir_op_ball_iequal2;  break;
          case 3:  op = nir_op_ball_iequal3;  break;
          case 4:  op = nir_op_ball_iequal4;  break;
+         default: unreachable("invalid number of components");
          }
          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
                                        nir_imm_int(&b->nb, NIR_TRUE),
@@ -441,13 +447,67 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
 
    case SpvOpIsInf:
-      val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]),
+      val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]),
                                       nir_imm_float(&b->nb, INFINITY));
       break;
 
+   case SpvOpFUnordEqual:
+   case SpvOpFUnordNotEqual:
+   case SpvOpFUnordLessThan:
+   case SpvOpFUnordGreaterThan:
+   case SpvOpFUnordLessThanEqual:
+   case SpvOpFUnordGreaterThanEqual: {
+      bool swap;
+      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
+      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
+
+      if (swap) {
+         nir_ssa_def *tmp = src[0];
+         src[0] = src[1];
+         src[1] = tmp;
+      }
+
+      val->ssa->def =
+         nir_ior(&b->nb,
+                 nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
+                 nir_ior(&b->nb,
+                         nir_fne(&b->nb, src[0], src[0]),
+                         nir_fne(&b->nb, src[1], src[1])));
+      break;
+   }
+
+   case SpvOpFOrdEqual:
+   case SpvOpFOrdNotEqual:
+   case SpvOpFOrdLessThan:
+   case SpvOpFOrdGreaterThan:
+   case SpvOpFOrdLessThanEqual:
+   case SpvOpFOrdGreaterThanEqual: {
+      bool swap;
+      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
+      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
+
+      if (swap) {
+         nir_ssa_def *tmp = src[0];
+         src[0] = src[1];
+         src[1] = tmp;
+      }
+
+      val->ssa->def =
+         nir_iand(&b->nb,
+                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
+                  nir_iand(&b->nb,
+                          nir_feq(&b->nb, src[0], src[0]),
+                          nir_feq(&b->nb, src[1], src[1])));
+      break;
+   }
+
    default: {
       bool swap;
-      nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap);
+      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
+      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(opcode, &swap, src_alu_type, dst_alu_type);
 
       if (swap) {
          nir_ssa_def *tmp = src[0];