nir: Allow [iu]mul_high on non-32-bit types
[mesa.git] / src / compiler / nir / nir_constant_expressions.py
1 from __future__ import print_function
2
3 import re
4 from nir_opcodes import opcodes
5 from nir_opcodes import type_has_size, type_size, type_sizes, type_base_type
6
7 def type_add_size(type_, size):
8 if type_has_size(type_):
9 return type_
10 return type_ + str(size)
11
12 def op_bit_sizes(op):
13 sizes = None
14 if not type_has_size(op.output_type):
15 sizes = set(type_sizes(op.output_type))
16
17 for input_type in op.input_types:
18 if not type_has_size(input_type):
19 if sizes is None:
20 sizes = set(type_sizes(input_type))
21 else:
22 sizes = sizes.intersection(set(type_sizes(input_type)))
23
24 return sorted(list(sizes)) if sizes is not None else None
25
26 def get_const_field(type_):
27 if type_ == "bool32":
28 return "u32"
29 elif type_ == "float16":
30 return "u16"
31 else:
32 return type_base_type(type_)[0] + str(type_size(type_))
33
34 template = """\
35 /*
36 * Copyright (C) 2014 Intel Corporation
37 *
38 * Permission is hereby granted, free of charge, to any person obtaining a
39 * copy of this software and associated documentation files (the "Software"),
40 * to deal in the Software without restriction, including without limitation
41 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
42 * and/or sell copies of the Software, and to permit persons to whom the
43 * Software is furnished to do so, subject to the following conditions:
44 *
45 * The above copyright notice and this permission notice (including the next
46 * paragraph) shall be included in all copies or substantial portions of the
47 * Software.
48 *
49 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
50 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
51 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
52 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
53 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
54 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
55 * IN THE SOFTWARE.
56 *
57 * Authors:
58 * Jason Ekstrand (jason@jlekstrand.net)
59 */
60
61 #include <math.h>
62 #include "util/rounding.h" /* for _mesa_roundeven */
63 #include "util/half_float.h"
64 #include "util/bigmath.h"
65 #include "nir_constant_expressions.h"
66
67 /**
68 * Evaluate one component of packSnorm4x8.
69 */
70 static uint8_t
71 pack_snorm_1x8(float x)
72 {
73 /* From section 8.4 of the GLSL 4.30 spec:
74 *
75 * packSnorm4x8
76 * ------------
77 * The conversion for component c of v to fixed point is done as
78 * follows:
79 *
80 * packSnorm4x8: round(clamp(c, -1, +1) * 127.0)
81 *
82 * We must first cast the float to an int, because casting a negative
83 * float to a uint is undefined.
84 */
85 return (uint8_t) (int)
86 _mesa_roundevenf(CLAMP(x, -1.0f, +1.0f) * 127.0f);
87 }
88
89 /**
90 * Evaluate one component of packSnorm2x16.
91 */
92 static uint16_t
93 pack_snorm_1x16(float x)
94 {
95 /* From section 8.4 of the GLSL ES 3.00 spec:
96 *
97 * packSnorm2x16
98 * -------------
99 * The conversion for component c of v to fixed point is done as
100 * follows:
101 *
102 * packSnorm2x16: round(clamp(c, -1, +1) * 32767.0)
103 *
104 * We must first cast the float to an int, because casting a negative
105 * float to a uint is undefined.
106 */
107 return (uint16_t) (int)
108 _mesa_roundevenf(CLAMP(x, -1.0f, +1.0f) * 32767.0f);
109 }
110
111 /**
112 * Evaluate one component of unpackSnorm4x8.
113 */
114 static float
115 unpack_snorm_1x8(uint8_t u)
116 {
117 /* From section 8.4 of the GLSL 4.30 spec:
118 *
119 * unpackSnorm4x8
120 * --------------
121 * The conversion for unpacked fixed-point value f to floating point is
122 * done as follows:
123 *
124 * unpackSnorm4x8: clamp(f / 127.0, -1, +1)
125 */
126 return CLAMP((int8_t) u / 127.0f, -1.0f, +1.0f);
127 }
128
129 /**
130 * Evaluate one component of unpackSnorm2x16.
131 */
132 static float
133 unpack_snorm_1x16(uint16_t u)
134 {
135 /* From section 8.4 of the GLSL ES 3.00 spec:
136 *
137 * unpackSnorm2x16
138 * ---------------
139 * The conversion for unpacked fixed-point value f to floating point is
140 * done as follows:
141 *
142 * unpackSnorm2x16: clamp(f / 32767.0, -1, +1)
143 */
144 return CLAMP((int16_t) u / 32767.0f, -1.0f, +1.0f);
145 }
146
147 /**
148 * Evaluate one component packUnorm4x8.
149 */
150 static uint8_t
151 pack_unorm_1x8(float x)
152 {
153 /* From section 8.4 of the GLSL 4.30 spec:
154 *
155 * packUnorm4x8
156 * ------------
157 * The conversion for component c of v to fixed point is done as
158 * follows:
159 *
160 * packUnorm4x8: round(clamp(c, 0, +1) * 255.0)
161 */
162 return (uint8_t) (int)
163 _mesa_roundevenf(CLAMP(x, 0.0f, 1.0f) * 255.0f);
164 }
165
166 /**
167 * Evaluate one component packUnorm2x16.
168 */
169 static uint16_t
170 pack_unorm_1x16(float x)
171 {
172 /* From section 8.4 of the GLSL ES 3.00 spec:
173 *
174 * packUnorm2x16
175 * -------------
176 * The conversion for component c of v to fixed point is done as
177 * follows:
178 *
179 * packUnorm2x16: round(clamp(c, 0, +1) * 65535.0)
180 */
181 return (uint16_t) (int)
182 _mesa_roundevenf(CLAMP(x, 0.0f, 1.0f) * 65535.0f);
183 }
184
185 /**
186 * Evaluate one component of unpackUnorm4x8.
187 */
188 static float
189 unpack_unorm_1x8(uint8_t u)
190 {
191 /* From section 8.4 of the GLSL 4.30 spec:
192 *
193 * unpackUnorm4x8
194 * --------------
195 * The conversion for unpacked fixed-point value f to floating point is
196 * done as follows:
197 *
198 * unpackUnorm4x8: f / 255.0
199 */
200 return (float) u / 255.0f;
201 }
202
203 /**
204 * Evaluate one component of unpackUnorm2x16.
205 */
206 static float
207 unpack_unorm_1x16(uint16_t u)
208 {
209 /* From section 8.4 of the GLSL ES 3.00 spec:
210 *
211 * unpackUnorm2x16
212 * ---------------
213 * The conversion for unpacked fixed-point value f to floating point is
214 * done as follows:
215 *
216 * unpackUnorm2x16: f / 65535.0
217 */
218 return (float) u / 65535.0f;
219 }
220
221 /**
222 * Evaluate one component of packHalf2x16.
223 */
224 static uint16_t
225 pack_half_1x16(float x)
226 {
227 return _mesa_float_to_half(x);
228 }
229
230 /**
231 * Evaluate one component of unpackHalf2x16.
232 */
233 static float
234 unpack_half_1x16(uint16_t u)
235 {
236 return _mesa_half_to_float(u);
237 }
238
239 /* Some typed vector structures to make things like src0.y work */
240 typedef float float16_t;
241 typedef float float32_t;
242 typedef double float64_t;
243 typedef bool bool32_t;
244 % for type in ["float", "int", "uint"]:
245 % for width in type_sizes(type):
246 struct ${type}${width}_vec {
247 ${type}${width}_t x;
248 ${type}${width}_t y;
249 ${type}${width}_t z;
250 ${type}${width}_t w;
251 };
252 % endfor
253 % endfor
254
255 struct bool32_vec {
256 bool x;
257 bool y;
258 bool z;
259 bool w;
260 };
261
262 <%def name="evaluate_op(op, bit_size)">
263 <%
264 output_type = type_add_size(op.output_type, bit_size)
265 input_types = [type_add_size(type_, bit_size) for type_ in op.input_types]
266 %>
267
268 ## For each non-per-component input, create a variable srcN that
269 ## contains x, y, z, and w elements which are filled in with the
270 ## appropriately-typed values.
271 % for j in range(op.num_inputs):
272 % if op.input_sizes[j] == 0:
273 <% continue %>
274 % elif "src" + str(j) not in op.const_expr:
275 ## Avoid unused variable warnings
276 <% continue %>
277 %endif
278
279 const struct ${input_types[j]}_vec src${j} = {
280 % for k in range(op.input_sizes[j]):
281 % if input_types[j] == "bool32":
282 _src[${j}].u32[${k}] != 0,
283 % elif input_types[j] == "float16":
284 _mesa_half_to_float(_src[${j}].u16[${k}]),
285 % else:
286 _src[${j}].${get_const_field(input_types[j])}[${k}],
287 % endif
288 % endfor
289 % for k in range(op.input_sizes[j], 4):
290 0,
291 % endfor
292 };
293 % endfor
294
295 % if op.output_size == 0:
296 ## For per-component instructions, we need to iterate over the
297 ## components and apply the constant expression one component
298 ## at a time.
299 for (unsigned _i = 0; _i < num_components; _i++) {
300 ## For each per-component input, create a variable srcN that
301 ## contains the value of the current (_i'th) component.
302 % for j in range(op.num_inputs):
303 % if op.input_sizes[j] != 0:
304 <% continue %>
305 % elif "src" + str(j) not in op.const_expr:
306 ## Avoid unused variable warnings
307 <% continue %>
308 % elif input_types[j] == "bool32":
309 const bool src${j} = _src[${j}].u32[_i] != 0;
310 % elif input_types[j] == "float16":
311 const float src${j} =
312 _mesa_half_to_float(_src[${j}].u16[_i]);
313 % else:
314 const ${input_types[j]}_t src${j} =
315 _src[${j}].${get_const_field(input_types[j])}[_i];
316 % endif
317 % endfor
318
319 ## Create an appropriately-typed variable dst and assign the
320 ## result of the const_expr to it. If const_expr already contains
321 ## writes to dst, just include const_expr directly.
322 % if "dst" in op.const_expr:
323 ${output_type}_t dst;
324
325 ${op.const_expr}
326 % else:
327 ${output_type}_t dst = ${op.const_expr};
328 % endif
329
330 ## Store the current component of the actual destination to the
331 ## value of dst.
332 % if output_type == "bool32":
333 ## Sanitize the C value to a proper NIR bool
334 _dst_val.u32[_i] = dst ? NIR_TRUE : NIR_FALSE;
335 % elif output_type == "float16":
336 _dst_val.u16[_i] = _mesa_float_to_half(dst);
337 % else:
338 _dst_val.${get_const_field(output_type)}[_i] = dst;
339 % endif
340 }
341 % else:
342 ## In the non-per-component case, create a struct dst with
343 ## appropriately-typed elements x, y, z, and w and assign the result
344 ## of the const_expr to all components of dst, or include the
345 ## const_expr directly if it writes to dst already.
346 struct ${output_type}_vec dst;
347
348 % if "dst" in op.const_expr:
349 ${op.const_expr}
350 % else:
351 ## Splat the value to all components. This way expressions which
352 ## write the same value to all components don't need to explicitly
353 ## write to dest. One such example is fnoise which has a
354 ## const_expr of 0.0f.
355 dst.x = dst.y = dst.z = dst.w = ${op.const_expr};
356 % endif
357
358 ## For each component in the destination, copy the value of dst to
359 ## the actual destination.
360 % for k in range(op.output_size):
361 % if output_type == "bool32":
362 ## Sanitize the C value to a proper NIR bool
363 _dst_val.u32[${k}] = dst.${"xyzw"[k]} ? NIR_TRUE : NIR_FALSE;
364 % elif output_type == "float16":
365 _dst_val.u16[${k}] = _mesa_float_to_half(dst.${"xyzw"[k]});
366 % else:
367 _dst_val.${get_const_field(output_type)}[${k}] = dst.${"xyzw"[k]};
368 % endif
369 % endfor
370 % endif
371 </%def>
372
373 % for name, op in sorted(opcodes.items()):
374 static nir_const_value
375 evaluate_${name}(MAYBE_UNUSED unsigned num_components,
376 ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
377 MAYBE_UNUSED nir_const_value *_src)
378 {
379 nir_const_value _dst_val = { {0, } };
380
381 % if op_bit_sizes(op) is not None:
382 switch (bit_size) {
383 % for bit_size in op_bit_sizes(op):
384 case ${bit_size}: {
385 ${evaluate_op(op, bit_size)}
386 break;
387 }
388 % endfor
389
390 default:
391 unreachable("unknown bit width");
392 }
393 % else:
394 ${evaluate_op(op, 0)}
395 % endif
396
397 return _dst_val;
398 }
399 % endfor
400
401 nir_const_value
402 nir_eval_const_opcode(nir_op op, unsigned num_components,
403 unsigned bit_width, nir_const_value *src)
404 {
405 switch (op) {
406 % for name in sorted(opcodes.keys()):
407 case nir_op_${name}:
408 return evaluate_${name}(num_components, bit_width, src);
409 % endfor
410 default:
411 unreachable("shouldn't get here");
412 }
413 }"""
414
415 from mako.template import Template
416
417 print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
418 type_has_size=type_has_size,
419 type_add_size=type_add_size,
420 op_bit_sizes=op_bit_sizes,
421 get_const_field=get_const_field))