nir/spirv: initial handling of OpenCL.std extension opcodes
[mesa.git] / src / compiler / spirv / vtn_alu.c
index dc6fedc9129674ef5303e809e366de9e38de6c64..6bc015a096d642fed0c9bbbaabbb6c1d664add80 100644 (file)
@@ -395,7 +395,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    if (glsl_type_is_matrix(vtn_src[0]->type) ||
        (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) {
       vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]);
-      b->nb.exact = false;
+      b->nb.exact = b->exact;
       return;
    }
 
@@ -465,17 +465,21 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       val->ssa->elems[1]->def = nir_usub_borrow(&b->nb, src[0], src[1]);
       break;
 
-   case SpvOpUMulExtended:
+   case SpvOpUMulExtended: {
       vtn_assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]);
+      nir_ssa_def *umul = nir_umul_2x32_64(&b->nb, src[0], src[1]);
+      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, umul);
+      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, umul);
       break;
+   }
 
-   case SpvOpSMulExtended:
+   case SpvOpSMulExtended: {
       vtn_assert(glsl_type_is_struct(val->ssa->type));
-      val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]);
-      val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]);
+      nir_ssa_def *smul = nir_imul_2x32_64(&b->nb, src[0], src[1]);
+      val->ssa->elems[0]->def = nir_unpack_64_2x32_split_x(&b->nb, smul);
+      val->ssa->elems[1]->def = nir_unpack_64_2x32_split_y(&b->nb, smul);
       break;
+   }
 
    case SpvOpFwidth:
       val->ssa->def = nir_fadd(&b->nb,
@@ -632,6 +636,21 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       break;
    }
 
+   case SpvOpSignBitSet: {
+      unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
+      if (src[0]->num_components == 1)
+         val->ssa->def =
+            nir_ushr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
+      else
+         val->ssa->def =
+            nir_ishr(&b->nb, src[0], nir_imm_int(&b->nb, src_bit_size - 1));
+
+      if (src_bit_size != 32)
+         val->ssa->def = nir_u2u32(&b->nb, val->ssa->def);
+
+      break;
+   }
+
    default: {
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
@@ -661,5 +680,5 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
    } /* default */
    }
 
-   b->nb.exact = false;
+   b->nb.exact = b->exact;
 }