nir/spirv: initial handling of OpenCL.std extension opcodes
[mesa.git] / src / compiler / spirv / vtn_opencl.c
1 /*
2 * Copyright © 2018 Red Hat
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Rob Clark (robdclark@gmail.com)
25 */
26
27 #include "math.h"
28
29 #include "nir/nir_builtin_builder.h"
30
31 #include "vtn_private.h"
32 #include "OpenCL.std.h"
33
34 typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b, enum OpenCLstd opcode,
35 unsigned num_srcs, nir_ssa_def **srcs,
36 const struct glsl_type *dest_type);
37
38 static void
39 handle_instr(struct vtn_builder *b, enum OpenCLstd opcode, const uint32_t *w,
40 unsigned count, nir_handler handler)
41 {
42 const struct glsl_type *dest_type =
43 vtn_value(b, w[1], vtn_value_type_type)->type->type;
44
45 unsigned num_srcs = count - 5;
46 nir_ssa_def *srcs[3] = { NULL };
47 vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
48 for (unsigned i = 0; i < num_srcs; i++) {
49 srcs[i] = vtn_ssa_value(b, w[i + 5])->def;
50 }
51
52 nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, dest_type);
53 if (result) {
54 struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa);
55 val->ssa = vtn_create_ssa_value(b, dest_type);
56 val->ssa->def = result;
57 } else {
58 vtn_assert(dest_type == glsl_void_type());
59 }
60 }
61
62 static nir_op
63 nir_alu_op_for_opencl_opcode(struct vtn_builder *b, enum OpenCLstd opcode)
64 {
65 switch (opcode) {
66 case Fabs: return nir_op_fabs;
67 case SAbs: return nir_op_iabs;
68 case SAdd_sat: return nir_op_iadd_sat;
69 case UAdd_sat: return nir_op_uadd_sat;
70 case Ceil: return nir_op_fceil;
71 case Cos: return nir_op_fcos;
72 case Exp2: return nir_op_fexp2;
73 case Log2: return nir_op_flog2;
74 case Floor: return nir_op_ffloor;
75 case SHadd: return nir_op_ihadd;
76 case UHadd: return nir_op_uhadd;
77 case Fma: return nir_op_ffma;
78 case Fmax: return nir_op_fmax;
79 case SMax: return nir_op_imax;
80 case UMax: return nir_op_umax;
81 case Fmin: return nir_op_fmin;
82 case SMin: return nir_op_imin;
83 case UMin: return nir_op_umin;
84 case Fmod: return nir_op_fmod;
85 case Mix: return nir_op_flrp;
86 case SMul_hi: return nir_op_imul_high;
87 case UMul_hi: return nir_op_umul_high;
88 case Popcount: return nir_op_bit_count;
89 case Pow: return nir_op_fpow;
90 case Remainder: return nir_op_frem;
91 case SRhadd: return nir_op_irhadd;
92 case URhadd: return nir_op_urhadd;
93 case Rsqrt: return nir_op_frsq;
94 case Sign: return nir_op_fsign;
95 case Sin: return nir_op_fsin;
96 case Sqrt: return nir_op_fsqrt;
97 case SSub_sat: return nir_op_isub_sat;
98 case USub_sat: return nir_op_usub_sat;
99 case Trunc: return nir_op_ftrunc;
100 /* uhm... */
101 case UAbs: return nir_op_imov;
102 default:
103 vtn_fail("No NIR equivalent");
104 }
105 }
106
107 static nir_ssa_def *
108 handle_alu(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
109 nir_ssa_def **srcs, const struct glsl_type *dest_type)
110 {
111 return nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, opcode),
112 srcs[0], srcs[1], srcs[2], NULL);
113 }
114
115 static nir_ssa_def *
116 handle_special(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
117 nir_ssa_def **srcs, const struct glsl_type *dest_type)
118 {
119 nir_builder *nb = &b->nb;
120
121 switch (opcode) {
122 case SAbs_diff:
123 return nir_iabs_diff(nb, srcs[0], srcs[1]);
124 case UAbs_diff:
125 return nir_uabs_diff(nb, srcs[0], srcs[1]);
126 case Bitselect:
127 return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
128 case FClamp:
129 return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
130 case SClamp:
131 return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
132 case UClamp:
133 return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
134 case Copysign:
135 return nir_copysign(nb, srcs[0], srcs[1]);
136 case Cross:
137 if (glsl_get_components(dest_type) == 4)
138 return nir_cross4(nb, srcs[0], srcs[1]);
139 return nir_cross3(nb, srcs[0], srcs[1]);
140 case Degrees:
141 return nir_degrees(nb, srcs[0]);
142 case Fdim:
143 return nir_fdim(nb, srcs[0], srcs[1]);
144 case Distance:
145 return nir_distance(nb, srcs[0], srcs[1]);
146 case Fast_distance:
147 return nir_fast_distance(nb, srcs[0], srcs[1]);
148 case Fast_length:
149 return nir_fast_length(nb, srcs[0]);
150 case Fast_normalize:
151 return nir_fast_normalize(nb, srcs[0]);
152 case Length:
153 return nir_length(nb, srcs[0]);
154 case Mad:
155 return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
156 case Maxmag:
157 return nir_maxmag(nb, srcs[0], srcs[1]);
158 case Minmag:
159 return nir_minmag(nb, srcs[0], srcs[1]);
160 case Nan:
161 return nir_nan(nb, srcs[0]);
162 case Nextafter:
163 return nir_nextafter(nb, srcs[0], srcs[1]);
164 case Normalize:
165 return nir_normalize(nb, srcs[0]);
166 case Radians:
167 return nir_radians(nb, srcs[0]);
168 case Rotate:
169 return nir_rotate(nb, srcs[0], srcs[1]);
170 case Smoothstep:
171 return nir_smoothstep(nb, srcs[0], srcs[1], srcs[2]);
172 case Select:
173 return nir_select(nb, srcs[0], srcs[1], srcs[2]);
174 case Step:
175 return nir_sge(nb, srcs[1], srcs[0]);
176 case S_Upsample:
177 case U_Upsample:
178 return nir_upsample(nb, srcs[0], srcs[1]);
179 default:
180 vtn_fail("No NIR equivalent");
181 return NULL;
182 }
183 }
184
185 static nir_ssa_def *
186 handle_printf(struct vtn_builder *b, enum OpenCLstd opcode, unsigned num_srcs,
187 nir_ssa_def **srcs, const struct glsl_type *dest_type)
188 {
189 /* hahah, yeah, right.. */
190 return nir_imm_int(&b->nb, -1);
191 }
192
193 bool
194 vtn_handle_opencl_instruction(struct vtn_builder *b, uint32_t ext_opcode,
195 const uint32_t *w, unsigned count)
196 {
197 switch (ext_opcode) {
198 case Fabs:
199 case SAbs:
200 case UAbs:
201 case SAdd_sat:
202 case UAdd_sat:
203 case Ceil:
204 case Cos:
205 case Exp2:
206 case Log2:
207 case Floor:
208 case Fma:
209 case Fmax:
210 case SHadd:
211 case UHadd:
212 case SMax:
213 case UMax:
214 case Fmin:
215 case SMin:
216 case UMin:
217 case Mix:
218 case Fmod:
219 case SMul_hi:
220 case UMul_hi:
221 case Popcount:
222 case Pow:
223 case Remainder:
224 case SRhadd:
225 case URhadd:
226 case Rsqrt:
227 case Sign:
228 case Sin:
229 case Sqrt:
230 case SSub_sat:
231 case USub_sat:
232 case Trunc:
233 handle_instr(b, ext_opcode, w, count, handle_alu);
234 return true;
235 case SAbs_diff:
236 case UAbs_diff:
237 case Bitselect:
238 case FClamp:
239 case SClamp:
240 case UClamp:
241 case Copysign:
242 case Cross:
243 case Degrees:
244 case Fdim:
245 case Distance:
246 case Fast_distance:
247 case Fast_length:
248 case Fast_normalize:
249 case Length:
250 case Mad:
251 case Maxmag:
252 case Minmag:
253 case Nan:
254 case Nextafter:
255 case Normalize:
256 case Radians:
257 case Rotate:
258 case Select:
259 case Step:
260 case Smoothstep:
261 case S_Upsample:
262 case U_Upsample:
263 handle_instr(b, ext_opcode, w, count, handle_special);
264 return true;
265 case Printf:
266 handle_instr(b, ext_opcode, w, count, handle_printf);
267 return true;
268 case Prefetch:
269 /* TODO maybe add a nir instruction for this? */
270 return true;
271 default:
272 vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
273 return false;
274 }
275 }