nir/spirv: initial handling of OpenCL.std extension opcodes
[mesa.git] / src / compiler / spirv / vtn_alu.c
index b04ada92199a133a7d064a0d02f61dd26aef7930..6bc015a096d642fed0c9bbbaabbb6c1d664add80 100644 (file)
@@ -244,15 +244,15 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
    case SpvOpShiftRightArithmetic:  return nir_op_ishr;
    case SpvOpShiftLeftLogical:      return nir_op_ishl;
    case SpvOpLogicalOr:             return nir_op_ior;
-   case SpvOpLogicalEqual:          return nir_op_ieq32;
-   case SpvOpLogicalNotEqual:       return nir_op_ine32;
+   case SpvOpLogicalEqual:          return nir_op_ieq;
+   case SpvOpLogicalNotEqual:       return nir_op_ine;
    case SpvOpLogicalAnd:            return nir_op_iand;
    case SpvOpLogicalNot:            return nir_op_inot;
    case SpvOpBitwiseOr:             return nir_op_ior;
    case SpvOpBitwiseXor:            return nir_op_ixor;
    case SpvOpBitwiseAnd:            return nir_op_iand;
-   case SpvOpSelect:                return nir_op_b32csel;
-   case SpvOpIEqual:                return nir_op_ieq32;
+   case SpvOpSelect:                return nir_op_bcsel;
+   case SpvOpIEqual:                return nir_op_ieq;
 
    case SpvOpBitFieldInsert:        return nir_op_bitfield_insert;
    case SpvOpBitFieldSExtract:      return nir_op_ibitfield_extract;
@@ -264,27 +264,27 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
     * the logical operator to use since they also need to check if operands are
     * ordered.
     */
-   case SpvOpFOrdEqual:                            return nir_op_feq32;
-   case SpvOpFUnordEqual:                          return nir_op_feq32;
-   case SpvOpINotEqual:                            return nir_op_ine32;
-   case SpvOpFOrdNotEqual:                         return nir_op_fne32;
-   case SpvOpFUnordNotEqual:                       return nir_op_fne32;
-   case SpvOpULessThan:                            return nir_op_ult32;
-   case SpvOpSLessThan:                            return nir_op_ilt32;
-   case SpvOpFOrdLessThan:                         return nir_op_flt32;
-   case SpvOpFUnordLessThan:                       return nir_op_flt32;
-   case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult32;
-   case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt32;
-   case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt32;
-   case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt32;
-   case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge32;
-   case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige32;
-   case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge32;
-   case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge32;
-   case SpvOpUGreaterThanEqual:                    return nir_op_uge32;
-   case SpvOpSGreaterThanEqual:                    return nir_op_ige32;
-   case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge32;
-   case SpvOpFUnordGreaterThanEqual:               return nir_op_fge32;
+   case SpvOpFOrdEqual:                            return nir_op_feq;
+   case SpvOpFUnordEqual:                          return nir_op_feq;
+   case SpvOpINotEqual:                            return nir_op_ine;
+   case SpvOpFOrdNotEqual:                         return nir_op_fne;
+   case SpvOpFUnordNotEqual:                       return nir_op_fne;
+   case SpvOpULessThan:                            return nir_op_ult;
+   case SpvOpSLessThan:                            return nir_op_ilt;
+   case SpvOpFOrdLessThan:                         return nir_op_flt;
+   case SpvOpFUnordLessThan:                       return nir_op_flt;
+   case SpvOpUGreaterThan:          *swap = true;  return nir_op_ult;
+   case SpvOpSGreaterThan:          *swap = true;  return nir_op_ilt;
+   case SpvOpFOrdGreaterThan:       *swap = true;  return nir_op_flt;
+   case SpvOpFUnordGreaterThan:     *swap = true;  return nir_op_flt;
+   case SpvOpULessThanEqual:        *swap = true;  return nir_op_uge;
+   case SpvOpSLessThanEqual:        *swap = true;  return nir_op_ige;
+   case SpvOpFOrdLessThanEqual:     *swap = true;  return nir_op_fge;
+   case SpvOpFUnordLessThanEqual:   *swap = true;  return nir_op_fge;
+   case SpvOpUGreaterThanEqual:                    return nir_op_uge;
+   case SpvOpSGreaterThanEqual:                    return nir_op_ige;
+   case SpvOpFOrdGreaterThanEqual:                 return nir_op_fge;
+   case SpvOpFUnordGreaterThanEqual:               return nir_op_fge;
 
    /* Conversions: */
    case SpvOpQuantizeToF16:         return nir_op_fquantize2f16;
@@ -395,7 +395,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    if (glsl_type_is_matrix(vtn_src[0]->type) ||
        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
-      b->nb.exact = false;
+      b->nb.exact = b->exact;
       return;
    }
 
@@ -413,9 +413,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       } else {
          nir_op op;
          switch (src[0]->num_components) {
-         case 2:  op = nir_op_b32any_inequal2; break;
-         case 3:  op = nir_op_b32any_inequal3; break;
-         case 4:  op = nir_op_b32any_inequal4; break;
+         case 2:  op = nir_op_bany_inequal2; break;
+         case 3:  op = nir_op_bany_inequal3; break;
+         case 4:  op = nir_op_bany_inequal4; break;
          default: vtn_fail("invalid number of components");
          }
          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
@@ -430,9 +430,9 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       } else {
          nir_op op;
          switch (src[0]->num_components) {
-         case 2:  op = nir_op_b32all_iequal2;  break;
-         case 3:  op = nir_op_b32all_iequal3;  break;
-         case 4:  op = nir_op_b32all_iequal4;  break;
+         case 2:  op = nir_op_ball_iequal2;  break;
+         case 3:  op = nir_op_ball_iequal3;  break;
+         case 4:  op = nir_op_ball_iequal4;  break;
          default: vtn_fail("invalid number of components");
          }
          val->ssa->def = nir_build_alu(&b->nb, op, src[0],
@@ -465,17 +465,21 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
       break;
 
-   case SpvOpUMulExtended:
+   case SpvOpUMulExtended: {
       vtn_assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
+      nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
+      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
+      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
       break;
+   }
 
-   case SpvOpSMulExtended:
+   case SpvOpSMulExtended: {
       vtn_assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
+      nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
+      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
+      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
       break;
+   }
 
    case SpvOpFwidth:
       val->ssa->def = nir_fadd(&b->nb,
@@ -632,6 +636,21 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   case SpvOpSignBitSet: {
+      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
+      if (src[0]->num_components == 1)
+         val->ssa->def =
+            nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
+      else
+         val->ssa->def =
+            nir_ishr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
+
+      if (src_bit_size != 32)
+         val->ssa->def = nir_u2u32(&b->nb, val->ssa->def);
+
+      break;
+   }
+
    default: {
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
@@ -661,5 +680,5 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    } /* default */
    }
 
-   b->nb.exact = false;
+   b->nb.exact = b->exact;
 }