nir/vtn: Support OpOrdered and OpUnordered opcodes
[mesa.git] / src / compiler / spirv / vtn_alu.c
index 60e88144ceb196175c50b5752fb247297b790c46..e92a48f24a63962b5cc86e828f483ffe7ec9ad37 100644 (file)
@@ -43,7 +43,7 @@ wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val)
       return val;
 
    struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value);
-   dest->type = val->type;
+   dest->type = glsl_get_bare_type(val->type);
    dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1);
    dest->elems[0] = val;
 
@@ -153,48 +153,46 @@ mat_times_scalar(struct vtn_builder *b,
    return dest;
 }
 
-static void
+static struct vtn_ssa_value *
 vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
-                      struct vtn_value *dest,
                       struct vtn_ssa_value *src0, struct vtn_ssa_value *src1)
 {
    switch (opcode) {
    case SpvOpFNegate: {
-      dest->ssa = vtn_create_ssa_value(b, src0->type);
+      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
       unsigned cols = glsl_get_matrix_columns(src0->type);
       for (unsigned i = 0; i < cols; i++)
-         dest->ssa->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
-      break;
+         dest->elems[i]->def = nir_fneg(&b->nb, src0->elems[i]->def);
+      return dest;
    }
 
    case SpvOpFAdd: {
-      dest->ssa = vtn_create_ssa_value(b, src0->type);
+      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
       unsigned cols = glsl_get_matrix_columns(src0->type);
       for (unsigned i = 0; i < cols; i++)
-         dest->ssa->elems[i]->def =
+         dest->elems[i]->def =
             nir_fadd(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
-      break;
+      return dest;
    }
 
    case SpvOpFSub: {
-      dest->ssa = vtn_create_ssa_value(b, src0->type);
+      struct vtn_ssa_value *dest = vtn_create_ssa_value(b, src0->type);
       unsigned cols = glsl_get_matrix_columns(src0->type);
       for (unsigned i = 0; i < cols; i++)
-         dest->ssa->elems[i]->def =
+         dest->elems[i]->def =
             nir_fsub(&b->nb, src0->elems[i]->def, src1->elems[i]->def);
-      break;
+      return dest;
    }
 
    case SpvOpTranspose:
-      dest->ssa = vtn_ssa_transpose(b, src0);
-      break;
+      return vtn_ssa_transpose(b, src0);
 
    case SpvOpMatrixTimesScalar:
       if (src0->transposed) {
-         dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
-                                                           src1->def));
+         return vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed,
+                                                         src1->def));
       } else {
-         dest->ssa = mat_times_scalar(b, src0, src1->def);
+         return mat_times_scalar(b, src0, src1->def);
       }
       break;
 
@@ -202,9 +200,9 @@ vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode,
    case SpvOpMatrixTimesVector:
    case SpvOpMatrixTimesMatrix:
       if (opcode == SpvOpVectorTimesMatrix) {
-         dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
+         return matrix_multiply(b, vtn_ssa_transpose(b, src1), src0);
       } else {
-         dest->ssa = matrix_multiply(b, src0, src1);
+         return matrix_multiply(b, src0, src1);
       }
       break;
 
@@ -283,6 +281,7 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
    case SpvOpFOrdEqual:                            return nir_op_feq;
    case SpvOpFUnordEqual:                          return nir_op_feq;
    case SpvOpINotEqual:                            return nir_op_ine;
+   case SpvOpLessOrGreater:                        /* Deprecated, use OrdNotEqual */
    case SpvOpFOrdNotEqual:                         return nir_op_fne;
    case SpvOpFUnordNotEqual:                       return nir_op_fne;
    case SpvOpULessThan:                            return nir_op_ult;
@@ -355,6 +354,9 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
    case SpvOpDPdxCoarse:   return nir_op_fddx_coarse;
    case SpvOpDPdyCoarse:   return nir_op_fddy_coarse;
 
+   case SpvOpIsNormal:     return nir_op_fisnormal;
+   case SpvOpIsFinite:     return nir_op_fisfinite;
+
    default:
       vtn_fail("No NIR equivalent: %u", opcode);
    }
