From 12913bcf86031cefbfcb70cca5e0565be36eef5f Mon Sep 17 00:00:00 2001 From: Dave Airlie Date: Tue, 30 Apr 2019 06:57:11 +1000 Subject: [PATCH] spirv/nir/opencl: handle some multiply instructions. This adds support for some missing 24-bit and hi multiply variants. Reviewed-by: Jason Ekstrand --- src/compiler/nir/nir_builtin_builder.h | 37 ++++++++++++++++++++++++++ src/compiler/spirv/vtn_opencl.c | 18 +++++++++++++ 2 files changed, 55 insertions(+) diff --git a/src/compiler/nir/nir_builtin_builder.h b/src/compiler/nir/nir_builtin_builder.h index 2ec300c4474..e2fbf02b3ea 100644 --- a/src/compiler/nir/nir_builtin_builder.h +++ b/src/compiler/nir/nir_builtin_builder.h @@ -82,6 +82,43 @@ nir_uabs_diff(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y) return nir_bcsel(b, cond, res0, res1); } +static inline nir_ssa_def * +nir_umul24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y) +{ + nir_ssa_def *mask = nir_imm_int(b, 0xffffff); + nir_ssa_def *x_24 = nir_iand(b, x, mask); + nir_ssa_def *y_24 = nir_iand(b, y, mask); + return nir_imul(b, x_24, y_24); +} + +static inline nir_ssa_def * +nir_umad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z) +{ + nir_ssa_def *temp = nir_umul24(b, x, y); + return nir_iadd(b, temp, z); +} + +static inline nir_ssa_def * +nir_imad24(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z) +{ + nir_ssa_def *temp = nir_imul24(b, x, y); + return nir_iadd(b, temp, z); +} + +static inline nir_ssa_def * +nir_imad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z) +{ + nir_ssa_def *temp = nir_imul_high(b, x, y); + return nir_iadd(b, temp, z); +} + +static inline nir_ssa_def * +nir_umad_hi(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *z) +{ + nir_ssa_def *temp = nir_umul_high(b, x, y); + return nir_iadd(b, temp, z); +} + static inline nir_ssa_def * nir_bitselect(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y, nir_ssa_def *s) { diff --git a/src/compiler/spirv/vtn_opencl.c b/src/compiler/spirv/vtn_opencl.c index e26b5bc0457..165c0de4a7a 100644 --- a/src/compiler/spirv/vtn_opencl.c +++ b/src/compiler/spirv/vtn_opencl.c @@ -129,6 +129,18 @@ handle_special(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode, return nir_uabs_diff(nb, srcs[0], srcs[1]); case OpenCLstd_Bitselect: return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_SMad_hi: + return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_UMad_hi: + return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_SMul24: + return nir_imul24(nb, srcs[0], srcs[1]); + case OpenCLstd_UMul24: + return nir_umul24(nb, srcs[0], srcs[1]); + case OpenCLstd_SMad24: + return nir_imad24(nb, srcs[0], srcs[1], srcs[2]); + case OpenCLstd_UMad24: + return nir_umad24(nb, srcs[0], srcs[1], srcs[2]); case OpenCLstd_FClamp: return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]); case OpenCLstd_SClamp: @@ -288,6 +300,12 @@ vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode, return true; case OpenCLstd_SAbs_diff: case OpenCLstd_UAbs_diff: + case OpenCLstd_SMad_hi: + case OpenCLstd_UMad_hi: + case OpenCLstd_SMad24: + case OpenCLstd_UMad24: + case OpenCLstd_SMul24: + case OpenCLstd_UMul24: case OpenCLstd_Bitselect: case OpenCLstd_FClamp: case OpenCLstd_SClamp: -- 2.30.2