compiler/nir: add an is_conversion field to nir_op_info
[mesa.git] / src / compiler / nir / nir_opcodes.py
1 #
2 # Copyright (C) 2014 Connor Abbott
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Connor Abbott (cwabbott0@gmail.com)
25
26 import re
27
28 # Class that represents all the information we have about the opcode
29 # NOTE: this must be kept in sync with nir_op_info
30
31 class Opcode(object):
32 """Class that represents all the information we have about the opcode
33 NOTE: this must be kept in sync with nir_op_info
34 """
35 def __init__(self, name, output_size, output_type, input_sizes,
36 input_types, is_conversion, algebraic_properties, const_expr):
37 """Parameters:
38
39 - name is the name of the opcode (prepend nir_op_ for the enum name)
40 - all types are strings that get nir_type_ prepended to them
41 - input_types is a list of types
42 - is_conversion is true if this opcode represents a type conversion
43 - algebraic_properties is a space-seperated string, where nir_op_is_ is
44 prepended before each entry
45 - const_expr is an expression or series of statements that computes the
46 constant value of the opcode given the constant values of its inputs.
47
48 Constant expressions are formed from the variables src0, src1, ...,
49 src(N-1), where N is the number of arguments. The output of the
50 expression should be stored in the dst variable. Per-component input
51 and output variables will be scalars and non-per-component input and
52 output variables will be a struct with fields named x, y, z, and w
53 all of the correct type. Input and output variables can be assumed
54 to already be of the correct type and need no conversion. In
55 particular, the conversion from the C bool type to/from NIR_TRUE and
56 NIR_FALSE happens automatically.
57
58 For per-component instructions, the entire expression will be
59 executed once for each component. For non-per-component
60 instructions, the expression is expected to store the correct values
61 in dst.x, dst.y, etc. If "dst" does not exist anywhere in the
62 constant expression, an assignment to dst will happen automatically
63 and the result will be equivalent to "dst = <expression>" for
64 per-component instructions and "dst.x = dst.y = ... = <expression>"
65 for non-per-component instructions.
66 """
67 assert isinstance(name, str)
68 assert isinstance(output_size, int)
69 assert isinstance(output_type, str)
70 assert isinstance(input_sizes, list)
71 assert isinstance(input_sizes[0], int)
72 assert isinstance(input_types, list)
73 assert isinstance(input_types[0], str)
74 assert isinstance(is_conversion, bool)
75 assert isinstance(algebraic_properties, str)
76 assert isinstance(const_expr, str)
77 assert len(input_sizes) == len(input_types)
78 assert 0 <= output_size <= 4
79 for size in input_sizes:
80 assert 0 <= size <= 4
81 if output_size != 0:
82 assert size != 0
83 self.name = name
84 self.num_inputs = len(input_sizes)
85 self.output_size = output_size
86 self.output_type = output_type
87 self.input_sizes = input_sizes
88 self.input_types = input_types
89 self.is_conversion = is_conversion
90 self.algebraic_properties = algebraic_properties
91 self.const_expr = const_expr
92
93 # helper variables for strings
94 tfloat = "float"
95 tint = "int"
96 tbool = "bool"
97 tbool1 = "bool1"
98 tbool32 = "bool32"
99 tuint = "uint"
100 tuint16 = "uint16"
101 tfloat32 = "float32"
102 tint32 = "int32"
103 tuint32 = "uint32"
104 tint64 = "int64"
105 tuint64 = "uint64"
106 tfloat64 = "float64"
107
108 _TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
109
110 def type_has_size(type_):
111 m = _TYPE_SPLIT_RE.match(type_)
112 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
113 return m.group('bits') is not None
114
115 def type_size(type_):
116 m = _TYPE_SPLIT_RE.match(type_)
117 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
118 assert m.group('bits') is not None, \
119 'NIR type string has no bit size: "{}"'.format(type_)
120 return int(m.group('bits'))
121
122 def type_sizes(type_):
123 if type_has_size(type_):
124 return [type_size(type_)]
125 elif type_ == 'bool':
126 return [1, 32]
127 elif type_ == 'float':
128 return [16, 32, 64]
129 else:
130 return [1, 8, 16, 32, 64]
131
132 def type_base_type(type_):
133 m = _TYPE_SPLIT_RE.match(type_)
134 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
135 return m.group('type')
136
137 commutative = "commutative "
138 associative = "associative "
139
140 # global dictionary of opcodes
141 opcodes = {}
142
143 def opcode(name, output_size, output_type, input_sizes, input_types,
144 is_conversion, algebraic_properties, const_expr):
145 assert name not in opcodes
146 opcodes[name] = Opcode(name, output_size, output_type, input_sizes,
147 input_types, is_conversion, algebraic_properties,
148 const_expr)
149
150 def unop_convert(name, out_type, in_type, const_expr):
151 opcode(name, 0, out_type, [0], [in_type], False, "", const_expr)
152
153 def unop(name, ty, const_expr):
154 opcode(name, 0, ty, [0], [ty], False, "", const_expr)
155
156 def unop_horiz(name, output_size, output_type, input_size, input_type,
157 const_expr):
158 opcode(name, output_size, output_type, [input_size], [input_type],
159 False, "", const_expr)
160
161 def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
162 reduce_expr, final_expr):
163 def prereduce(src):
164 return "(" + prereduce_expr.format(src=src) + ")"
165 def final(src):
166 return final_expr.format(src="(" + src + ")")
167 def reduce_(src0, src1):
168 return reduce_expr.format(src0=src0, src1=src1)
169 src0 = prereduce("src0.x")
170 src1 = prereduce("src0.y")
171 src2 = prereduce("src0.z")
172 src3 = prereduce("src0.w")
173 unop_horiz(name + "2", output_size, output_type, 2, input_type,
174 final(reduce_(src0, src1)))
175 unop_horiz(name + "3", output_size, output_type, 3, input_type,
176 final(reduce_(reduce_(src0, src1), src2)))
177 unop_horiz(name + "4", output_size, output_type, 4, input_type,
178 final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
179
180 def unop_numeric_convert(name, out_type, in_type, const_expr):
181 opcode(name, 0, out_type, [0], [in_type], True, "", const_expr)
182
183 # These two move instructions differ in what modifiers they support and what
184 # the negate modifier means. Otherwise, they are identical.
185 unop("fmov", tfloat, "src0")
186 unop("imov", tint, "src0")
187
188 unop("ineg", tint, "-src0")
189 unop("fneg", tfloat, "-src0")
190 unop("inot", tint, "~src0") # invert every bit of the integer
191 unop("fnot", tfloat, ("bit_size == 64 ? ((src0 == 0.0) ? 1.0 : 0.0f) : " +
192 "((src0 == 0.0f) ? 1.0f : 0.0f)"))
193 unop("fsign", tfloat, ("bit_size == 64 ? " +
194 "((src0 == 0.0) ? 0.0 : ((src0 > 0.0) ? 1.0 : -1.0)) : " +
195 "((src0 == 0.0f) ? 0.0f : ((src0 > 0.0f) ? 1.0f : -1.0f))"))
196 unop("isign", tint, "(src0 == 0) ? 0 : ((src0 > 0) ? 1 : -1)")
197 unop("iabs", tint, "(src0 < 0) ? -src0 : src0")
198 unop("fabs", tfloat, "fabs(src0)")
199 unop("fsat", tfloat, ("bit_size == 64 ? " +
200 "((src0 > 1.0) ? 1.0 : ((src0 <= 0.0) ? 0.0 : src0)) : " +
201 "((src0 > 1.0f) ? 1.0f : ((src0 <= 0.0f) ? 0.0f : src0))"))
202 unop("frcp", tfloat, "bit_size == 64 ? 1.0 / src0 : 1.0f / src0")
203 unop("frsq", tfloat, "bit_size == 64 ? 1.0 / sqrt(src0) : 1.0f / sqrtf(src0)")
204 unop("fsqrt", tfloat, "bit_size == 64 ? sqrt(src0) : sqrtf(src0)")
205 unop("fexp2", tfloat, "exp2f(src0)")
206 unop("flog2", tfloat, "log2f(src0)")
207
208 # Generate all of the numeric conversion opcodes
209 for src_t in [tint, tuint, tfloat, tbool]:
210 if src_t == tbool:
211 dst_types = [tfloat, tint]
212 elif src_t == tint:
213 dst_types = [tfloat, tint, tbool]
214 elif src_t == tuint:
215 dst_types = [tfloat, tuint]
216 elif src_t == tfloat:
217 dst_types = [tint, tuint, tfloat, tbool]
218
219 for dst_t in dst_types:
220 for bit_size in type_sizes(dst_t):
221 if bit_size == 16 and dst_t == tfloat and src_t == tfloat:
222 rnd_modes = ['_rtne', '_rtz', '']
223 for rnd_mode in rnd_modes:
224 unop_numeric_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0],
225 bit_size, rnd_mode),
226 dst_t + str(bit_size), src_t, "src0")
227 else:
228 conv_expr = "src0 != 0" if dst_t == tbool else "src0"
229 unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0], bit_size),
230 dst_t + str(bit_size), src_t, conv_expr)
231
232
233 # Unary floating-point rounding operations.
234
235
236 unop("ftrunc", tfloat, "bit_size == 64 ? trunc(src0) : truncf(src0)")
237 unop("fceil", tfloat, "bit_size == 64 ? ceil(src0) : ceilf(src0)")
238 unop("ffloor", tfloat, "bit_size == 64 ? floor(src0) : floorf(src0)")
239 unop("ffract", tfloat, "src0 - (bit_size == 64 ? floor(src0) : floorf(src0))")
240 unop("fround_even", tfloat, "bit_size == 64 ? _mesa_roundeven(src0) : _mesa_roundevenf(src0)")
241
242 unop("fquantize2f16", tfloat, "(fabs(src0) < ldexpf(1.0, -14)) ? copysignf(0.0f, src0) : _mesa_half_to_float(_mesa_float_to_half(src0))")
243
244 # Trigonometric operations.
245
246
247 unop("fsin", tfloat, "bit_size == 64 ? sin(src0) : sinf(src0)")
248 unop("fcos", tfloat, "bit_size == 64 ? cos(src0) : cosf(src0)")
249
250 # dfrexp
251 unop_convert("frexp_exp", tint32, tfloat64, "frexp(src0, &dst);")
252 unop_convert("frexp_sig", tfloat64, tfloat64, "int n; dst = frexp(src0, &n);")
253
254 # Partial derivatives.
255
256
257 unop("fddx", tfloat, "0.0") # the derivative of a constant is 0.
258 unop("fddy", tfloat, "0.0")
259 unop("fddx_fine", tfloat, "0.0")
260 unop("fddy_fine", tfloat, "0.0")
261 unop("fddx_coarse", tfloat, "0.0")
262 unop("fddy_coarse", tfloat, "0.0")
263
264
265 # Floating point pack and unpack operations.
266
267 def pack_2x16(fmt):
268 unop_horiz("pack_" + fmt + "_2x16", 1, tuint32, 2, tfloat32, """
269 dst.x = (uint32_t) pack_fmt_1x16(src0.x);
270 dst.x |= ((uint32_t) pack_fmt_1x16(src0.y)) << 16;
271 """.replace("fmt", fmt))
272
273 def pack_4x8(fmt):
274 unop_horiz("pack_" + fmt + "_4x8", 1, tuint32, 4, tfloat32, """
275 dst.x = (uint32_t) pack_fmt_1x8(src0.x);
276 dst.x |= ((uint32_t) pack_fmt_1x8(src0.y)) << 8;
277 dst.x |= ((uint32_t) pack_fmt_1x8(src0.z)) << 16;
278 dst.x |= ((uint32_t) pack_fmt_1x8(src0.w)) << 24;
279 """.replace("fmt", fmt))
280
281 def unpack_2x16(fmt):
282 unop_horiz("unpack_" + fmt + "_2x16", 2, tfloat32, 1, tuint32, """
283 dst.x = unpack_fmt_1x16((uint16_t)(src0.x & 0xffff));
284 dst.y = unpack_fmt_1x16((uint16_t)(src0.x << 16));
285 """.replace("fmt", fmt))
286
287 def unpack_4x8(fmt):
288 unop_horiz("unpack_" + fmt + "_4x8", 4, tfloat32, 1, tuint32, """
289 dst.x = unpack_fmt_1x8((uint8_t)(src0.x & 0xff));
290 dst.y = unpack_fmt_1x8((uint8_t)((src0.x >> 8) & 0xff));
291 dst.z = unpack_fmt_1x8((uint8_t)((src0.x >> 16) & 0xff));
292 dst.w = unpack_fmt_1x8((uint8_t)(src0.x >> 24));
293 """.replace("fmt", fmt))
294
295
296 pack_2x16("snorm")
297 pack_4x8("snorm")
298 pack_2x16("unorm")
299 pack_4x8("unorm")
300 pack_2x16("half")
301 unpack_2x16("snorm")
302 unpack_4x8("snorm")
303 unpack_2x16("unorm")
304 unpack_4x8("unorm")
305 unpack_2x16("half")
306
307 unop_horiz("pack_uvec2_to_uint", 1, tuint32, 2, tuint32, """
308 dst.x = (src0.x & 0xffff) | (src0.y << 16);
309 """)
310
311 unop_horiz("pack_uvec4_to_uint", 1, tuint32, 4, tuint32, """
312 dst.x = (src0.x << 0) |
313 (src0.y << 8) |
314 (src0.z << 16) |
315 (src0.w << 24);
316 """)
317
318 unop_horiz("pack_32_2x16", 1, tuint32, 2, tuint16,
319 "dst.x = src0.x | ((uint32_t)src0.y << 16);")
320
321 unop_horiz("pack_64_2x32", 1, tuint64, 2, tuint32,
322 "dst.x = src0.x | ((uint64_t)src0.y << 32);")
323
324 unop_horiz("pack_64_4x16", 1, tuint64, 4, tuint16,
325 "dst.x = src0.x | ((uint64_t)src0.y << 16) | ((uint64_t)src0.z << 32) | ((uint64_t)src0.w << 48);")
326
327 unop_horiz("unpack_64_2x32", 2, tuint32, 1, tuint64,
328 "dst.x = src0.x; dst.y = src0.x >> 32;")
329
330 unop_horiz("unpack_64_4x16", 4, tuint16, 1, tuint64,
331 "dst.x = src0.x; dst.y = src0.x >> 16; dst.z = src0.x >> 32; dst.w = src0.w >> 48;")
332
333 unop_horiz("unpack_32_2x16", 2, tuint16, 1, tuint32,
334 "dst.x = src0.x; dst.y = src0.x >> 16;")
335
336 # Lowered floating point unpacking operations.
337
338
339 unop_convert("unpack_half_2x16_split_x", tfloat32, tuint32,
340 "unpack_half_1x16((uint16_t)(src0 & 0xffff))")
341 unop_convert("unpack_half_2x16_split_y", tfloat32, tuint32,
342 "unpack_half_1x16((uint16_t)(src0 >> 16))")
343
344 unop_convert("unpack_32_2x16_split_x", tuint16, tuint32, "src0")
345 unop_convert("unpack_32_2x16_split_y", tuint16, tuint32, "src0 >> 16")
346
347 unop_convert("unpack_64_2x32_split_x", tuint32, tuint64, "src0")
348 unop_convert("unpack_64_2x32_split_y", tuint32, tuint64, "src0 >> 32")
349
350 # Bit operations, part of ARB_gpu_shader5.
351
352
353 unop("bitfield_reverse", tuint32, """
354 /* we're not winning any awards for speed here, but that's ok */
355 dst = 0;
356 for (unsigned bit = 0; bit < 32; bit++)
357 dst |= ((src0 >> bit) & 1) << (31 - bit);
358 """)
359 unop_convert("bit_count", tuint32, tuint, """
360 dst = 0;
361 for (unsigned bit = 0; bit < bit_size; bit++) {
362 if ((src0 >> bit) & 1)
363 dst++;
364 }
365 """)
366
367 unop_convert("ufind_msb", tint32, tuint, """
368 dst = -1;
369 for (int bit = bit_size - 1; bit >= 0; bit--) {
370 if ((src0 >> bit) & 1) {
371 dst = bit;
372 break;
373 }
374 }
375 """)
376
377 unop("ifind_msb", tint32, """
378 dst = -1;
379 for (int bit = 31; bit >= 0; bit--) {
380 /* If src0 < 0, we're looking for the first 0 bit.
381 * if src0 >= 0, we're looking for the first 1 bit.
382 */
383 if ((((src0 >> bit) & 1) && (src0 >= 0)) ||
384 (!((src0 >> bit) & 1) && (src0 < 0))) {
385 dst = bit;
386 break;
387 }
388 }
389 """)
390
391 unop_convert("find_lsb", tint32, tint, """
392 dst = -1;
393 for (unsigned bit = 0; bit < bit_size; bit++) {
394 if ((src0 >> bit) & 1) {
395 dst = bit;
396 break;
397 }
398 }
399 """)
400
401
402 for i in range(1, 5):
403 for j in range(1, 5):
404 unop_horiz("fnoise{0}_{1}".format(i, j), i, tfloat, j, tfloat, "0.0f")
405
406
407 # AMD_gcn_shader extended instructions
408 unop_horiz("cube_face_coord", 2, tfloat32, 3, tfloat32, """
409 dst.x = dst.y = 0.0;
410 float absX = fabs(src0.x);
411 float absY = fabs(src0.y);
412 float absZ = fabs(src0.z);
413 if (src0.x >= 0 && absX >= absY && absX >= absZ) { dst.x = -src0.y; dst.y = -src0.z; }
414 if (src0.x < 0 && absX >= absY && absX >= absZ) { dst.x = -src0.y; dst.y = src0.z; }
415 if (src0.y >= 0 && absY >= absX && absY >= absZ) { dst.x = src0.z; dst.y = src0.x; }
416 if (src0.y < 0 && absY >= absX && absY >= absZ) { dst.x = -src0.z; dst.y = src0.x; }
417 if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.y; dst.y = src0.x; }
418 if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.y; dst.y = -src0.x; }
419 """)
420
421 unop_horiz("cube_face_index", 1, tfloat32, 3, tfloat32, """
422 float absX = fabs(src0.x);
423 float absY = fabs(src0.y);
424 float absZ = fabs(src0.z);
425 if (src0.x >= 0 && absX >= absY && absX >= absZ) dst.x = 0;
426 if (src0.x < 0 && absX >= absY && absX >= absZ) dst.x = 1;
427 if (src0.y >= 0 && absY >= absX && absY >= absZ) dst.x = 2;
428 if (src0.y < 0 && absY >= absX && absY >= absZ) dst.x = 3;
429 if (src0.z >= 0 && absZ >= absX && absZ >= absY) dst.x = 4;
430 if (src0.z < 0 && absZ >= absX && absZ >= absY) dst.x = 5;
431 """)
432
433
434 def binop_convert(name, out_type, in_type, alg_props, const_expr):
435 opcode(name, 0, out_type, [0, 0], [in_type, in_type],
436 False, alg_props, const_expr)
437
438 def binop(name, ty, alg_props, const_expr):
439 binop_convert(name, ty, ty, alg_props, const_expr)
440
441 def binop_compare(name, ty, alg_props, const_expr):
442 binop_convert(name, tbool1, ty, alg_props, const_expr)
443
444 def binop_compare32(name, ty, alg_props, const_expr):
445 binop_convert(name, tbool32, ty, alg_props, const_expr)
446
447 def binop_horiz(name, out_size, out_type, src1_size, src1_type, src2_size,
448 src2_type, const_expr):
449 opcode(name, out_size, out_type, [src1_size, src2_size], [src1_type, src2_type],
450 False, "", const_expr)
451
452 def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
453 reduce_expr, final_expr):
454 def final(src):
455 return final_expr.format(src= "(" + src + ")")
456 def reduce_(src0, src1):
457 return reduce_expr.format(src0=src0, src1=src1)
458 def prereduce(src0, src1):
459 return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
460 src0 = prereduce("src0.x", "src1.x")
461 src1 = prereduce("src0.y", "src1.y")
462 src2 = prereduce("src0.z", "src1.z")
463 src3 = prereduce("src0.w", "src1.w")
464 opcode(name + "2", output_size, output_type,
465 [2, 2], [src_type, src_type], False, commutative,
466 final(reduce_(src0, src1)))
467 opcode(name + "3", output_size, output_type,
468 [3, 3], [src_type, src_type], False, commutative,
469 final(reduce_(reduce_(src0, src1), src2)))
470 opcode(name + "4", output_size, output_type,
471 [4, 4], [src_type, src_type], False, commutative,
472 final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
473
474 binop("fadd", tfloat, commutative + associative, "src0 + src1")
475 binop("iadd", tint, commutative + associative, "src0 + src1")
476 binop("iadd_sat", tint, commutative + associative, """
477 src1 > 0 ?
478 (src0 + src1 < src0 ? (1ull << (bit_size - 1)) - 1 : src0 + src1) :
479 (src0 < src0 + src1 ? (1ull << (bit_size - 1)) : src0 + src1)
480 """)
481 binop("uadd_sat", tuint, commutative,
482 "(src0 + src1) < src0 ? UINT64_MAX : (src0 + src1)")
483 binop("isub_sat", tint, "", """
484 src1 < 0 ?
485 (src0 - src1 < src0 ? (1ull << (bit_size - 1)) - 1 : src0 - src1) :
486 (src0 < src0 - src1 ? (1ull << (bit_size - 1)) : src0 - src1)
487 """)
488 binop("usub_sat", tuint, "", "src0 < src1 ? 0 : src0 - src1")
489
490 binop("fsub", tfloat, "", "src0 - src1")
491 binop("isub", tint, "", "src0 - src1")
492
493 binop("fmul", tfloat, commutative + associative, "src0 * src1")
494 # low 32-bits of signed/unsigned integer multiply
495 binop("imul", tint, commutative + associative, "src0 * src1")
496
497 # Generate 64 bit result from 2 32 bits quantity
498 binop_convert("imul_2x32_64", tint64, tint32, commutative,
499 "(int64_t)src0 * (int64_t)src1")
500 binop_convert("umul_2x32_64", tuint64, tuint32, commutative,
501 "(uint64_t)src0 * (uint64_t)src1")
502
503 # high 32-bits of signed integer multiply
504 binop("imul_high", tint, commutative, """
505 if (bit_size == 64) {
506 /* We need to do a full 128-bit x 128-bit multiply in order for the sign
507 * extension to work properly. The casts are kind-of annoying but needed
508 * to prevent compiler warnings.
509 */
510 uint32_t src0_u32[4] = {
511 src0,
512 (int64_t)src0 >> 32,
513 (int64_t)src0 >> 63,
514 (int64_t)src0 >> 63,
515 };
516 uint32_t src1_u32[4] = {
517 src1,
518 (int64_t)src1 >> 32,
519 (int64_t)src1 >> 63,
520 (int64_t)src1 >> 63,
521 };
522 uint32_t prod_u32[4];
523 ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
524 dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
525 } else {
526 dst = ((int64_t)src0 * (int64_t)src1) >> bit_size;
527 }
528 """)
529
530 # high 32-bits of unsigned integer multiply
531 binop("umul_high", tuint, commutative, """
532 if (bit_size == 64) {
533 /* The casts are kind-of annoying but needed to prevent compiler warnings. */
534 uint32_t src0_u32[2] = { src0, (uint64_t)src0 >> 32 };
535 uint32_t src1_u32[2] = { src1, (uint64_t)src1 >> 32 };
536 uint32_t prod_u32[4];
537 ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
538 dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
539 } else {
540 dst = ((uint64_t)src0 * (uint64_t)src1) >> bit_size;
541 }
542 """)
543
544 binop("fdiv", tfloat, "", "src0 / src1")
545 binop("idiv", tint, "", "src1 == 0 ? 0 : (src0 / src1)")
546 binop("udiv", tuint, "", "src1 == 0 ? 0 : (src0 / src1)")
547
548 # returns a boolean representing the carry resulting from the addition of
549 # the two unsigned arguments.
550
551 binop_convert("uadd_carry", tuint, tuint, commutative, "src0 + src1 < src0")
552
553 # returns a boolean representing the borrow resulting from the subtraction
554 # of the two unsigned arguments.
555
556 binop_convert("usub_borrow", tuint, tuint, "", "src0 < src1")
557
558 # hadd: (a + b) >> 1 (without overflow)
559 # x + y = x - (x & ~y) + (x & ~y) + y - (~x & y) + (~x & y)
560 # = (x & y) + (x & ~y) + (x & y) + (~x & y)
561 # = 2 * (x & y) + (x & ~y) + (~x & y)
562 # = ((x & y) << 1) + (x ^ y)
563 #
564 # Since we know that the bottom bit of (x & y) << 1 is zero,
565 #
566 # (x + y) >> 1 = (((x & y) << 1) + (x ^ y)) >> 1
567 # = (x & y) + ((x ^ y) >> 1)
568 binop("ihadd", tint, commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
569 binop("uhadd", tuint, commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
570
571 # rhadd: (a + b + 1) >> 1 (without overflow)
572 # x + y + 1 = x + (~x & y) - (~x & y) + y + (x & ~y) - (x & ~y) + 1
573 # = (x | y) - (~x & y) + (x | y) - (x & ~y) + 1
574 # = 2 * (x | y) - ((~x & y) + (x & ~y)) + 1
575 # = ((x | y) << 1) - (x ^ y) + 1
576 #
577 # Since we know that the bottom bit of (x & y) << 1 is zero,
578 #
579 # (x + y + 1) >> 1 = (x | y) + (-(x ^ y) + 1) >> 1)
580 # = (x | y) - ((x ^ y) >> 1)
581 binop("irhadd", tint, commutative, "(src0 | src1) + ((src0 ^ src1) >> 1)")
582 binop("urhadd", tuint, commutative, "(src0 | src1) + ((src0 ^ src1) >> 1)")
583
584 binop("umod", tuint, "", "src1 == 0 ? 0 : src0 % src1")
585
586 # For signed integers, there are several different possible definitions of
587 # "modulus" or "remainder". We follow the conventions used by LLVM and
588 # SPIR-V. The irem opcode implements the standard C/C++ signed "%"
589 # operation while the imod opcode implements the more mathematical
590 # "modulus" operation. For details on the difference, see
591 #
592 # http://mathforum.org/library/drmath/view/52343.html
593
594 binop("irem", tint, "", "src1 == 0 ? 0 : src0 % src1")
595 binop("imod", tint, "",
596 "src1 == 0 ? 0 : ((src0 % src1 == 0 || (src0 >= 0) == (src1 >= 0)) ?"
597 " src0 % src1 : src0 % src1 + src1)")
598 binop("fmod", tfloat, "", "src0 - src1 * floorf(src0 / src1)")
599 binop("frem", tfloat, "", "src0 - src1 * truncf(src0 / src1)")
600
601 #
602 # Comparisons
603 #
604
605
606 # these integer-aware comparisons return a boolean (0 or ~0)
607
608 binop_compare("flt", tfloat, "", "src0 < src1")
609 binop_compare("fge", tfloat, "", "src0 >= src1")
610 binop_compare("feq", tfloat, commutative, "src0 == src1")
611 binop_compare("fne", tfloat, commutative, "src0 != src1")
612 binop_compare("ilt", tint, "", "src0 < src1")
613 binop_compare("ige", tint, "", "src0 >= src1")
614 binop_compare("ieq", tint, commutative, "src0 == src1")
615 binop_compare("ine", tint, commutative, "src0 != src1")
616 binop_compare("ult", tuint, "", "src0 < src1")
617 binop_compare("uge", tuint, "", "src0 >= src1")
618 binop_compare32("flt32", tfloat, "", "src0 < src1")
619 binop_compare32("fge32", tfloat, "", "src0 >= src1")
620 binop_compare32("feq32", tfloat, commutative, "src0 == src1")
621 binop_compare32("fne32", tfloat, commutative, "src0 != src1")
622 binop_compare32("ilt32", tint, "", "src0 < src1")
623 binop_compare32("ige32", tint, "", "src0 >= src1")
624 binop_compare32("ieq32", tint, commutative, "src0 == src1")
625 binop_compare32("ine32", tint, commutative, "src0 != src1")
626 binop_compare32("ult32", tuint, "", "src0 < src1")
627 binop_compare32("uge32", tuint, "", "src0 >= src1")
628
629 # integer-aware GLSL-style comparisons that compare floats and ints
630
631 binop_reduce("ball_fequal", 1, tbool1, tfloat, "{src0} == {src1}",
632 "{src0} && {src1}", "{src}")
633 binop_reduce("bany_fnequal", 1, tbool1, tfloat, "{src0} != {src1}",
634 "{src0} || {src1}", "{src}")
635 binop_reduce("ball_iequal", 1, tbool1, tint, "{src0} == {src1}",
636 "{src0} && {src1}", "{src}")
637 binop_reduce("bany_inequal", 1, tbool1, tint, "{src0} != {src1}",
638 "{src0} || {src1}", "{src}")
639
640 binop_reduce("b32all_fequal", 1, tbool32, tfloat, "{src0} == {src1}",
641 "{src0} && {src1}", "{src}")
642 binop_reduce("b32any_fnequal", 1, tbool32, tfloat, "{src0} != {src1}",
643 "{src0} || {src1}", "{src}")
644 binop_reduce("b32all_iequal", 1, tbool32, tint, "{src0} == {src1}",
645 "{src0} && {src1}", "{src}")
646 binop_reduce("b32any_inequal", 1, tbool32, tint, "{src0} != {src1}",
647 "{src0} || {src1}", "{src}")
648
649 # non-integer-aware GLSL-style comparisons that return 0.0 or 1.0
650
651 binop_reduce("fall_equal", 1, tfloat32, tfloat32, "{src0} == {src1}",
652 "{src0} && {src1}", "{src} ? 1.0f : 0.0f")
653 binop_reduce("fany_nequal", 1, tfloat32, tfloat32, "{src0} != {src1}",
654 "{src0} || {src1}", "{src} ? 1.0f : 0.0f")
655
656 # These comparisons for integer-less hardware return 1.0 and 0.0 for true
657 # and false respectively
658
659 binop("slt", tfloat32, "", "(src0 < src1) ? 1.0f : 0.0f") # Set on Less Than
660 binop("sge", tfloat, "", "(src0 >= src1) ? 1.0f : 0.0f") # Set on Greater or Equal
661 binop("seq", tfloat32, commutative, "(src0 == src1) ? 1.0f : 0.0f") # Set on Equal
662 binop("sne", tfloat32, commutative, "(src0 != src1) ? 1.0f : 0.0f") # Set on Not Equal
663
664 # SPIRV shifts are undefined for shift-operands >= bitsize,
665 # but SM5 shifts are defined to use the least significant bits, only
666 # The NIR definition is according to the SM5 specification.
667 opcode("ishl", 0, tint, [0, 0], [tint, tuint32], False, "",
668 "src0 << (src1 & (sizeof(src0) * 8 - 1))")
669 opcode("ishr", 0, tint, [0, 0], [tint, tuint32], False, "",
670 "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
671 opcode("ushr", 0, tuint, [0, 0], [tuint, tuint32], False, "",
672 "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
673
674 # bitwise logic operators
675 #
676 # These are also used as boolean and, or, xor for hardware supporting
677 # integers.
678
679
680 binop("iand", tuint, commutative + associative, "src0 & src1")
681 binop("ior", tuint, commutative + associative, "src0 | src1")
682 binop("ixor", tuint, commutative + associative, "src0 ^ src1")
683
684
685 # floating point logic operators
686 #
687 # These use (src != 0.0) for testing the truth of the input, and output 1.0
688 # for true and 0.0 for false
689
690 binop("fand", tfloat32, commutative,
691 "((src0 != 0.0f) && (src1 != 0.0f)) ? 1.0f : 0.0f")
692 binop("for", tfloat32, commutative,
693 "((src0 != 0.0f) || (src1 != 0.0f)) ? 1.0f : 0.0f")
694 binop("fxor", tfloat32, commutative,
695 "(src0 != 0.0f && src1 == 0.0f) || (src0 == 0.0f && src1 != 0.0f) ? 1.0f : 0.0f")
696
697 binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
698 "{src}")
699
700 binop_reduce("fdot_replicated", 4, tfloat, tfloat,
701 "{src0} * {src1}", "{src0} + {src1}", "{src}")
702
703 opcode("fdph", 1, tfloat, [3, 4], [tfloat, tfloat], False, "",
704 "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
705 opcode("fdph_replicated", 4, tfloat, [3, 4], [tfloat, tfloat], False, "",
706 "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
707
708 binop("fmin", tfloat, "", "fminf(src0, src1)")
709 binop("imin", tint, commutative + associative, "src1 > src0 ? src0 : src1")
710 binop("umin", tuint, commutative + associative, "src1 > src0 ? src0 : src1")
711 binop("fmax", tfloat, "", "fmaxf(src0, src1)")
712 binop("imax", tint, commutative + associative, "src1 > src0 ? src1 : src0")
713 binop("umax", tuint, commutative + associative, "src1 > src0 ? src1 : src0")
714
715 # Saturated vector add for 4 8bit ints.
716 binop("usadd_4x8", tint32, commutative + associative, """
717 dst = 0;
718 for (int i = 0; i < 32; i += 8) {
719 dst |= MIN2(((src0 >> i) & 0xff) + ((src1 >> i) & 0xff), 0xff) << i;
720 }
721 """)
722
723 # Saturated vector subtract for 4 8bit ints.
724 binop("ussub_4x8", tint32, "", """
725 dst = 0;
726 for (int i = 0; i < 32; i += 8) {
727 int src0_chan = (src0 >> i) & 0xff;
728 int src1_chan = (src1 >> i) & 0xff;
729 if (src0_chan > src1_chan)
730 dst |= (src0_chan - src1_chan) << i;
731 }
732 """)
733
734 # vector min for 4 8bit ints.
735 binop("umin_4x8", tint32, commutative + associative, """
736 dst = 0;
737 for (int i = 0; i < 32; i += 8) {
738 dst |= MIN2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
739 }
740 """)
741
742 # vector max for 4 8bit ints.
743 binop("umax_4x8", tint32, commutative + associative, """
744 dst = 0;
745 for (int i = 0; i < 32; i += 8) {
746 dst |= MAX2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
747 }
748 """)
749
750 # unorm multiply: (a * b) / 255.
751 binop("umul_unorm_4x8", tint32, commutative + associative, """
752 dst = 0;
753 for (int i = 0; i < 32; i += 8) {
754 int src0_chan = (src0 >> i) & 0xff;
755 int src1_chan = (src1 >> i) & 0xff;
756 dst |= ((src0_chan * src1_chan) / 255) << i;
757 }
758 """)
759
760 binop("fpow", tfloat, "", "bit_size == 64 ? powf(src0, src1) : pow(src0, src1)")
761
762 binop_horiz("pack_half_2x16_split", 1, tuint32, 1, tfloat32, 1, tfloat32,
763 "pack_half_1x16(src0.x) | (pack_half_1x16(src1.x) << 16)")
764
765 binop_convert("pack_64_2x32_split", tuint64, tuint32, "",
766 "src0 | ((uint64_t)src1 << 32)")
767
768 binop_convert("pack_32_2x16_split", tuint32, tuint16, "",
769 "src0 | ((uint32_t)src1 << 16)")
770
771 # bfm implements the behavior of the first operation of the SM5 "bfi" assembly
772 # and that of the "bfi1" i965 instruction. That is, it has undefined behavior
773 # if either of its arguments are 32.
774 binop_convert("bfm", tuint32, tint32, "", """
775 int bits = src0, offset = src1;
776 if (offset < 0 || bits < 0 || offset > 31 || bits > 31 || offset + bits > 32)
777 dst = 0; /* undefined */
778 else
779 dst = ((1u << bits) - 1) << offset;
780 """)
781
782 opcode("ldexp", 0, tfloat, [0, 0], [tfloat, tint32], False, "", """
783 dst = (bit_size == 64) ? ldexp(src0, src1) : ldexpf(src0, src1);
784 /* flush denormals to zero. */
785 if (!isnormal(dst))
786 dst = copysignf(0.0f, src0);
787 """)
788
789 # Combines the first component of each input to make a 2-component vector.
790
791 binop_horiz("vec2", 2, tuint, 1, tuint, 1, tuint, """
792 dst.x = src0.x;
793 dst.y = src1.x;
794 """)
795
796 # Byte extraction
797 binop("extract_u8", tuint, "", "(uint8_t)(src0 >> (src1 * 8))")
798 binop("extract_i8", tint, "", "(int8_t)(src0 >> (src1 * 8))")
799
800 # Word extraction
801 binop("extract_u16", tuint, "", "(uint16_t)(src0 >> (src1 * 16))")
802 binop("extract_i16", tint, "", "(int16_t)(src0 >> (src1 * 16))")
803
804
805 def triop(name, ty, const_expr):
806 opcode(name, 0, ty, [0, 0, 0], [ty, ty, ty], False, "", const_expr)
807 def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr):
808 opcode(name, output_size, tuint,
809 [src1_size, src2_size, src3_size],
810 [tuint, tuint, tuint], False, "", const_expr)
811
812 triop("ffma", tfloat, "src0 * src1 + src2")
813
814 triop("flrp", tfloat, "src0 * (1 - src2) + src1 * src2")
815
816 # Conditional Select
817 #
818 # A vector conditional select instruction (like ?:, but operating per-
819 # component on vectors). There are two versions, one for floating point
820 # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
821
822
823 triop("fcsel", tfloat32, "(src0 != 0.0f) ? src1 : src2")
824
825 # 3 way min/max/med
826 triop("fmin3", tfloat, "fminf(src0, fminf(src1, src2))")
827 triop("imin3", tint, "MIN2(src0, MIN2(src1, src2))")
828 triop("umin3", tuint, "MIN2(src0, MIN2(src1, src2))")
829
830 triop("fmax3", tfloat, "fmaxf(src0, fmaxf(src1, src2))")
831 triop("imax3", tint, "MAX2(src0, MAX2(src1, src2))")
832 triop("umax3", tuint, "MAX2(src0, MAX2(src1, src2))")
833
834 triop("fmed3", tfloat, "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))")
835 triop("imed3", tint, "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
836 triop("umed3", tuint, "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
837
838 opcode("bcsel", 0, tuint, [0, 0, 0],
839 [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
840 opcode("b32csel", 0, tuint, [0, 0, 0],
841 [tbool32, tuint, tuint], False, "", "src0 ? src1 : src2")
842
843 # SM5 bfi assembly
844 triop("bfi", tuint32, """
845 unsigned mask = src0, insert = src1, base = src2;
846 if (mask == 0) {
847 dst = base;
848 } else {
849 unsigned tmp = mask;
850 while (!(tmp & 1)) {
851 tmp >>= 1;
852 insert <<= 1;
853 }
854 dst = (base & ~mask) | (insert & mask);
855 }
856 """)
857
858 # SM5 ubfe/ibfe assembly
859 opcode("ubfe", 0, tuint32,
860 [0, 0, 0], [tuint32, tint32, tint32], False, "", """
861 unsigned base = src0;
862 int offset = src1, bits = src2;
863 if (bits == 0) {
864 dst = 0;
865 } else if (bits < 0 || offset < 0) {
866 dst = 0; /* undefined */
867 } else if (offset + bits < 32) {
868 dst = (base << (32 - bits - offset)) >> (32 - bits);
869 } else {
870 dst = base >> offset;
871 }
872 """)
873 opcode("ibfe", 0, tint32,
874 [0, 0, 0], [tint32, tint32, tint32], False, "", """
875 int base = src0;
876 int offset = src1, bits = src2;
877 if (bits == 0) {
878 dst = 0;
879 } else if (bits < 0 || offset < 0) {
880 dst = 0; /* undefined */
881 } else if (offset + bits < 32) {
882 dst = (base << (32 - bits - offset)) >> (32 - bits);
883 } else {
884 dst = base >> offset;
885 }
886 """)
887
888 # GLSL bitfieldExtract()
889 opcode("ubitfield_extract", 0, tuint32,
890 [0, 0, 0], [tuint32, tint32, tint32], False, "", """
891 unsigned base = src0;
892 int offset = src1, bits = src2;
893 if (bits == 0) {
894 dst = 0;
895 } else if (bits < 0 || offset < 0 || offset + bits > 32) {
896 dst = 0; /* undefined per the spec */
897 } else {
898 dst = (base >> offset) & ((1ull << bits) - 1);
899 }
900 """)
901 opcode("ibitfield_extract", 0, tint32,
902 [0, 0, 0], [tint32, tint32, tint32], False, "", """
903 int base = src0;
904 int offset = src1, bits = src2;
905 if (bits == 0) {
906 dst = 0;
907 } else if (offset < 0 || bits < 0 || offset + bits > 32) {
908 dst = 0;
909 } else {
910 dst = (base << (32 - offset - bits)) >> offset; /* use sign-extending shift */
911 }
912 """)
913
914 # Combines the first component of each input to make a 3-component vector.
915
916 triop_horiz("vec3", 3, 1, 1, 1, """
917 dst.x = src0.x;
918 dst.y = src1.x;
919 dst.z = src2.x;
920 """)
921
922 def quadop_horiz(name, output_size, src1_size, src2_size, src3_size,
923 src4_size, const_expr):
924 opcode(name, output_size, tuint,
925 [src1_size, src2_size, src3_size, src4_size],
926 [tuint, tuint, tuint, tuint],
927 False, "", const_expr)
928
929 opcode("bitfield_insert", 0, tuint32, [0, 0, 0, 0],
930 [tuint32, tuint32, tint32, tint32], False, "", """
931 unsigned base = src0, insert = src1;
932 int offset = src2, bits = src3;
933 if (bits == 0) {
934 dst = base;
935 } else if (offset < 0 || bits < 0 || bits + offset > 32) {
936 dst = 0;
937 } else {
938 unsigned mask = ((1ull << bits) - 1) << offset;
939 dst = (base & ~mask) | ((insert << offset) & mask);
940 }
941 """)
942
943 quadop_horiz("vec4", 4, 1, 1, 1, 1, """
944 dst.x = src0.x;
945 dst.y = src1.x;
946 dst.z = src2.x;
947 dst.w = src3.x;
948 """)
949
950