@@ -414,10 +416,10 @@ void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
 {
-   struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
-   const struct glsl_type *type = vtn_get_type(b, w[1])->type;
+   struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
+   const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
 
-   vtn_foreach_decoration(b, val, handle_no_contraction, NULL);
+   vtn_foreach_decoration(b, dest_val, handle_no_contraction, NULL);
 
    /* Collect the various SSA sources */
    const unsigned num_inputs = count - 3;
@@ -427,12 +429,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
 
    if (glsl_type_is_matrix(vtn_src[0]->type) ||
        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
-      vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
+      vtn_push_ssa_value(b, w[2],
+         vtn_handle_matrix_alu(b, opcode, vtn_src[0], vtn_src[1]));
       b->nb.exact = b->exact;
       return;
    }
 
-   val->ssa = vtn_create_ssa_value(b, type);
+   struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type);
    nir_ssa_def *src[4] = { NULL, };
    for (unsigned i = 0; i < num_inputs; i++) {
       vtn_assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type));
@@ -441,81 +444,91 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
 
    switch (opcode) {
    case SpvOpAny:
-      val->ssa->def = nir_bany(&b->nb, src[0]);
+      dest->def = nir_bany(&b->nb, src[0]);
       break;
 
    case SpvOpAll:
-      val->ssa->def = nir_ball(&b->nb, src[0]);
+      dest->def = nir_ball(&b->nb, src[0]);
       break;
 
    case SpvOpOuterProduct: {
       for (unsigned i = 0; i < src[1]->num_components; i++) {
-         val->ssa->elems[i]->def =
+         dest->elems[i]->def =
             nir_fmul(&b->nb, src[0], nir_channel(&b->nb, src[1], i));
       }
       break;
    }
 
    case SpvOpDot:
-      val->ssa->def = nir_fdot(&b->nb, src[0], src[1]);
+      dest->def = nir_fdot(&b->nb, src[0], src[1]);
       break;
 
    case SpvOpIAddCarry:
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
-      val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
+      dest->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]);
+      dest->elems[1]->def = nir_uadd_carry(&b->nb, src[0], src[1]);
       break;
 
    case SpvOpISubBorrow:
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
-      val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
+      dest->elems[0]->def = nir_isub(&b->nb, src[0], src[1]);
+      dest->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
       break;
 
    case SpvOpUMulExtended: {
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
       nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
-      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
-      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
+      dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
+      dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
       break;
    }
 
    case SpvOpSMulExtended: {
-      vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type));
+      vtn_assert(glsl_type_is_struct_or_ifc(dest_type));
       nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
-      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
-      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
+      dest->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
+      dest->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
       break;
    }
 
    case SpvOpFwidth:
-      val->ssa->def = nir_fadd(&b->nb,
+      dest->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy(&b->nb, src[0])));
       break;
    case SpvOpFwidthFine:
-      val->ssa->def = nir_fadd(&b->nb,
+      dest->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy_fine(&b->nb, src[0])));
       break;
    case SpvOpFwidthCoarse:
-      val->ssa->def = nir_fadd(&b->nb,
+      dest->def = nir_fadd(&b->nb,
                                nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])),
                                nir_fabs(&b->nb, nir_fddy_coarse(&b->nb, src[0])));
       break;
 
    case SpvOpVectorTimesScalar:
       /* The builder will take care of splatting for us. */
-      val->ssa->def = nir_fmul(&b->nb, src[0], src[1]);
+      dest->def = nir_fmul(&b->nb, src[0], src[1]);
       break;
 
    case SpvOpIsNan:
-      val->ssa->def = nir_fne(&b->nb, src[0], src[0]);
+      dest->def = nir_fne(&b->nb, src[0], src[0]);
+      break;
+
+   case SpvOpOrdered:
+      dest->def = nir_iand(&b->nb, nir_feq(&b->nb, src[0], src[0]),
+                                   nir_feq(&b->nb, src[1], src[1]));
+      break;
+
+   case SpvOpUnordered:
+      dest->def = nir_ior(&b->nb, nir_fne(&b->nb, src[0], src[0]),
+                                  nir_fne(&b->nb, src[1], src[1]));
       break;
 
    case SpvOpIsInf: {
       nir_ssa_def *inf = nir_imm_floatN_t(&b->nb, INFINITY, src[0]->bit_size);
-      val->ssa->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
+      dest->def = nir_ieq(&b->nb, nir_fabs(&b->nb, src[0]), inf);
       break;
    }
 
