compiler: replace MAYBE_UNUSED with UNUSED
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
34 #define NIR_SEARCH_MAX_COMM_OPS 8
35
36 struct match_state {
37 bool inexact_match;
38 bool has_exact_alu;
39 uint8_t comm_op_direction;
40 unsigned variables_seen;
41 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
42 };
43
44 static bool
45 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
46 unsigned num_components, const uint8_t *swizzle,
47 struct match_state *state);
48
49 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
50
51 /**
52 * Check if a source produces a value of the given type.
53 *
54 * Used for satisfying 'a@type' constraints.
55 */
56 static bool
57 src_is_type(nir_src src, nir_alu_type type)
58 {
59 assert(type != nir_type_invalid);
60
61 if (!src.is_ssa)
62 return false;
63
64 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
65 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
66 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
67
68 if (type == nir_type_bool) {
69 switch (src_alu->op) {
70 case nir_op_iand:
71 case nir_op_ior:
72 case nir_op_ixor:
73 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
74 src_is_type(src_alu->src[1].src, nir_type_bool);
75 case nir_op_inot:
76 return src_is_type(src_alu->src[0].src, nir_type_bool);
77 default:
78 break;
79 }
80 }
81
82 return nir_alu_type_get_base_type(output_type) == type;
83 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
84 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
85
86 if (type == nir_type_bool) {
87 return intr->intrinsic == nir_intrinsic_load_front_face ||
88 intr->intrinsic == nir_intrinsic_load_helper_invocation;
89 }
90 }
91
92 /* don't know */
93 return false;
94 }
95
96 static bool
97 nir_op_matches_search_op(nir_op nop, uint16_t sop)
98 {
99 if (sop <= nir_last_opcode)
100 return nop == sop;
101
102 #define MATCH_FCONV_CASE(op) \
103 case nir_search_op_##op: \
104 return nop == nir_op_##op##16 || \
105 nop == nir_op_##op##32 || \
106 nop == nir_op_##op##64;
107
108 #define MATCH_ICONV_CASE(op) \
109 case nir_search_op_##op: \
110 return nop == nir_op_##op##8 || \
111 nop == nir_op_##op##16 || \
112 nop == nir_op_##op##32 || \
113 nop == nir_op_##op##64;
114
115 #define MATCH_BCONV_CASE(op) \
116 case nir_search_op_##op: \
117 return nop == nir_op_##op##1 || \
118 nop == nir_op_##op##32;
119
120 switch (sop) {
121 MATCH_FCONV_CASE(i2f)
122 MATCH_FCONV_CASE(u2f)
123 MATCH_FCONV_CASE(f2f)
124 MATCH_ICONV_CASE(f2u)
125 MATCH_ICONV_CASE(f2i)
126 MATCH_ICONV_CASE(u2u)
127 MATCH_ICONV_CASE(i2i)
128 MATCH_FCONV_CASE(b2f)
129 MATCH_ICONV_CASE(b2i)
130 MATCH_BCONV_CASE(i2b)
131 MATCH_BCONV_CASE(f2b)
132 default:
133 unreachable("Invalid nir_search_op");
134 }
135
136 #undef MATCH_FCONV_CASE
137 #undef MATCH_ICONV_CASE
138 #undef MATCH_BCONV_CASE
139 }
140
141 uint16_t
142 nir_search_op_for_nir_op(nir_op nop)
143 {
144 #define MATCH_FCONV_CASE(op) \
145 case nir_op_##op##16: \
146 case nir_op_##op##32: \
147 case nir_op_##op##64: \
148 return nir_search_op_##op;
149
150 #define MATCH_ICONV_CASE(op) \
151 case nir_op_##op##8: \
152 case nir_op_##op##16: \
153 case nir_op_##op##32: \
154 case nir_op_##op##64: \
155 return nir_search_op_##op;
156
157 #define MATCH_BCONV_CASE(op) \
158 case nir_op_##op##1: \
159 case nir_op_##op##32: \
160 return nir_search_op_##op;
161
162
163 switch (nop) {
164 MATCH_FCONV_CASE(i2f)
165 MATCH_FCONV_CASE(u2f)
166 MATCH_FCONV_CASE(f2f)
167 MATCH_ICONV_CASE(f2u)
168 MATCH_ICONV_CASE(f2i)
169 MATCH_ICONV_CASE(u2u)
170 MATCH_ICONV_CASE(i2i)
171 MATCH_FCONV_CASE(b2f)
172 MATCH_ICONV_CASE(b2i)
173 MATCH_BCONV_CASE(i2b)
174 MATCH_BCONV_CASE(f2b)
175 default:
176 return nop;
177 }
178
179 #undef MATCH_FCONV_CASE
180 #undef MATCH_ICONV_CASE
181 #undef MATCH_BCONV_CASE
182 }
183
184 static nir_op
185 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
186 {
187 if (sop <= nir_last_opcode)
188 return sop;
189
190 #define RET_FCONV_CASE(op) \
191 case nir_search_op_##op: \
192 switch (bit_size) { \
193 case 16: return nir_op_##op##16; \
194 case 32: return nir_op_##op##32; \
195 case 64: return nir_op_##op##64; \
196 default: unreachable("Invalid bit size"); \
197 }
198
199 #define RET_ICONV_CASE(op) \
200 case nir_search_op_##op: \
201 switch (bit_size) { \
202 case 8: return nir_op_##op##8; \
203 case 16: return nir_op_##op##16; \
204 case 32: return nir_op_##op##32; \
205 case 64: return nir_op_##op##64; \
206 default: unreachable("Invalid bit size"); \
207 }
208
209 #define RET_BCONV_CASE(op) \
210 case nir_search_op_##op: \
211 switch (bit_size) { \
212 case 1: return nir_op_##op##1; \
213 case 32: return nir_op_##op##32; \
214 default: unreachable("Invalid bit size"); \
215 }
216
217 switch (sop) {
218 RET_FCONV_CASE(i2f)
219 RET_FCONV_CASE(u2f)
220 RET_FCONV_CASE(f2f)
221 RET_ICONV_CASE(f2u)
222 RET_ICONV_CASE(f2i)
223 RET_ICONV_CASE(u2u)
224 RET_ICONV_CASE(i2i)
225 RET_FCONV_CASE(b2f)
226 RET_ICONV_CASE(b2i)
227 RET_BCONV_CASE(i2b)
228 RET_BCONV_CASE(f2b)
229 default:
230 unreachable("Invalid nir_search_op");
231 }
232
233 #undef RET_FCONV_CASE
234 #undef RET_ICONV_CASE
235 #undef RET_BCONV_CASE
236 }
237
238 static bool
239 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
240 unsigned num_components, const uint8_t *swizzle,
241 struct match_state *state)
242 {
243 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
244
245 /* Searching only works on SSA values because, if it's not SSA, we can't
246 * know if the value changed between one instance of that value in the
247 * expression and another. Also, the replace operation will place reads of
248 * that value right before the last instruction in the expression we're
249 * replacing so those reads will happen after the original reads and may
250 * not be valid if they're register reads.
251 */
252 assert(instr->src[src].src.is_ssa);
253
254 /* If the source is an explicitly sized source, then we need to reset
255 * both the number of components and the swizzle.
256 */
257 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
258 num_components = nir_op_infos[instr->op].input_sizes[src];
259 swizzle = identity_swizzle;
260 }
261
262 for (unsigned i = 0; i < num_components; ++i)
263 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
264
265 /* If the value has a specific bit size and it doesn't match, bail */
266 if (value->bit_size > 0 &&
267 nir_src_bit_size(instr->src[src].src) != value->bit_size)
268 return false;
269
270 switch (value->type) {
271 case nir_search_value_expression:
272 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
273 return false;
274
275 return match_expression(nir_search_value_as_expression(value),
276 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
277 num_components, new_swizzle, state);
278
279 case nir_search_value_variable: {
280 nir_search_variable *var = nir_search_value_as_variable(value);
281 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
282
283 if (state->variables_seen & (1 << var->variable)) {
284 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
285 return false;
286
287 assert(!instr->src[src].abs && !instr->src[src].negate);
288
289 for (unsigned i = 0; i < num_components; ++i) {
290 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
291 return false;
292 }
293
294 return true;
295 } else {
296 if (var->is_constant &&
297 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
298 return false;
299
300 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
301 return false;
302
303 if (var->type != nir_type_invalid &&
304 !src_is_type(instr->src[src].src, var->type))
305 return false;
306
307 state->variables_seen |= (1 << var->variable);
308 state->variables[var->variable].src = instr->src[src].src;
309 state->variables[var->variable].abs = false;
310 state->variables[var->variable].negate = false;
311
312 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
313 if (i < num_components)
314 state->variables[var->variable].swizzle[i] = new_swizzle[i];
315 else
316 state->variables[var->variable].swizzle[i] = 0;
317 }
318
319 return true;
320 }
321 }
322
323 case nir_search_value_constant: {
324 nir_search_constant *const_val = nir_search_value_as_constant(value);
325
326 if (!nir_src_is_const(instr->src[src].src))
327 return false;
328
329 switch (const_val->type) {
330 case nir_type_float: {
331 nir_load_const_instr *const load =
332 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
333
334 /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
335 * 1-bit float types. This prevents potential assertion failures in
336 * nir_src_comp_as_float.
337 */
338 if (load->def.bit_size < 16)
339 return false;
340
341 for (unsigned i = 0; i < num_components; ++i) {
342 double val = nir_src_comp_as_float(instr->src[src].src,
343 new_swizzle[i]);
344 if (val != const_val->data.d)
345 return false;
346 }
347 return true;
348 }
349
350 case nir_type_int:
351 case nir_type_uint:
352 case nir_type_bool: {
353 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
354 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
355 for (unsigned i = 0; i < num_components; ++i) {
356 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
357 new_swizzle[i]);
358 if ((val & mask) != (const_val->data.u & mask))
359 return false;
360 }
361 return true;
362 }
363
364 default:
365 unreachable("Invalid alu source type");
366 }
367 }
368
369 default:
370 unreachable("Invalid search value type");
371 }
372 }
373
374 static bool
375 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
376 unsigned num_components, const uint8_t *swizzle,
377 struct match_state *state)
378 {
379 if (expr->cond && !expr->cond(instr))
380 return false;
381
382 if (!nir_op_matches_search_op(instr->op, expr->opcode))
383 return false;
384
385 assert(instr->dest.dest.is_ssa);
386
387 if (expr->value.bit_size > 0 &&
388 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
389 return false;
390
391 state->inexact_match = expr->inexact || state->inexact_match;
392 state->has_exact_alu = instr->exact || state->has_exact_alu;
393 if (state->inexact_match && state->has_exact_alu)
394 return false;
395
396 assert(!instr->dest.saturate);
397 assert(nir_op_infos[instr->op].num_inputs > 0);
398
399 /* If we have an explicitly sized destination, we can only handle the
400 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
401 * expression, we don't have the information right now to propagate that
402 * swizzle through. We can only properly propagate swizzles if the
403 * instruction is vectorized.
404 */
405 if (nir_op_infos[instr->op].output_size != 0) {
406 for (unsigned i = 0; i < num_components; i++) {
407 if (swizzle[i] != i)
408 return false;
409 }
410 }
411
412 /* If this is a commutative expression and it's one of the first few, look
413 * up its direction for the current search operation. We'll use that value
414 * to possibly flip the sources for the match.
415 */
416 unsigned comm_op_flip =
417 (expr->comm_expr_idx >= 0 &&
418 expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
419 ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
420
421 bool matched = true;
422 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
423 /* 2src_commutative instructions that have 3 sources are only commutative
424 * in the first two sources. Source 2 is always source 2.
425 */
426 if (!match_value(expr->srcs[i], instr,
427 i < 2 ? i ^ comm_op_flip : i,
428 num_components, swizzle, state)) {
429 matched = false;
430 break;
431 }
432 }
433
434 return matched;
435 }
436
437 static unsigned
438 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
439 struct match_state *state)
440 {
441 if (value->bit_size > 0)
442 return value->bit_size;
443 if (value->bit_size < 0)
444 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
445 return search_bitsize;
446 }
447
448 static nir_alu_src
449 construct_value(nir_builder *build,
450 const nir_search_value *value,
451 unsigned num_components, unsigned search_bitsize,
452 struct match_state *state,
453 nir_instr *instr)
454 {
455 switch (value->type) {
456 case nir_search_value_expression: {
457 const nir_search_expression *expr = nir_search_value_as_expression(value);
458 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
459 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
460
461 if (nir_op_infos[op].output_size != 0)
462 num_components = nir_op_infos[op].output_size;
463
464 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
465 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
466 dst_bit_size, NULL);
467 alu->dest.write_mask = (1 << num_components) - 1;
468 alu->dest.saturate = false;
469
470 /* We have no way of knowing what values in a given search expression
471 * map to a particular replacement value. Therefore, if the
472 * expression we are replacing has any exact values, the entire
473 * replacement should be exact.
474 */
475 alu->exact = state->has_exact_alu;
476
477 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
478 /* If the source is an explicitly sized source, then we need to reset
479 * the number of components to match.
480 */
481 if (nir_op_infos[alu->op].input_sizes[i] != 0)
482 num_components = nir_op_infos[alu->op].input_sizes[i];
483
484 alu->src[i] = construct_value(build, expr->srcs[i],
485 num_components, search_bitsize,
486 state, instr);
487 }
488
489 nir_builder_instr_insert(build, &alu->instr);
490
491 nir_alu_src val;
492 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
493 val.negate = false;
494 val.abs = false,
495 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
496
497 return val;
498 }
499
500 case nir_search_value_variable: {
501 const nir_search_variable *var = nir_search_value_as_variable(value);
502 assert(state->variables_seen & (1 << var->variable));
503
504 nir_alu_src val = { NIR_SRC_INIT };
505 nir_alu_src_copy(&val, &state->variables[var->variable],
506 (void *)build->shader);
507 assert(!var->is_constant);
508
509 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
510 val.swizzle[i] = state->variables[var->variable].swizzle[var->swizzle[i]];
511
512 return val;
513 }
514
515 case nir_search_value_constant: {
516 const nir_search_constant *c = nir_search_value_as_constant(value);
517 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
518
519 nir_ssa_def *cval;
520 switch (c->type) {
521 case nir_type_float:
522 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
523 break;
524
525 case nir_type_int:
526 case nir_type_uint:
527 cval = nir_imm_intN_t(build, c->data.i, bit_size);
528 break;
529
530 case nir_type_bool:
531 cval = nir_imm_boolN_t(build, c->data.u, bit_size);
532 break;
533
534 default:
535 unreachable("Invalid alu source type");
536 }
537
538 nir_alu_src val;
539 val.src = nir_src_for_ssa(cval);
540 val.negate = false;
541 val.abs = false,
542 memset(val.swizzle, 0, sizeof val.swizzle);
543
544 return val;
545 }
546
547 default:
548 unreachable("Invalid search value type");
549 }
550 }
551
552 UNUSED static void dump_value(const nir_search_value *val)
553 {
554 switch (val->type) {
555 case nir_search_value_constant: {
556 const nir_search_constant *sconst = nir_search_value_as_constant(val);
557 switch (sconst->type) {
558 case nir_type_float:
559 printf("%f", sconst->data.d);
560 break;
561 case nir_type_int:
562 printf("%"PRId64, sconst->data.i);
563 break;
564 case nir_type_uint:
565 printf("0x%"PRIx64, sconst->data.u);
566 break;
567 case nir_type_bool:
568 printf("%s", sconst->data.u != 0 ? "True" : "False");
569 break;
570 default:
571 unreachable("bad const type");
572 }
573 break;
574 }
575
576 case nir_search_value_variable: {
577 const nir_search_variable *var = nir_search_value_as_variable(val);
578 if (var->is_constant)
579 printf("#");
580 printf("%c", var->variable + 'a');
581 break;
582 }
583
584 case nir_search_value_expression: {
585 const nir_search_expression *expr = nir_search_value_as_expression(val);
586 printf("(");
587 if (expr->inexact)
588 printf("~");
589 switch (expr->opcode) {
590 #define CASE(n) \
591 case nir_search_op_##n: printf(#n); break;
592 CASE(f2b)
593 CASE(b2f)
594 CASE(b2i)
595 CASE(i2b)
596 CASE(i2i)
597 CASE(f2i)
598 CASE(i2f)
599 #undef CASE
600 default:
601 printf("%s", nir_op_infos[expr->opcode].name);
602 }
603
604 unsigned num_srcs = 1;
605 if (expr->opcode <= nir_last_opcode)
606 num_srcs = nir_op_infos[expr->opcode].num_inputs;
607
608 for (unsigned i = 0; i < num_srcs; i++) {
609 printf(" ");
610 dump_value(expr->srcs[i]);
611 }
612
613 printf(")");
614 break;
615 }
616 }
617
618 if (val->bit_size > 0)
619 printf("@%d", val->bit_size);
620 }
621
622 nir_ssa_def *
623 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
624 const nir_search_expression *search,
625 const nir_search_value *replace)
626 {
627 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
628
629 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
630 swizzle[i] = i;
631
632 assert(instr->dest.dest.is_ssa);
633
634 struct match_state state;
635 state.inexact_match = false;
636 state.has_exact_alu = false;
637
638 STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
639
640 unsigned comm_expr_combinations =
641 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
642
643 bool found = false;
644 for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
645 /* The bitfield of directions is just the current iteration. Hooray for
646 * binary.
647 */
648 state.comm_op_direction = comb;
649 state.variables_seen = 0;
650
651 if (match_expression(search, instr,
652 instr->dest.dest.ssa.num_components,
653 swizzle, &state)) {
654 found = true;
655 break;
656 }
657 }
658 if (!found)
659 return NULL;
660
661 #if 0
662 printf("matched: ");
663 dump_value(&search->value);
664 printf(" -> ");
665 dump_value(replace);
666 printf(" ssa_%d\n", instr->dest.dest.ssa.index);
667 #endif
668
669 build->cursor = nir_before_instr(&instr->instr);
670
671 nir_alu_src val = construct_value(build, replace,
672 instr->dest.dest.ssa.num_components,
673 instr->dest.dest.ssa.bit_size,
674 &state, &instr->instr);
675
676 /* Inserting a mov may be unnecessary. However, it's much easier to
677 * simply let copy propagation clean this up than to try to go through
678 * and rewrite swizzles ourselves.
679 */
680 nir_ssa_def *ssa_val =
681 nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
682 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
683
684 /* We know this one has no more uses because we just rewrote them all,
685 * so we can remove it. The rest of the matched expression, however, we
686 * don't know so much about. We'll just let dead code clean them up.
687 */
688 nir_instr_remove(&instr->instr);
689
690 return ssa_val;
691 }