vtn: handle SpvExecutionModelKernel
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 struct match_state {
34 bool inexact_match;
35 bool has_exact_alu;
36 unsigned variables_seen;
37 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
38 };
39
40 static bool
41 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
42 unsigned num_components, const uint8_t *swizzle,
43 struct match_state *state);
44
45 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
46
47 /**
48 * Check if a source produces a value of the given type.
49 *
50 * Used for satisfying 'a@type' constraints.
51 */
52 static bool
53 src_is_type(nir_src src, nir_alu_type type)
54 {
55 assert(type != nir_type_invalid);
56
57 if (!src.is_ssa)
58 return false;
59
60 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
61 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
62 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
63
64 if (type == nir_type_bool) {
65 switch (src_alu->op) {
66 case nir_op_iand:
67 case nir_op_ior:
68 case nir_op_ixor:
69 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
70 src_is_type(src_alu->src[1].src, nir_type_bool);
71 case nir_op_inot:
72 return src_is_type(src_alu->src[0].src, nir_type_bool);
73 default:
74 break;
75 }
76 }
77
78 return nir_alu_type_get_base_type(output_type) == type;
79 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
80 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
81
82 if (type == nir_type_bool) {
83 return intr->intrinsic == nir_intrinsic_load_front_face ||
84 intr->intrinsic == nir_intrinsic_load_helper_invocation;
85 }
86 }
87
88 /* don't know */
89 return false;
90 }
91
92 static bool
93 nir_op_matches_search_op(nir_op nop, uint16_t sop)
94 {
95 if (sop <= nir_last_opcode)
96 return nop == sop;
97
98 #define MATCH_FCONV_CASE(op) \
99 case nir_search_op_##op: \
100 return nop == nir_op_##op##16 || \
101 nop == nir_op_##op##32 || \
102 nop == nir_op_##op##64;
103
104 #define MATCH_ICONV_CASE(op) \
105 case nir_search_op_##op: \
106 return nop == nir_op_##op##8 || \
107 nop == nir_op_##op##16 || \
108 nop == nir_op_##op##32 || \
109 nop == nir_op_##op##64;
110
111 #define MATCH_BCONV_CASE(op) \
112 case nir_search_op_##op: \
113 return nop == nir_op_##op##1 || \
114 nop == nir_op_##op##32;
115
116 switch (sop) {
117 MATCH_FCONV_CASE(i2f)
118 MATCH_FCONV_CASE(u2f)
119 MATCH_FCONV_CASE(f2f)
120 MATCH_ICONV_CASE(f2u)
121 MATCH_ICONV_CASE(f2i)
122 MATCH_ICONV_CASE(u2u)
123 MATCH_ICONV_CASE(i2i)
124 MATCH_FCONV_CASE(b2f)
125 MATCH_ICONV_CASE(b2i)
126 MATCH_BCONV_CASE(i2b)
127 MATCH_BCONV_CASE(f2b)
128 default:
129 unreachable("Invalid nir_search_op");
130 }
131
132 #undef MATCH_FCONV_CASE
133 #undef MATCH_ICONV_CASE
134 }
135
136 static nir_op
137 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
138 {
139 if (sop <= nir_last_opcode)
140 return sop;
141
142 #define RET_FCONV_CASE(op) \
143 case nir_search_op_##op: \
144 switch (bit_size) { \
145 case 16: return nir_op_##op##16; \
146 case 32: return nir_op_##op##32; \
147 case 64: return nir_op_##op##64; \
148 default: unreachable("Invalid bit size"); \
149 }
150
151 #define RET_ICONV_CASE(op) \
152 case nir_search_op_##op: \
153 switch (bit_size) { \
154 case 8: return nir_op_##op##8; \
155 case 16: return nir_op_##op##16; \
156 case 32: return nir_op_##op##32; \
157 case 64: return nir_op_##op##64; \
158 default: unreachable("Invalid bit size"); \
159 }
160
161 #define RET_BCONV_CASE(op) \
162 case nir_search_op_##op: \
163 switch (bit_size) { \
164 case 1: return nir_op_##op##1; \
165 case 32: return nir_op_##op##32; \
166 default: unreachable("Invalid bit size"); \
167 }
168
169 switch (sop) {
170 RET_FCONV_CASE(i2f)
171 RET_FCONV_CASE(u2f)
172 RET_FCONV_CASE(f2f)
173 RET_ICONV_CASE(f2u)
174 RET_ICONV_CASE(f2i)
175 RET_ICONV_CASE(u2u)
176 RET_ICONV_CASE(i2i)
177 RET_FCONV_CASE(b2f)
178 RET_ICONV_CASE(b2i)
179 RET_BCONV_CASE(i2b)
180 RET_BCONV_CASE(f2b)
181 default:
182 unreachable("Invalid nir_search_op");
183 }
184
185 #undef RET_FCONV_CASE
186 #undef RET_ICONV_CASE
187 }
188
189 static bool
190 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
191 unsigned num_components, const uint8_t *swizzle,
192 struct match_state *state)
193 {
194 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
195
196 /* Searching only works on SSA values because, if it's not SSA, we can't
197 * know if the value changed between one instance of that value in the
198 * expression and another. Also, the replace operation will place reads of
199 * that value right before the last instruction in the expression we're
200 * replacing so those reads will happen after the original reads and may
201 * not be valid if they're register reads.
202 */
203 if (!instr->src[src].src.is_ssa)
204 return false;
205
206 /* If the source is an explicitly sized source, then we need to reset
207 * both the number of components and the swizzle.
208 */
209 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
210 num_components = nir_op_infos[instr->op].input_sizes[src];
211 swizzle = identity_swizzle;
212 }
213
214 for (unsigned i = 0; i < num_components; ++i)
215 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
216
217 /* If the value has a specific bit size and it doesn't match, bail */
218 if (value->bit_size > 0 &&
219 nir_src_bit_size(instr->src[src].src) != value->bit_size)
220 return false;
221
222 switch (value->type) {
223 case nir_search_value_expression:
224 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
225 return false;
226
227 return match_expression(nir_search_value_as_expression(value),
228 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
229 num_components, new_swizzle, state);
230
231 case nir_search_value_variable: {
232 nir_search_variable *var = nir_search_value_as_variable(value);
233 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
234
235 if (state->variables_seen & (1 << var->variable)) {
236 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
237 return false;
238
239 assert(!instr->src[src].abs && !instr->src[src].negate);
240
241 for (unsigned i = 0; i < num_components; ++i) {
242 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
243 return false;
244 }
245
246 return true;
247 } else {
248 if (var->is_constant &&
249 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
250 return false;
251
252 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
253 return false;
254
255 if (var->type != nir_type_invalid &&
256 !src_is_type(instr->src[src].src, var->type))
257 return false;
258
259 state->variables_seen |= (1 << var->variable);
260 state->variables[var->variable].src = instr->src[src].src;
261 state->variables[var->variable].abs = false;
262 state->variables[var->variable].negate = false;
263
264 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
265 if (i < num_components)
266 state->variables[var->variable].swizzle[i] = new_swizzle[i];
267 else
268 state->variables[var->variable].swizzle[i] = 0;
269 }
270
271 return true;
272 }
273 }
274
275 case nir_search_value_constant: {
276 nir_search_constant *const_val = nir_search_value_as_constant(value);
277
278 if (!nir_src_is_const(instr->src[src].src))
279 return false;
280
281 switch (const_val->type) {
282 case nir_type_float:
283 for (unsigned i = 0; i < num_components; ++i) {
284 double val = nir_src_comp_as_float(instr->src[src].src,
285 new_swizzle[i]);
286 if (val != const_val->data.d)
287 return false;
288 }
289 return true;
290
291 case nir_type_int:
292 case nir_type_uint:
293 case nir_type_bool: {
294 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
295 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
296 for (unsigned i = 0; i < num_components; ++i) {
297 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
298 new_swizzle[i]);
299 if ((val & mask) != (const_val->data.u & mask))
300 return false;
301 }
302 return true;
303 }
304
305 default:
306 unreachable("Invalid alu source type");
307 }
308 }
309
310 default:
311 unreachable("Invalid search value type");
312 }
313 }
314
315 static bool
316 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
317 unsigned num_components, const uint8_t *swizzle,
318 struct match_state *state)
319 {
320 if (expr->cond && !expr->cond(instr))
321 return false;
322
323 if (!nir_op_matches_search_op(instr->op, expr->opcode))
324 return false;
325
326 assert(instr->dest.dest.is_ssa);
327
328 if (expr->value.bit_size > 0 &&
329 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
330 return false;
331
332 state->inexact_match = expr->inexact || state->inexact_match;
333 state->has_exact_alu = instr->exact || state->has_exact_alu;
334 if (state->inexact_match && state->has_exact_alu)
335 return false;
336
337 assert(!instr->dest.saturate);
338 assert(nir_op_infos[instr->op].num_inputs > 0);
339
340 /* If we have an explicitly sized destination, we can only handle the
341 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
342 * expression, we don't have the information right now to propagate that
343 * swizzle through. We can only properly propagate swizzles if the
344 * instruction is vectorized.
345 */
346 if (nir_op_infos[instr->op].output_size != 0) {
347 for (unsigned i = 0; i < num_components; i++) {
348 if (swizzle[i] != i)
349 return false;
350 }
351 }
352
353 /* Stash off the current variables_seen bitmask. This way we can
354 * restore it prior to matching in the commutative case below.
355 */
356 unsigned variables_seen_stash = state->variables_seen;
357
358 bool matched = true;
359 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
360 if (!match_value(expr->srcs[i], instr, i, num_components,
361 swizzle, state)) {
362 matched = false;
363 break;
364 }
365 }
366
367 if (matched)
368 return true;
369
370 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
371 assert(nir_op_infos[instr->op].num_inputs == 2);
372
373 /* Restore the variables_seen bitmask. If we don't do this, then we
374 * could end up with an erroneous failure due to variables found in the
375 * first match attempt above not matching those in the second.
376 */
377 state->variables_seen = variables_seen_stash;
378
379 if (!match_value(expr->srcs[0], instr, 1, num_components,
380 swizzle, state))
381 return false;
382
383 return match_value(expr->srcs[1], instr, 0, num_components,
384 swizzle, state);
385 } else {
386 return false;
387 }
388 }
389
390 static unsigned
391 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
392 struct match_state *state)
393 {
394 if (value->bit_size > 0)
395 return value->bit_size;
396 if (value->bit_size < 0)
397 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
398 return search_bitsize;
399 }
400
401 static nir_alu_src
402 construct_value(nir_builder *build,
403 const nir_search_value *value,
404 unsigned num_components, unsigned search_bitsize,
405 struct match_state *state,
406 nir_instr *instr)
407 {
408 switch (value->type) {
409 case nir_search_value_expression: {
410 const nir_search_expression *expr = nir_search_value_as_expression(value);
411 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
412 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
413
414 if (nir_op_infos[op].output_size != 0)
415 num_components = nir_op_infos[op].output_size;
416
417 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
418 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
419 dst_bit_size, NULL);
420 alu->dest.write_mask = (1 << num_components) - 1;
421 alu->dest.saturate = false;
422
423 /* We have no way of knowing what values in a given search expression
424 * map to a particular replacement value. Therefore, if the
425 * expression we are replacing has any exact values, the entire
426 * replacement should be exact.
427 */
428 alu->exact = state->has_exact_alu;
429
430 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
431 /* If the source is an explicitly sized source, then we need to reset
432 * the number of components to match.
433 */
434 if (nir_op_infos[alu->op].input_sizes[i] != 0)
435 num_components = nir_op_infos[alu->op].input_sizes[i];
436
437 alu->src[i] = construct_value(build, expr->srcs[i],
438 num_components, search_bitsize,
439 state, instr);
440 }
441
442 nir_builder_instr_insert(build, &alu->instr);
443
444 nir_alu_src val;
445 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
446 val.negate = false;
447 val.abs = false,
448 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
449
450 return val;
451 }
452
453 case nir_search_value_variable: {
454 const nir_search_variable *var = nir_search_value_as_variable(value);
455 assert(state->variables_seen & (1 << var->variable));
456
457 nir_alu_src val = { NIR_SRC_INIT };
458 nir_alu_src_copy(&val, &state->variables[var->variable],
459 (void *)build->shader);
460 assert(!var->is_constant);
461
462 return val;
463 }
464
465 case nir_search_value_constant: {
466 const nir_search_constant *c = nir_search_value_as_constant(value);
467 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
468
469 nir_ssa_def *cval;
470 switch (c->type) {
471 case nir_type_float:
472 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
473 break;
474
475 case nir_type_int:
476 case nir_type_uint:
477 cval = nir_imm_intN_t(build, c->data.i, bit_size);
478 break;
479
480 case nir_type_bool:
481 cval = nir_imm_boolN_t(build, c->data.u, bit_size);
482 break;
483
484 default:
485 unreachable("Invalid alu source type");
486 }
487
488 nir_alu_src val;
489 val.src = nir_src_for_ssa(cval);
490 val.negate = false;
491 val.abs = false,
492 memset(val.swizzle, 0, sizeof val.swizzle);
493
494 return val;
495 }
496
497 default:
498 unreachable("Invalid search value type");
499 }
500 }
501
502 nir_ssa_def *
503 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
504 const nir_search_expression *search,
505 const nir_search_value *replace)
506 {
507 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
508
509 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
510 swizzle[i] = i;
511
512 assert(instr->dest.dest.is_ssa);
513
514 struct match_state state;
515 state.inexact_match = false;
516 state.has_exact_alu = false;
517 state.variables_seen = 0;
518
519 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
520 swizzle, &state))
521 return NULL;
522
523 build->cursor = nir_before_instr(&instr->instr);
524
525 nir_alu_src val = construct_value(build, replace,
526 instr->dest.dest.ssa.num_components,
527 instr->dest.dest.ssa.bit_size,
528 &state, &instr->instr);
529
530 /* Inserting a mov may be unnecessary. However, it's much easier to
531 * simply let copy propagation clean this up than to try to go through
532 * and rewrite swizzles ourselves.
533 */
534 nir_ssa_def *ssa_val =
535 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
536 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
537
538 /* We know this one has no more uses because we just rewrote them all,
539 * so we can remove it. The rest of the matched expression, however, we
540 * don't know so much about. We'll just let dead code clean them up.
541 */
542 nir_instr_remove(&instr->instr);
543
544 return ssa_val;
545 }