@@ -527,7 +540,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    case SpvOpFUnordGreaterThanEqual: {
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
-      unsigned dst_bit_size = glsl_get_bit_size(type);
+      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
@@ -537,7 +550,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
          src[1] = tmp;
       }
 
-      val->ssa->def =
+      dest->def =
          nir_ior(&b->nb,
                  nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
                  nir_ior(&b->nb,
@@ -546,6 +559,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   case SpvOpLessOrGreater:
    case SpvOpFOrdNotEqual: {
       /* For all the SpvOpFOrd* comparisons apart from NotEqual, the value
        * from the ALU will probably already be false if the operands are not
@@ -553,13 +567,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
        */
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
-      unsigned dst_bit_size = glsl_get_bit_size(type);
+      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
       assert(!swap);
 
-      val->ssa->def =
+      dest->def =
          nir_iand(&b->nb,
                   nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL),
                   nir_iand(&b->nb,
@@ -570,13 +584,13 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
 
    case SpvOpFConvert: {
       nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
-      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
+      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(dest_type);
       nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
 
-      vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
+      vtn_foreach_decoration(b, dest_val, handle_rounding_mode, &rounding_mode);
       nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
 
-      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
+      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
       break;
    }
 
@@ -588,7 +602,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    case SpvOpShiftRightLogical: {
       bool swap;
       unsigned src0_bit_size = glsl_get_bit_size(vtn_src[0]->type);
-      unsigned dst_bit_size = glsl_get_bit_size(type);
+      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
                                                   src0_bit_size, dst_bit_size);
 
@@ -611,17 +625,17 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
             src[i] = nir_u2u32(&b->nb, src[i]);
          }
       }
-      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
+      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
       break;
    }
 
    case SpvOpSignBitSet:
-      val->ssa->def = nir_i2b(&b->nb,
+      dest->def = nir_i2b(&b->nb,
          nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src[0]->bit_size - 1)));
       break;
 
    case SpvOpUCountTrailingZerosINTEL:
-      val->ssa->def = nir_umin(&b->nb,
+      dest->def = nir_umin(&b->nb,
                                nir_find_lsb(&b->nb, src[0]),
                                nir_imm_int(&b->nb, 32u));
       break;
@@ -629,7 +643,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    default: {
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
-      unsigned dst_bit_size = glsl_get_bit_size(type);
+      unsigned dst_bit_size = glsl_get_bit_size(dest_type);
       nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
@@ -650,7 +664,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
          break;
       }
 
-      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
+      dest->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]);
       break;
    } /* default */
    }
@@ -661,8 +675,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    case SpvOpISub:
    case SpvOpShiftLeftLogical:
    case SpvOpSNegate: {
-      nir_alu_instr *alu = nir_instr_as_alu(val->ssa->def->parent_instr);
-      vtn_foreach_decoration(b, val, handle_no_wrap, alu);
+      nir_alu_instr *alu = nir_instr_as_alu(dest->def->parent_instr);
+      vtn_foreach_decoration(b, dest_val, handle_no_wrap, alu);
       break;
    }
    default:
@@ -670,6 +684,8 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   vtn_push_ssa_value(b, w[2], dest);
+
    b->nb.exact = b->exact;
 }
 
@@ -697,16 +713,13 @@ vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w, unsigned count)
     */
 
    struct vtn_type *type = vtn_get_type(b, w[1]);
-   struct vtn_ssa_value *vtn_src = vtn_ssa_value(b, w[3]);
-   struct nir_ssa_def *src = vtn_src->def;
-   struct vtn_ssa_value *val = vtn_create_ssa_value(b, type->type);
-
-   vtn_assert(glsl_type_is_vector_or_scalar(vtn_src->type));
+   struct nir_ssa_def *src = vtn_get_nir_ssa(b, w[3]);
 
    vtn_fail_if(src->num_components * src->bit_size !=
                glsl_get_vector_elements(type->type) * glsl_get_bit_size(type->type),
                "Source and destination of OpBitcast must have the same "
                "total number of bits");
-   val->def = nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
-   vtn_push_ssa(b, w[2], type, val);
+   nir_ssa_def *val =
+      nir_bitcast_vector(&b->nb, src, glsl_get_bit_size(type->type));
+   vtn_push_nir_ssa(b, w[2], val);
 }