spirv/nir/opencl: handle some multiply instructions.
authorDave Airlie <airlied@redhat.com>
Mon, 29 Apr 2019 20:57:11 +0000 (06:57 +1000)
committerDave Airlie <airlied@redhat.com>
Thu, 21 Nov 2019 18:37:25 +0000 (04:37 +1000)
This adds support for some missing 24-bit and hi multiply
variants.

Reviewed-by: Jason Ekstrand <jason@jlekstrand.net>
src/compiler/nir/nir_builtin_builder.h
src/compiler/spirv/vtn_opencl.c

index 2ec300c4474f5896e20bb413e7d2ab447522523c..e2fbf02b3eafb55859dcacd1ce224d5741bef13d 100644 (file)
@@ -82,6 +82,43 @@ nir_uabs_diff(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
    return nir_bcsel(b, cond, res0, res1);
 }
 
+static inline nir_ssa_def *
+nir_umul24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
+{
+   nir_ssa_def *mask = nir_imm_int(b, 0xffffff);
+   nir_ssa_def *x_24 = nir_iand(b, x, mask);
+   nir_ssa_def *y_24 = nir_iand(b, y, mask);
+   return nir_imul(b, x_24, y_24);
+}
+
+static inline nir_ssa_def *
+nir_umad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
+{
+   nir_ssa_def *temp = nir_umul24(b, x, y);
+   return nir_iadd(b, temp, z);
+}
+
+static inline nir_ssa_def *
+nir_imad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
+{
+   nir_ssa_def *temp = nir_imul24(b, x, y);
+   return nir_iadd(b, temp, z);
+}
+
+static inline nir_ssa_def *
+nir_imad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
+{
+   nir_ssa_def *temp = nir_imul_high(b, x, y);
+   return nir_iadd(b, temp, z);
+}
+
+static inline nir_ssa_def *
+nir_umad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z)
+{
+   nir_ssa_def *temp = nir_umul_high(b, x, y);
+   return nir_iadd(b, temp, z);
+}
+
 static inline nir_ssa_def *
 nir_bitselect(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *s)
 {
index e26b5bc0457e6bf364fcf7c891638560af461406..165c0de4a7aec69d68c84ac306372348c5a75b00 100644 (file)
@@ -129,6 +129,18 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
       return nir_uabs_diff(nb, srcs[0], srcs[1]);
    case OpenCLstd_Bitselect:
       return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
+   case OpenCLstd_SMad_hi:
+      return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
+   case OpenCLstd_UMad_hi:
+      return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
+   case OpenCLstd_SMul24:
+      return nir_imul24(nb, srcs[0], srcs[1]);
+   case OpenCLstd_UMul24:
+      return nir_umul24(nb, srcs[0], srcs[1]);
+   case OpenCLstd_SMad24:
+      return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
+   case OpenCLstd_UMad24:
+      return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
    case OpenCLstd_FClamp:
       return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
    case OpenCLstd_SClamp:
@@ -288,6 +300,12 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
       return true;
    case OpenCLstd_SAbs_diff:
    case OpenCLstd_UAbs_diff:
+   case OpenCLstd_SMad_hi:
+   case OpenCLstd_UMad_hi:
+   case OpenCLstd_SMad24:
+   case OpenCLstd_UMad24:
+   case OpenCLstd_SMul24:
+   case OpenCLstd_UMul24:
    case OpenCLstd_Bitselect:
    case OpenCLstd_FClamp:
    case OpenCLstd_SClamp: