nir: Drop imov/fmov in favor of one mov instruction
[mesa.git] / src / compiler / nir / nir_opcodes.py
1 #
2 # Copyright (C) 2014 Connor Abbott
3 #
4 # Permission is hereby granted, free of charge, to any person obtaining a
5 # copy of this software and associated documentation files (the "Software"),
6 # to deal in the Software without restriction, including without limitation
7 # the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 # and/or sell copies of the Software, and to permit persons to whom the
9 # Software is furnished to do so, subject to the following conditions:
10 #
11 # The above copyright notice and this permission notice (including the next
12 # paragraph) shall be included in all copies or substantial portions of the
13 # Software.
14 #
15 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 # FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 # IN THE SOFTWARE.
22 #
23 # Authors:
24 # Connor Abbott (cwabbott0@gmail.com)
25
26 import re
27
28 # Class that represents all the information we have about the opcode
29 # NOTE: this must be kept in sync with nir_op_info
30
31 class Opcode(object):
32 """Class that represents all the information we have about the opcode
33 NOTE: this must be kept in sync with nir_op_info
34 """
35 def __init__(self, name, output_size, output_type, input_sizes,
36 input_types, is_conversion, algebraic_properties, const_expr):
37 """Parameters:
38
39 - name is the name of the opcode (prepend nir_op_ for the enum name)
40 - all types are strings that get nir_type_ prepended to them
41 - input_types is a list of types
42 - is_conversion is true if this opcode represents a type conversion
43 - algebraic_properties is a space-seperated string, where nir_op_is_ is
44 prepended before each entry
45 - const_expr is an expression or series of statements that computes the
46 constant value of the opcode given the constant values of its inputs.
47
48 Constant expressions are formed from the variables src0, src1, ...,
49 src(N-1), where N is the number of arguments. The output of the
50 expression should be stored in the dst variable. Per-component input
51 and output variables will be scalars and non-per-component input and
52 output variables will be a struct with fields named x, y, z, and w
53 all of the correct type. Input and output variables can be assumed
54 to already be of the correct type and need no conversion. In
55 particular, the conversion from the C bool type to/from NIR_TRUE and
56 NIR_FALSE happens automatically.
57
58 For per-component instructions, the entire expression will be
59 executed once for each component. For non-per-component
60 instructions, the expression is expected to store the correct values
61 in dst.x, dst.y, etc. If "dst" does not exist anywhere in the
62 constant expression, an assignment to dst will happen automatically
63 and the result will be equivalent to "dst = <expression>" for
64 per-component instructions and "dst.x = dst.y = ... = <expression>"
65 for non-per-component instructions.
66 """
67 assert isinstance(name, str)
68 assert isinstance(output_size, int)
69 assert isinstance(output_type, str)
70 assert isinstance(input_sizes, list)
71 assert isinstance(input_sizes[0], int)
72 assert isinstance(input_types, list)
73 assert isinstance(input_types[0], str)
74 assert isinstance(is_conversion, bool)
75 assert isinstance(algebraic_properties, str)
76 assert isinstance(const_expr, str)
77 assert len(input_sizes) == len(input_types)
78 assert 0 <= output_size <= 4
79 for size in input_sizes:
80 assert 0 <= size <= 4
81 if output_size != 0:
82 assert size != 0
83 self.name = name
84 self.num_inputs = len(input_sizes)
85 self.output_size = output_size
86 self.output_type = output_type
87 self.input_sizes = input_sizes
88 self.input_types = input_types
89 self.is_conversion = is_conversion
90 self.algebraic_properties = algebraic_properties
91 self.const_expr = const_expr
92
93 # helper variables for strings
94 tfloat = "float"
95 tint = "int"
96 tbool = "bool"
97 tbool1 = "bool1"
98 tbool32 = "bool32"
99 tuint = "uint"
100 tuint16 = "uint16"
101 tfloat32 = "float32"
102 tint32 = "int32"
103 tuint32 = "uint32"
104 tint64 = "int64"
105 tuint64 = "uint64"
106 tfloat64 = "float64"
107
108 _TYPE_SPLIT_RE = re.compile(r'(?P<type>int|uint|float|bool)(?P<bits>\d+)?')
109
110 def type_has_size(type_):
111 m = _TYPE_SPLIT_RE.match(type_)
112 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
113 return m.group('bits') is not None
114
115 def type_size(type_):
116 m = _TYPE_SPLIT_RE.match(type_)
117 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
118 assert m.group('bits') is not None, \
119 'NIR type string has no bit size: "{}"'.format(type_)
120 return int(m.group('bits'))
121
122 def type_sizes(type_):
123 if type_has_size(type_):
124 return [type_size(type_)]
125 elif type_ == 'bool':
126 return [1, 32]
127 elif type_ == 'float':
128 return [16, 32, 64]
129 else:
130 return [1, 8, 16, 32, 64]
131
132 def type_base_type(type_):
133 m = _TYPE_SPLIT_RE.match(type_)
134 assert m is not None, 'Invalid NIR type string: "{}"'.format(type_)
135 return m.group('type')
136
137 # Operation where the first two sources are commutative.
138 #
139 # For 2-source operations, this just mathematical commutativity. Some
140 # 3-source operations, like ffma, are only commutative in the first two
141 # sources.
142 _2src_commutative = "2src_commutative "
143 associative = "associative "
144
145 # global dictionary of opcodes
146 opcodes = {}
147
148 def opcode(name, output_size, output_type, input_sizes, input_types,
149 is_conversion, algebraic_properties, const_expr):
150 assert name not in opcodes
151 opcodes[name] = Opcode(name, output_size, output_type, input_sizes,
152 input_types, is_conversion, algebraic_properties,
153 const_expr)
154
155 def unop_convert(name, out_type, in_type, const_expr):
156 opcode(name, 0, out_type, [0], [in_type], False, "", const_expr)
157
158 def unop(name, ty, const_expr):
159 opcode(name, 0, ty, [0], [ty], False, "", const_expr)
160
161 def unop_horiz(name, output_size, output_type, input_size, input_type,
162 const_expr):
163 opcode(name, output_size, output_type, [input_size], [input_type],
164 False, "", const_expr)
165
166 def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
167 reduce_expr, final_expr):
168 def prereduce(src):
169 return "(" + prereduce_expr.format(src=src) + ")"
170 def final(src):
171 return final_expr.format(src="(" + src + ")")
172 def reduce_(src0, src1):
173 return reduce_expr.format(src0=src0, src1=src1)
174 src0 = prereduce("src0.x")
175 src1 = prereduce("src0.y")
176 src2 = prereduce("src0.z")
177 src3 = prereduce("src0.w")
178 unop_horiz(name + "2", output_size, output_type, 2, input_type,
179 final(reduce_(src0, src1)))
180 unop_horiz(name + "3", output_size, output_type, 3, input_type,
181 final(reduce_(reduce_(src0, src1), src2)))
182 unop_horiz(name + "4", output_size, output_type, 4, input_type,
183 final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
184
185 def unop_numeric_convert(name, out_type, in_type, const_expr):
186 opcode(name, 0, out_type, [0], [in_type], True, "", const_expr)
187
188 unop("mov", tuint, "src0")
189
190 unop("ineg", tint, "-src0")
191 unop("fneg", tfloat, "-src0")
192 unop("inot", tint, "~src0") # invert every bit of the integer
193 unop("fnot", tfloat, ("bit_size == 64 ? ((src0 == 0.0) ? 1.0 : 0.0f) : " +
194 "((src0 == 0.0f) ? 1.0f : 0.0f)"))
195 unop("fsign", tfloat, ("bit_size == 64 ? " +
196 "((src0 == 0.0) ? 0.0 : ((src0 > 0.0) ? 1.0 : -1.0)) : " +
197 "((src0 == 0.0f) ? 0.0f : ((src0 > 0.0f) ? 1.0f : -1.0f))"))
198 unop("isign", tint, "(src0 == 0) ? 0 : ((src0 > 0) ? 1 : -1)")
199 unop("iabs", tint, "(src0 < 0) ? -src0 : src0")
200 unop("fabs", tfloat, "fabs(src0)")
201 unop("fsat", tfloat, ("bit_size == 64 ? " +
202 "((src0 > 1.0) ? 1.0 : ((src0 <= 0.0) ? 0.0 : src0)) : " +
203 "((src0 > 1.0f) ? 1.0f : ((src0 <= 0.0f) ? 0.0f : src0))"))
204 unop("frcp", tfloat, "bit_size == 64 ? 1.0 / src0 : 1.0f / src0")
205 unop("frsq", tfloat, "bit_size == 64 ? 1.0 / sqrt(src0) : 1.0f / sqrtf(src0)")
206 unop("fsqrt", tfloat, "bit_size == 64 ? sqrt(src0) : sqrtf(src0)")
207 unop("fexp2", tfloat, "exp2f(src0)")
208 unop("flog2", tfloat, "log2f(src0)")
209
210 # Generate all of the numeric conversion opcodes
211 for src_t in [tint, tuint, tfloat, tbool]:
212 if src_t == tbool:
213 dst_types = [tfloat, tint]
214 elif src_t == tint:
215 dst_types = [tfloat, tint, tbool]
216 elif src_t == tuint:
217 dst_types = [tfloat, tuint]
218 elif src_t == tfloat:
219 dst_types = [tint, tuint, tfloat, tbool]
220
221 for dst_t in dst_types:
222 for bit_size in type_sizes(dst_t):
223 if bit_size == 16 and dst_t == tfloat and src_t == tfloat:
224 rnd_modes = ['_rtne', '_rtz', '']
225 for rnd_mode in rnd_modes:
226 unop_numeric_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0],
227 bit_size, rnd_mode),
228 dst_t + str(bit_size), src_t, "src0")
229 else:
230 conv_expr = "src0 != 0" if dst_t == tbool else "src0"
231 unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0], bit_size),
232 dst_t + str(bit_size), src_t, conv_expr)
233
234
235 # Unary floating-point rounding operations.
236
237
238 unop("ftrunc", tfloat, "bit_size == 64 ? trunc(src0) : truncf(src0)")
239 unop("fceil", tfloat, "bit_size == 64 ? ceil(src0) : ceilf(src0)")
240 unop("ffloor", tfloat, "bit_size == 64 ? floor(src0) : floorf(src0)")
241 unop("ffract", tfloat, "src0 - (bit_size == 64 ? floor(src0) : floorf(src0))")
242 unop("fround_even", tfloat, "bit_size == 64 ? _mesa_roundeven(src0) : _mesa_roundevenf(src0)")
243
244 unop("fquantize2f16", tfloat, "(fabs(src0) < ldexpf(1.0, -14)) ? copysignf(0.0f, src0) : _mesa_half_to_float(_mesa_float_to_half(src0))")
245
246 # Trigonometric operations.
247
248
249 unop("fsin", tfloat, "bit_size == 64 ? sin(src0) : sinf(src0)")
250 unop("fcos", tfloat, "bit_size == 64 ? cos(src0) : cosf(src0)")
251
252 # dfrexp
253 unop_convert("frexp_exp", tint32, tfloat, "frexp(src0, &dst);")
254 unop_convert("frexp_sig", tfloat, tfloat, "int n; dst = frexp(src0, &n);")
255
256 # Partial derivatives.
257
258
259 unop("fddx", tfloat, "0.0") # the derivative of a constant is 0.
260 unop("fddy", tfloat, "0.0")
261 unop("fddx_fine", tfloat, "0.0")
262 unop("fddy_fine", tfloat, "0.0")
263 unop("fddx_coarse", tfloat, "0.0")
264 unop("fddy_coarse", tfloat, "0.0")
265
266
267 # Floating point pack and unpack operations.
268
269 def pack_2x16(fmt):
270 unop_horiz("pack_" + fmt + "_2x16", 1, tuint32, 2, tfloat32, """
271 dst.x = (uint32_t) pack_fmt_1x16(src0.x);
272 dst.x |= ((uint32_t) pack_fmt_1x16(src0.y)) << 16;
273 """.replace("fmt", fmt))
274
275 def pack_4x8(fmt):
276 unop_horiz("pack_" + fmt + "_4x8", 1, tuint32, 4, tfloat32, """
277 dst.x = (uint32_t) pack_fmt_1x8(src0.x);
278 dst.x |= ((uint32_t) pack_fmt_1x8(src0.y)) << 8;
279 dst.x |= ((uint32_t) pack_fmt_1x8(src0.z)) << 16;
280 dst.x |= ((uint32_t) pack_fmt_1x8(src0.w)) << 24;
281 """.replace("fmt", fmt))
282
283 def unpack_2x16(fmt):
284 unop_horiz("unpack_" + fmt + "_2x16", 2, tfloat32, 1, tuint32, """
285 dst.x = unpack_fmt_1x16((uint16_t)(src0.x & 0xffff));
286 dst.y = unpack_fmt_1x16((uint16_t)(src0.x << 16));
287 """.replace("fmt", fmt))
288
289 def unpack_4x8(fmt):
290 unop_horiz("unpack_" + fmt + "_4x8", 4, tfloat32, 1, tuint32, """
291 dst.x = unpack_fmt_1x8((uint8_t)(src0.x & 0xff));
292 dst.y = unpack_fmt_1x8((uint8_t)((src0.x >> 8) & 0xff));
293 dst.z = unpack_fmt_1x8((uint8_t)((src0.x >> 16) & 0xff));
294 dst.w = unpack_fmt_1x8((uint8_t)(src0.x >> 24));
295 """.replace("fmt", fmt))
296
297
298 pack_2x16("snorm")
299 pack_4x8("snorm")
300 pack_2x16("unorm")
301 pack_4x8("unorm")
302 pack_2x16("half")
303 unpack_2x16("snorm")
304 unpack_4x8("snorm")
305 unpack_2x16("unorm")
306 unpack_4x8("unorm")
307 unpack_2x16("half")
308
309 unop_horiz("pack_uvec2_to_uint", 1, tuint32, 2, tuint32, """
310 dst.x = (src0.x & 0xffff) | (src0.y << 16);
311 """)
312
313 unop_horiz("pack_uvec4_to_uint", 1, tuint32, 4, tuint32, """
314 dst.x = (src0.x << 0) |
315 (src0.y << 8) |
316 (src0.z << 16) |
317 (src0.w << 24);
318 """)
319
320 unop_horiz("pack_32_2x16", 1, tuint32, 2, tuint16,
321 "dst.x = src0.x | ((uint32_t)src0.y << 16);")
322
323 unop_horiz("pack_64_2x32", 1, tuint64, 2, tuint32,
324 "dst.x = src0.x | ((uint64_t)src0.y << 32);")
325
326 unop_horiz("pack_64_4x16", 1, tuint64, 4, tuint16,
327 "dst.x = src0.x | ((uint64_t)src0.y << 16) | ((uint64_t)src0.z << 32) | ((uint64_t)src0.w << 48);")
328
329 unop_horiz("unpack_64_2x32", 2, tuint32, 1, tuint64,
330 "dst.x = src0.x; dst.y = src0.x >> 32;")
331
332 unop_horiz("unpack_64_4x16", 4, tuint16, 1, tuint64,
333 "dst.x = src0.x; dst.y = src0.x >> 16; dst.z = src0.x >> 32; dst.w = src0.w >> 48;")
334
335 unop_horiz("unpack_32_2x16", 2, tuint16, 1, tuint32,
336 "dst.x = src0.x; dst.y = src0.x >> 16;")
337
338 # Lowered floating point unpacking operations.
339
340
341 unop_convert("unpack_half_2x16_split_x", tfloat32, tuint32,
342 "unpack_half_1x16((uint16_t)(src0 & 0xffff))")
343 unop_convert("unpack_half_2x16_split_y", tfloat32, tuint32,
344 "unpack_half_1x16((uint16_t)(src0 >> 16))")
345
346 unop_convert("unpack_32_2x16_split_x", tuint16, tuint32, "src0")
347 unop_convert("unpack_32_2x16_split_y", tuint16, tuint32, "src0 >> 16")
348
349 unop_convert("unpack_64_2x32_split_x", tuint32, tuint64, "src0")
350 unop_convert("unpack_64_2x32_split_y", tuint32, tuint64, "src0 >> 32")
351
352 # Bit operations, part of ARB_gpu_shader5.
353
354
355 unop("bitfield_reverse", tuint32, """
356 /* we're not winning any awards for speed here, but that's ok */
357 dst = 0;
358 for (unsigned bit = 0; bit < 32; bit++)
359 dst |= ((src0 >> bit) & 1) << (31 - bit);
360 """)
361 unop_convert("bit_count", tuint32, tuint, """
362 dst = 0;
363 for (unsigned bit = 0; bit < bit_size; bit++) {
364 if ((src0 >> bit) & 1)
365 dst++;
366 }
367 """)
368
369 unop_convert("ufind_msb", tint32, tuint, """
370 dst = -1;
371 for (int bit = bit_size - 1; bit >= 0; bit--) {
372 if ((src0 >> bit) & 1) {
373 dst = bit;
374 break;
375 }
376 }
377 """)
378
379 unop("ifind_msb", tint32, """
380 dst = -1;
381 for (int bit = 31; bit >= 0; bit--) {
382 /* If src0 < 0, we're looking for the first 0 bit.
383 * if src0 >= 0, we're looking for the first 1 bit.
384 */
385 if ((((src0 >> bit) & 1) && (src0 >= 0)) ||
386 (!((src0 >> bit) & 1) && (src0 < 0))) {
387 dst = bit;
388 break;
389 }
390 }
391 """)
392
393 unop_convert("find_lsb", tint32, tint, """
394 dst = -1;
395 for (unsigned bit = 0; bit < bit_size; bit++) {
396 if ((src0 >> bit) & 1) {
397 dst = bit;
398 break;
399 }
400 }
401 """)
402
403
404 for i in range(1, 5):
405 for j in range(1, 5):
406 unop_horiz("fnoise{0}_{1}".format(i, j), i, tfloat, j, tfloat, "0.0f")
407
408
409 # AMD_gcn_shader extended instructions
410 unop_horiz("cube_face_coord", 2, tfloat32, 3, tfloat32, """
411 dst.x = dst.y = 0.0;
412 float absX = fabs(src0.x);
413 float absY = fabs(src0.y);
414 float absZ = fabs(src0.z);
415
416 float ma = 0.0;
417 if (absX >= absY && absX >= absZ) { ma = 2 * src0.x; }
418 if (absY >= absX && absY >= absZ) { ma = 2 * src0.y; }
419 if (absZ >= absX && absZ >= absY) { ma = 2 * src0.z; }
420
421 if (src0.x >= 0 && absX >= absY && absX >= absZ) { dst.x = -src0.z; dst.y = -src0.y; }
422 if (src0.x < 0 && absX >= absY && absX >= absZ) { dst.x = src0.z; dst.y = -src0.y; }
423 if (src0.y >= 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = src0.z; }
424 if (src0.y < 0 && absY >= absX && absY >= absZ) { dst.x = src0.x; dst.y = -src0.z; }
425 if (src0.z >= 0 && absZ >= absX && absZ >= absY) { dst.x = src0.x; dst.y = -src0.y; }
426 if (src0.z < 0 && absZ >= absX && absZ >= absY) { dst.x = -src0.x; dst.y = -src0.y; }
427
428 dst.x = dst.x / ma + 0.5;
429 dst.y = dst.y / ma + 0.5;
430 """)
431
432 unop_horiz("cube_face_index", 1, tfloat32, 3, tfloat32, """
433 float absX = fabs(src0.x);
434 float absY = fabs(src0.y);
435 float absZ = fabs(src0.z);
436 if (src0.x >= 0 && absX >= absY && absX >= absZ) dst.x = 0;
437 if (src0.x < 0 && absX >= absY && absX >= absZ) dst.x = 1;
438 if (src0.y >= 0 && absY >= absX && absY >= absZ) dst.x = 2;
439 if (src0.y < 0 && absY >= absX && absY >= absZ) dst.x = 3;
440 if (src0.z >= 0 && absZ >= absX && absZ >= absY) dst.x = 4;
441 if (src0.z < 0 && absZ >= absX && absZ >= absY) dst.x = 5;
442 """)
443
444
445 def binop_convert(name, out_type, in_type, alg_props, const_expr):
446 opcode(name, 0, out_type, [0, 0], [in_type, in_type],
447 False, alg_props, const_expr)
448
449 def binop(name, ty, alg_props, const_expr):
450 binop_convert(name, ty, ty, alg_props, const_expr)
451
452 def binop_compare(name, ty, alg_props, const_expr):
453 binop_convert(name, tbool1, ty, alg_props, const_expr)
454
455 def binop_compare32(name, ty, alg_props, const_expr):
456 binop_convert(name, tbool32, ty, alg_props, const_expr)
457
458 def binop_horiz(name, out_size, out_type, src1_size, src1_type, src2_size,
459 src2_type, const_expr):
460 opcode(name, out_size, out_type, [src1_size, src2_size], [src1_type, src2_type],
461 False, "", const_expr)
462
463 def binop_reduce(name, output_size, output_type, src_type, prereduce_expr,
464 reduce_expr, final_expr):
465 def final(src):
466 return final_expr.format(src= "(" + src + ")")
467 def reduce_(src0, src1):
468 return reduce_expr.format(src0=src0, src1=src1)
469 def prereduce(src0, src1):
470 return "(" + prereduce_expr.format(src0=src0, src1=src1) + ")"
471 src0 = prereduce("src0.x", "src1.x")
472 src1 = prereduce("src0.y", "src1.y")
473 src2 = prereduce("src0.z", "src1.z")
474 src3 = prereduce("src0.w", "src1.w")
475 opcode(name + "2", output_size, output_type,
476 [2, 2], [src_type, src_type], False, _2src_commutative,
477 final(reduce_(src0, src1)))
478 opcode(name + "3", output_size, output_type,
479 [3, 3], [src_type, src_type], False, _2src_commutative,
480 final(reduce_(reduce_(src0, src1), src2)))
481 opcode(name + "4", output_size, output_type,
482 [4, 4], [src_type, src_type], False, _2src_commutative,
483 final(reduce_(reduce_(src0, src1), reduce_(src2, src3))))
484
485 binop("fadd", tfloat, _2src_commutative + associative, "src0 + src1")
486 binop("iadd", tint, _2src_commutative + associative, "src0 + src1")
487 binop("iadd_sat", tint, _2src_commutative, """
488 src1 > 0 ?
489 (src0 + src1 < src0 ? (1ull << (bit_size - 1)) - 1 : src0 + src1) :
490 (src0 < src0 + src1 ? (1ull << (bit_size - 1)) : src0 + src1)
491 """)
492 binop("uadd_sat", tuint, _2src_commutative,
493 "(src0 + src1) < src0 ? MAX_UINT_FOR_SIZE(sizeof(src0) * 8) : (src0 + src1)")
494 binop("isub_sat", tint, "", """
495 src1 < 0 ?
496 (src0 - src1 < src0 ? (1ull << (bit_size - 1)) - 1 : src0 - src1) :
497 (src0 < src0 - src1 ? (1ull << (bit_size - 1)) : src0 - src1)
498 """)
499 binop("usub_sat", tuint, "", "src0 < src1 ? 0 : src0 - src1")
500
501 binop("fsub", tfloat, "", "src0 - src1")
502 binop("isub", tint, "", "src0 - src1")
503
504 binop("fmul", tfloat, _2src_commutative + associative, "src0 * src1")
505 # low 32-bits of signed/unsigned integer multiply
506 binop("imul", tint, _2src_commutative + associative, "src0 * src1")
507
508 # Generate 64 bit result from 2 32 bits quantity
509 binop_convert("imul_2x32_64", tint64, tint32, _2src_commutative,
510 "(int64_t)src0 * (int64_t)src1")
511 binop_convert("umul_2x32_64", tuint64, tuint32, _2src_commutative,
512 "(uint64_t)src0 * (uint64_t)src1")
513
514 # high 32-bits of signed integer multiply
515 binop("imul_high", tint, _2src_commutative, """
516 if (bit_size == 64) {
517 /* We need to do a full 128-bit x 128-bit multiply in order for the sign
518 * extension to work properly. The casts are kind-of annoying but needed
519 * to prevent compiler warnings.
520 */
521 uint32_t src0_u32[4] = {
522 src0,
523 (int64_t)src0 >> 32,
524 (int64_t)src0 >> 63,
525 (int64_t)src0 >> 63,
526 };
527 uint32_t src1_u32[4] = {
528 src1,
529 (int64_t)src1 >> 32,
530 (int64_t)src1 >> 63,
531 (int64_t)src1 >> 63,
532 };
533 uint32_t prod_u32[4];
534 ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
535 dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
536 } else {
537 dst = ((int64_t)src0 * (int64_t)src1) >> bit_size;
538 }
539 """)
540
541 # high 32-bits of unsigned integer multiply
542 binop("umul_high", tuint, _2src_commutative, """
543 if (bit_size == 64) {
544 /* The casts are kind-of annoying but needed to prevent compiler warnings. */
545 uint32_t src0_u32[2] = { src0, (uint64_t)src0 >> 32 };
546 uint32_t src1_u32[2] = { src1, (uint64_t)src1 >> 32 };
547 uint32_t prod_u32[4];
548 ubm_mul_u32arr(prod_u32, src0_u32, src1_u32);
549 dst = (uint64_t)prod_u32[2] | ((uint64_t)prod_u32[3] << 32);
550 } else {
551 dst = ((uint64_t)src0 * (uint64_t)src1) >> bit_size;
552 }
553 """)
554
555 binop("fdiv", tfloat, "", "src0 / src1")
556 binop("idiv", tint, "", "src1 == 0 ? 0 : (src0 / src1)")
557 binop("udiv", tuint, "", "src1 == 0 ? 0 : (src0 / src1)")
558
559 # returns a boolean representing the carry resulting from the addition of
560 # the two unsigned arguments.
561
562 binop_convert("uadd_carry", tuint, tuint, _2src_commutative, "src0 + src1 < src0")
563
564 # returns a boolean representing the borrow resulting from the subtraction
565 # of the two unsigned arguments.
566
567 binop_convert("usub_borrow", tuint, tuint, "", "src0 < src1")
568
569 # hadd: (a + b) >> 1 (without overflow)
570 # x + y = x - (x & ~y) + (x & ~y) + y - (~x & y) + (~x & y)
571 # = (x & y) + (x & ~y) + (x & y) + (~x & y)
572 # = 2 * (x & y) + (x & ~y) + (~x & y)
573 # = ((x & y) << 1) + (x ^ y)
574 #
575 # Since we know that the bottom bit of (x & y) << 1 is zero,
576 #
577 # (x + y) >> 1 = (((x & y) << 1) + (x ^ y)) >> 1
578 # = (x & y) + ((x ^ y) >> 1)
579 binop("ihadd", tint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
580 binop("uhadd", tuint, _2src_commutative, "(src0 & src1) + ((src0 ^ src1) >> 1)")
581
582 # rhadd: (a + b + 1) >> 1 (without overflow)
583 # x + y + 1 = x + (~x & y) - (~x & y) + y + (x & ~y) - (x & ~y) + 1
584 # = (x | y) - (~x & y) + (x | y) - (x & ~y) + 1
585 # = 2 * (x | y) - ((~x & y) + (x & ~y)) + 1
586 # = ((x | y) << 1) - (x ^ y) + 1
587 #
588 # Since we know that the bottom bit of (x & y) << 1 is zero,
589 #
590 # (x + y + 1) >> 1 = (x | y) + (-(x ^ y) + 1) >> 1)
591 # = (x | y) - ((x ^ y) >> 1)
592 binop("irhadd", tint, _2src_commutative, "(src0 | src1) + ((src0 ^ src1) >> 1)")
593 binop("urhadd", tuint, _2src_commutative, "(src0 | src1) + ((src0 ^ src1) >> 1)")
594
595 binop("umod", tuint, "", "src1 == 0 ? 0 : src0 % src1")
596
597 # For signed integers, there are several different possible definitions of
598 # "modulus" or "remainder". We follow the conventions used by LLVM and
599 # SPIR-V. The irem opcode implements the standard C/C++ signed "%"
600 # operation while the imod opcode implements the more mathematical
601 # "modulus" operation. For details on the difference, see
602 #
603 # http://mathforum.org/library/drmath/view/52343.html
604
605 binop("irem", tint, "", "src1 == 0 ? 0 : src0 % src1")
606 binop("imod", tint, "",
607 "src1 == 0 ? 0 : ((src0 % src1 == 0 || (src0 >= 0) == (src1 >= 0)) ?"
608 " src0 % src1 : src0 % src1 + src1)")
609 binop("fmod", tfloat, "", "src0 - src1 * floorf(src0 / src1)")
610 binop("frem", tfloat, "", "src0 - src1 * truncf(src0 / src1)")
611
612 #
613 # Comparisons
614 #
615
616
617 # these integer-aware comparisons return a boolean (0 or ~0)
618
619 binop_compare("flt", tfloat, "", "src0 < src1")
620 binop_compare("fge", tfloat, "", "src0 >= src1")
621 binop_compare("feq", tfloat, _2src_commutative, "src0 == src1")
622 binop_compare("fne", tfloat, _2src_commutative, "src0 != src1")
623 binop_compare("ilt", tint, "", "src0 < src1")
624 binop_compare("ige", tint, "", "src0 >= src1")
625 binop_compare("ieq", tint, _2src_commutative, "src0 == src1")
626 binop_compare("ine", tint, _2src_commutative, "src0 != src1")
627 binop_compare("ult", tuint, "", "src0 < src1")
628 binop_compare("uge", tuint, "", "src0 >= src1")
629 binop_compare32("flt32", tfloat, "", "src0 < src1")
630 binop_compare32("fge32", tfloat, "", "src0 >= src1")
631 binop_compare32("feq32", tfloat, _2src_commutative, "src0 == src1")
632 binop_compare32("fne32", tfloat, _2src_commutative, "src0 != src1")
633 binop_compare32("ilt32", tint, "", "src0 < src1")
634 binop_compare32("ige32", tint, "", "src0 >= src1")
635 binop_compare32("ieq32", tint, _2src_commutative, "src0 == src1")
636 binop_compare32("ine32", tint, _2src_commutative, "src0 != src1")
637 binop_compare32("ult32", tuint, "", "src0 < src1")
638 binop_compare32("uge32", tuint, "", "src0 >= src1")
639
640 # integer-aware GLSL-style comparisons that compare floats and ints
641
642 binop_reduce("ball_fequal", 1, tbool1, tfloat, "{src0} == {src1}",
643 "{src0} && {src1}", "{src}")
644 binop_reduce("bany_fnequal", 1, tbool1, tfloat, "{src0} != {src1}",
645 "{src0} || {src1}", "{src}")
646 binop_reduce("ball_iequal", 1, tbool1, tint, "{src0} == {src1}",
647 "{src0} && {src1}", "{src}")
648 binop_reduce("bany_inequal", 1, tbool1, tint, "{src0} != {src1}",
649 "{src0} || {src1}", "{src}")
650
651 binop_reduce("b32all_fequal", 1, tbool32, tfloat, "{src0} == {src1}",
652 "{src0} && {src1}", "{src}")
653 binop_reduce("b32any_fnequal", 1, tbool32, tfloat, "{src0} != {src1}",
654 "{src0} || {src1}", "{src}")
655 binop_reduce("b32all_iequal", 1, tbool32, tint, "{src0} == {src1}",
656 "{src0} && {src1}", "{src}")
657 binop_reduce("b32any_inequal", 1, tbool32, tint, "{src0} != {src1}",
658 "{src0} || {src1}", "{src}")
659
660 # non-integer-aware GLSL-style comparisons that return 0.0 or 1.0
661
662 binop_reduce("fall_equal", 1, tfloat32, tfloat32, "{src0} == {src1}",
663 "{src0} && {src1}", "{src} ? 1.0f : 0.0f")
664 binop_reduce("fany_nequal", 1, tfloat32, tfloat32, "{src0} != {src1}",
665 "{src0} || {src1}", "{src} ? 1.0f : 0.0f")
666
667 # These comparisons for integer-less hardware return 1.0 and 0.0 for true
668 # and false respectively
669
670 binop("slt", tfloat32, "", "(src0 < src1) ? 1.0f : 0.0f") # Set on Less Than
671 binop("sge", tfloat, "", "(src0 >= src1) ? 1.0f : 0.0f") # Set on Greater or Equal
672 binop("seq", tfloat32, _2src_commutative, "(src0 == src1) ? 1.0f : 0.0f") # Set on Equal
673 binop("sne", tfloat32, _2src_commutative, "(src0 != src1) ? 1.0f : 0.0f") # Set on Not Equal
674
675 # SPIRV shifts are undefined for shift-operands >= bitsize,
676 # but SM5 shifts are defined to use the least significant bits, only
677 # The NIR definition is according to the SM5 specification.
678 opcode("ishl", 0, tint, [0, 0], [tint, tuint32], False, "",
679 "src0 << (src1 & (sizeof(src0) * 8 - 1))")
680 opcode("ishr", 0, tint, [0, 0], [tint, tuint32], False, "",
681 "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
682 opcode("ushr", 0, tuint, [0, 0], [tuint, tuint32], False, "",
683 "src0 >> (src1 & (sizeof(src0) * 8 - 1))")
684
685 # bitwise logic operators
686 #
687 # These are also used as boolean and, or, xor for hardware supporting
688 # integers.
689
690
691 binop("iand", tuint, _2src_commutative + associative, "src0 & src1")
692 binop("ior", tuint, _2src_commutative + associative, "src0 | src1")
693 binop("ixor", tuint, _2src_commutative + associative, "src0 ^ src1")
694
695
696 # floating point logic operators
697 #
698 # These use (src != 0.0) for testing the truth of the input, and output 1.0
699 # for true and 0.0 for false
700
701 binop("fand", tfloat32, _2src_commutative,
702 "((src0 != 0.0f) && (src1 != 0.0f)) ? 1.0f : 0.0f")
703 binop("for", tfloat32, _2src_commutative,
704 "((src0 != 0.0f) || (src1 != 0.0f)) ? 1.0f : 0.0f")
705 binop("fxor", tfloat32, _2src_commutative,
706 "(src0 != 0.0f && src1 == 0.0f) || (src0 == 0.0f && src1 != 0.0f) ? 1.0f : 0.0f")
707
708 binop_reduce("fdot", 1, tfloat, tfloat, "{src0} * {src1}", "{src0} + {src1}",
709 "{src}")
710
711 binop_reduce("fdot_replicated", 4, tfloat, tfloat,
712 "{src0} * {src1}", "{src0} + {src1}", "{src}")
713
714 opcode("fdph", 1, tfloat, [3, 4], [tfloat, tfloat], False, "",
715 "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
716 opcode("fdph_replicated", 4, tfloat, [3, 4], [tfloat, tfloat], False, "",
717 "src0.x * src1.x + src0.y * src1.y + src0.z * src1.z + src1.w")
718
719 binop("fmin", tfloat, "", "fminf(src0, src1)")
720 binop("imin", tint, _2src_commutative + associative, "src1 > src0 ? src0 : src1")
721 binop("umin", tuint, _2src_commutative + associative, "src1 > src0 ? src0 : src1")
722 binop("fmax", tfloat, "", "fmaxf(src0, src1)")
723 binop("imax", tint, _2src_commutative + associative, "src1 > src0 ? src1 : src0")
724 binop("umax", tuint, _2src_commutative + associative, "src1 > src0 ? src1 : src0")
725
726 # Saturated vector add for 4 8bit ints.
727 binop("usadd_4x8", tint32, _2src_commutative + associative, """
728 dst = 0;
729 for (int i = 0; i < 32; i += 8) {
730 dst |= MIN2(((src0 >> i) & 0xff) + ((src1 >> i) & 0xff), 0xff) << i;
731 }
732 """)
733
734 # Saturated vector subtract for 4 8bit ints.
735 binop("ussub_4x8", tint32, "", """
736 dst = 0;
737 for (int i = 0; i < 32; i += 8) {
738 int src0_chan = (src0 >> i) & 0xff;
739 int src1_chan = (src1 >> i) & 0xff;
740 if (src0_chan > src1_chan)
741 dst |= (src0_chan - src1_chan) << i;
742 }
743 """)
744
745 # vector min for 4 8bit ints.
746 binop("umin_4x8", tint32, _2src_commutative + associative, """
747 dst = 0;
748 for (int i = 0; i < 32; i += 8) {
749 dst |= MIN2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
750 }
751 """)
752
753 # vector max for 4 8bit ints.
754 binop("umax_4x8", tint32, _2src_commutative + associative, """
755 dst = 0;
756 for (int i = 0; i < 32; i += 8) {
757 dst |= MAX2((src0 >> i) & 0xff, (src1 >> i) & 0xff) << i;
758 }
759 """)
760
761 # unorm multiply: (a * b) / 255.
762 binop("umul_unorm_4x8", tint32, _2src_commutative + associative, """
763 dst = 0;
764 for (int i = 0; i < 32; i += 8) {
765 int src0_chan = (src0 >> i) & 0xff;
766 int src1_chan = (src1 >> i) & 0xff;
767 dst |= ((src0_chan * src1_chan) / 255) << i;
768 }
769 """)
770
771 binop("fpow", tfloat, "", "bit_size == 64 ? powf(src0, src1) : pow(src0, src1)")
772
773 binop_horiz("pack_half_2x16_split", 1, tuint32, 1, tfloat32, 1, tfloat32,
774 "pack_half_1x16(src0.x) | (pack_half_1x16(src1.x) << 16)")
775
776 binop_convert("pack_64_2x32_split", tuint64, tuint32, "",
777 "src0 | ((uint64_t)src1 << 32)")
778
779 binop_convert("pack_32_2x16_split", tuint32, tuint16, "",
780 "src0 | ((uint32_t)src1 << 16)")
781
782 # bfm implements the behavior of the first operation of the SM5 "bfi" assembly
783 # and that of the "bfi1" i965 instruction. That is, it has undefined behavior
784 # if either of its arguments are 32.
785 binop_convert("bfm", tuint32, tint32, "", """
786 int bits = src0, offset = src1;
787 if (offset < 0 || bits < 0 || offset > 31 || bits > 31 || offset + bits > 32)
788 dst = 0; /* undefined */
789 else
790 dst = ((1u << bits) - 1) << offset;
791 """)
792
793 opcode("ldexp", 0, tfloat, [0, 0], [tfloat, tint32], False, "", """
794 dst = (bit_size == 64) ? ldexp(src0, src1) : ldexpf(src0, src1);
795 /* flush denormals to zero. */
796 if (!isnormal(dst))
797 dst = copysignf(0.0f, src0);
798 """)
799
800 # Combines the first component of each input to make a 2-component vector.
801
802 binop_horiz("vec2", 2, tuint, 1, tuint, 1, tuint, """
803 dst.x = src0.x;
804 dst.y = src1.x;
805 """)
806
807 # Byte extraction
808 binop("extract_u8", tuint, "", "(uint8_t)(src0 >> (src1 * 8))")
809 binop("extract_i8", tint, "", "(int8_t)(src0 >> (src1 * 8))")
810
811 # Word extraction
812 binop("extract_u16", tuint, "", "(uint16_t)(src0 >> (src1 * 16))")
813 binop("extract_i16", tint, "", "(int16_t)(src0 >> (src1 * 16))")
814
815
816 def triop(name, ty, alg_props, const_expr):
817 opcode(name, 0, ty, [0, 0, 0], [ty, ty, ty], False, alg_props, const_expr)
818 def triop_horiz(name, output_size, src1_size, src2_size, src3_size, const_expr):
819 opcode(name, output_size, tuint,
820 [src1_size, src2_size, src3_size],
821 [tuint, tuint, tuint], False, "", const_expr)
822
823 triop("ffma", tfloat, _2src_commutative, "src0 * src1 + src2")
824
825 triop("flrp", tfloat, "", "src0 * (1 - src2) + src1 * src2")
826
827 # Conditional Select
828 #
829 # A vector conditional select instruction (like ?:, but operating per-
830 # component on vectors). There are two versions, one for floating point
831 # bools (0.0 vs 1.0) and one for integer bools (0 vs ~0).
832
833
834 triop("fcsel", tfloat32, "", "(src0 != 0.0f) ? src1 : src2")
835
836 # 3 way min/max/med
837 triop("fmin3", tfloat, "", "fminf(src0, fminf(src1, src2))")
838 triop("imin3", tint, "", "MIN2(src0, MIN2(src1, src2))")
839 triop("umin3", tuint, "", "MIN2(src0, MIN2(src1, src2))")
840
841 triop("fmax3", tfloat, "", "fmaxf(src0, fmaxf(src1, src2))")
842 triop("imax3", tint, "", "MAX2(src0, MAX2(src1, src2))")
843 triop("umax3", tuint, "", "MAX2(src0, MAX2(src1, src2))")
844
845 triop("fmed3", tfloat, "", "fmaxf(fminf(fmaxf(src0, src1), src2), fminf(src0, src1))")
846 triop("imed3", tint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
847 triop("umed3", tuint, "", "MAX2(MIN2(MAX2(src0, src1), src2), MIN2(src0, src1))")
848
849 opcode("bcsel", 0, tuint, [0, 0, 0],
850 [tbool1, tuint, tuint], False, "", "src0 ? src1 : src2")
851 opcode("b32csel", 0, tuint, [0, 0, 0],
852 [tbool32, tuint, tuint], False, "", "src0 ? src1 : src2")
853
854 # SM5 bfi assembly
855 triop("bfi", tuint32, "", """
856 unsigned mask = src0, insert = src1, base = src2;
857 if (mask == 0) {
858 dst = base;
859 } else {
860 unsigned tmp = mask;
861 while (!(tmp & 1)) {
862 tmp >>= 1;
863 insert <<= 1;
864 }
865 dst = (base & ~mask) | (insert & mask);
866 }
867 """)
868
869 # SM5 ubfe/ibfe assembly
870 opcode("ubfe", 0, tuint32,
871 [0, 0, 0], [tuint32, tint32, tint32], False, "", """
872 unsigned base = src0;
873 int offset = src1, bits = src2;
874 if (bits == 0) {
875 dst = 0;
876 } else if (bits < 0 || offset < 0) {
877 dst = 0; /* undefined */
878 } else if (offset + bits < 32) {
879 dst = (base << (32 - bits - offset)) >> (32 - bits);
880 } else {
881 dst = base >> offset;
882 }
883 """)
884 opcode("ibfe", 0, tint32,
885 [0, 0, 0], [tint32, tint32, tint32], False, "", """
886 int base = src0;
887 int offset = src1, bits = src2;
888 if (bits == 0) {
889 dst = 0;
890 } else if (bits < 0 || offset < 0) {
891 dst = 0; /* undefined */
892 } else if (offset + bits < 32) {
893 dst = (base << (32 - bits - offset)) >> (32 - bits);
894 } else {
895 dst = base >> offset;
896 }
897 """)
898
899 # GLSL bitfieldExtract()
900 opcode("ubitfield_extract", 0, tuint32,
901 [0, 0, 0], [tuint32, tint32, tint32], False, "", """
902 unsigned base = src0;
903 int offset = src1, bits = src2;
904 if (bits == 0) {
905 dst = 0;
906 } else if (bits < 0 || offset < 0 || offset + bits > 32) {
907 dst = 0; /* undefined per the spec */
908 } else {
909 dst = (base >> offset) & ((1ull << bits) - 1);
910 }
911 """)
912 opcode("ibitfield_extract", 0, tint32,
913 [0, 0, 0], [tint32, tint32, tint32], False, "", """
914 int base = src0;
915 int offset = src1, bits = src2;
916 if (bits == 0) {
917 dst = 0;
918 } else if (offset < 0 || bits < 0 || offset + bits > 32) {
919 dst = 0;
920 } else {
921 dst = (base << (32 - offset - bits)) >> offset; /* use sign-extending shift */
922 }
923 """)
924
925 # Combines the first component of each input to make a 3-component vector.
926
927 triop_horiz("vec3", 3, 1, 1, 1, """
928 dst.x = src0.x;
929 dst.y = src1.x;
930 dst.z = src2.x;
931 """)
932
933 def quadop_horiz(name, output_size, src1_size, src2_size, src3_size,
934 src4_size, const_expr):
935 opcode(name, output_size, tuint,
936 [src1_size, src2_size, src3_size, src4_size],
937 [tuint, tuint, tuint, tuint],
938 False, "", const_expr)
939
940 opcode("bitfield_insert", 0, tuint32, [0, 0, 0, 0],
941 [tuint32, tuint32, tint32, tint32], False, "", """
942 unsigned base = src0, insert = src1;
943 int offset = src2, bits = src3;
944 if (bits == 0) {
945 dst = base;
946 } else if (offset < 0 || bits < 0 || bits + offset > 32) {
947 dst = 0;
948 } else {
949 unsigned mask = ((1ull << bits) - 1) << offset;
950 dst = (base & ~mask) | ((insert << offset) & mask);
951 }
952 """)
953
954 quadop_horiz("vec4", 4, 1, 1, 1, 1, """
955 dst.x = src0.x;
956 dst.y = src1.x;
957 dst.z = src2.x;
958 dst.w = src3.x;
959 """)
960
961