nir/search: Add debugging code to dump the pattern matched
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 #define NIR_SEARCH_MAX_COMM_OPS 4
34
35 struct match_state {
36 bool inexact_match;
37 bool has_exact_alu;
38 uint8_t comm_op_direction;
39 unsigned variables_seen;
40 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
41 };
42
43 static bool
44 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
45 unsigned num_components, const uint8_t *swizzle,
46 struct match_state *state);
47
48 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
49
50 /**
51 * Check if a source produces a value of the given type.
52 *
53 * Used for satisfying 'a@type' constraints.
54 */
55 static bool
56 src_is_type(nir_src src, nir_alu_type type)
57 {
58 assert(type != nir_type_invalid);
59
60 if (!src.is_ssa)
61 return false;
62
63 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
64 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
65 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
66
67 if (type == nir_type_bool) {
68 switch (src_alu->op) {
69 case nir_op_iand:
70 case nir_op_ior:
71 case nir_op_ixor:
72 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
73 src_is_type(src_alu->src[1].src, nir_type_bool);
74 case nir_op_inot:
75 return src_is_type(src_alu->src[0].src, nir_type_bool);
76 default:
77 break;
78 }
79 }
80
81 return nir_alu_type_get_base_type(output_type) == type;
82 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
83 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
84
85 if (type == nir_type_bool) {
86 return intr->intrinsic == nir_intrinsic_load_front_face ||
87 intr->intrinsic == nir_intrinsic_load_helper_invocation;
88 }
89 }
90
91 /* don't know */
92 return false;
93 }
94
95 static bool
96 nir_op_matches_search_op(nir_op nop, uint16_t sop)
97 {
98 if (sop <= nir_last_opcode)
99 return nop == sop;
100
101 #define MATCH_FCONV_CASE(op) \
102 case nir_search_op_##op: \
103 return nop == nir_op_##op##16 || \
104 nop == nir_op_##op##32 || \
105 nop == nir_op_##op##64;
106
107 #define MATCH_ICONV_CASE(op) \
108 case nir_search_op_##op: \
109 return nop == nir_op_##op##8 || \
110 nop == nir_op_##op##16 || \
111 nop == nir_op_##op##32 || \
112 nop == nir_op_##op##64;
113
114 #define MATCH_BCONV_CASE(op) \
115 case nir_search_op_##op: \
116 return nop == nir_op_##op##1 || \
117 nop == nir_op_##op##32;
118
119 switch (sop) {
120 MATCH_FCONV_CASE(i2f)
121 MATCH_FCONV_CASE(u2f)
122 MATCH_FCONV_CASE(f2f)
123 MATCH_ICONV_CASE(f2u)
124 MATCH_ICONV_CASE(f2i)
125 MATCH_ICONV_CASE(u2u)
126 MATCH_ICONV_CASE(i2i)
127 MATCH_FCONV_CASE(b2f)
128 MATCH_ICONV_CASE(b2i)
129 MATCH_BCONV_CASE(i2b)
130 MATCH_BCONV_CASE(f2b)
131 default:
132 unreachable("Invalid nir_search_op");
133 }
134
135 #undef MATCH_FCONV_CASE
136 #undef MATCH_ICONV_CASE
137 #undef MATCH_BCONV_CASE
138 }
139
140 uint16_t
141 nir_search_op_for_nir_op(nir_op nop)
142 {
143 #define MATCH_FCONV_CASE(op) \
144 case nir_op_##op##16: \
145 case nir_op_##op##32: \
146 case nir_op_##op##64: \
147 return nir_search_op_##op;
148
149 #define MATCH_ICONV_CASE(op) \
150 case nir_op_##op##8: \
151 case nir_op_##op##16: \
152 case nir_op_##op##32: \
153 case nir_op_##op##64: \
154 return nir_search_op_##op;
155
156 #define MATCH_BCONV_CASE(op) \
157 case nir_op_##op##1: \
158 case nir_op_##op##32: \
159 return nir_search_op_##op;
160
161
162 switch (nop) {
163 MATCH_FCONV_CASE(i2f)
164 MATCH_FCONV_CASE(u2f)
165 MATCH_FCONV_CASE(f2f)
166 MATCH_ICONV_CASE(f2u)
167 MATCH_ICONV_CASE(f2i)
168 MATCH_ICONV_CASE(u2u)
169 MATCH_ICONV_CASE(i2i)
170 MATCH_FCONV_CASE(b2f)
171 MATCH_ICONV_CASE(b2i)
172 MATCH_BCONV_CASE(i2b)
173 MATCH_BCONV_CASE(f2b)
174 default:
175 return nop;
176 }
177
178 #undef MATCH_FCONV_CASE
179 #undef MATCH_ICONV_CASE
180 #undef MATCH_BCONV_CASE
181 }
182
183 static nir_op
184 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
185 {
186 if (sop <= nir_last_opcode)
187 return sop;
188
189 #define RET_FCONV_CASE(op) \
190 case nir_search_op_##op: \
191 switch (bit_size) { \
192 case 16: return nir_op_##op##16; \
193 case 32: return nir_op_##op##32; \
194 case 64: return nir_op_##op##64; \
195 default: unreachable("Invalid bit size"); \
196 }
197
198 #define RET_ICONV_CASE(op) \
199 case nir_search_op_##op: \
200 switch (bit_size) { \
201 case 8: return nir_op_##op##8; \
202 case 16: return nir_op_##op##16; \
203 case 32: return nir_op_##op##32; \
204 case 64: return nir_op_##op##64; \
205 default: unreachable("Invalid bit size"); \
206 }
207
208 #define RET_BCONV_CASE(op) \
209 case nir_search_op_##op: \
210 switch (bit_size) { \
211 case 1: return nir_op_##op##1; \
212 case 32: return nir_op_##op##32; \
213 default: unreachable("Invalid bit size"); \
214 }
215
216 switch (sop) {
217 RET_FCONV_CASE(i2f)
218 RET_FCONV_CASE(u2f)
219 RET_FCONV_CASE(f2f)
220 RET_ICONV_CASE(f2u)
221 RET_ICONV_CASE(f2i)
222 RET_ICONV_CASE(u2u)
223 RET_ICONV_CASE(i2i)
224 RET_FCONV_CASE(b2f)
225 RET_ICONV_CASE(b2i)
226 RET_BCONV_CASE(i2b)
227 RET_BCONV_CASE(f2b)
228 default:
229 unreachable("Invalid nir_search_op");
230 }
231
232 #undef RET_FCONV_CASE
233 #undef RET_ICONV_CASE
234 #undef RET_BCONV_CASE
235 }
236
237 static bool
238 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
239 unsigned num_components, const uint8_t *swizzle,
240 struct match_state *state)
241 {
242 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
243
244 /* Searching only works on SSA values because, if it's not SSA, we can't
245 * know if the value changed between one instance of that value in the
246 * expression and another. Also, the replace operation will place reads of
247 * that value right before the last instruction in the expression we're
248 * replacing so those reads will happen after the original reads and may
249 * not be valid if they're register reads.
250 */
251 assert(instr->src[src].src.is_ssa);
252
253 /* If the source is an explicitly sized source, then we need to reset
254 * both the number of components and the swizzle.
255 */
256 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
257 num_components = nir_op_infos[instr->op].input_sizes[src];
258 swizzle = identity_swizzle;
259 }
260
261 for (unsigned i = 0; i < num_components; ++i)
262 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
263
264 /* If the value has a specific bit size and it doesn't match, bail */
265 if (value->bit_size > 0 &&
266 nir_src_bit_size(instr->src[src].src) != value->bit_size)
267 return false;
268
269 switch (value->type) {
270 case nir_search_value_expression:
271 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
272 return false;
273
274 return match_expression(nir_search_value_as_expression(value),
275 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
276 num_components, new_swizzle, state);
277
278 case nir_search_value_variable: {
279 nir_search_variable *var = nir_search_value_as_variable(value);
280 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
281
282 if (state->variables_seen & (1 << var->variable)) {
283 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
284 return false;
285
286 assert(!instr->src[src].abs && !instr->src[src].negate);
287
288 for (unsigned i = 0; i < num_components; ++i) {
289 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
290 return false;
291 }
292
293 return true;
294 } else {
295 if (var->is_constant &&
296 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
297 return false;
298
299 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
300 return false;
301
302 if (var->type != nir_type_invalid &&
303 !src_is_type(instr->src[src].src, var->type))
304 return false;
305
306 state->variables_seen |= (1 << var->variable);
307 state->variables[var->variable].src = instr->src[src].src;
308 state->variables[var->variable].abs = false;
309 state->variables[var->variable].negate = false;
310
311 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
312 if (i < num_components)
313 state->variables[var->variable].swizzle[i] = new_swizzle[i];
314 else
315 state->variables[var->variable].swizzle[i] = 0;
316 }
317
318 return true;
319 }
320 }
321
322 case nir_search_value_constant: {
323 nir_search_constant *const_val = nir_search_value_as_constant(value);
324
325 if (!nir_src_is_const(instr->src[src].src))
326 return false;
327
328 switch (const_val->type) {
329 case nir_type_float:
330 for (unsigned i = 0; i < num_components; ++i) {
331 double val = nir_src_comp_as_float(instr->src[src].src,
332 new_swizzle[i]);
333 if (val != const_val->data.d)
334 return false;
335 }
336 return true;
337
338 case nir_type_int:
339 case nir_type_uint:
340 case nir_type_bool: {
341 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
342 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
343 for (unsigned i = 0; i < num_components; ++i) {
344 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
345 new_swizzle[i]);
346 if ((val & mask) != (const_val->data.u & mask))
347 return false;
348 }
349 return true;
350 }
351
352 default:
353 unreachable("Invalid alu source type");
354 }
355 }
356
357 default:
358 unreachable("Invalid search value type");
359 }
360 }
361
362 static bool
363 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
364 unsigned num_components, const uint8_t *swizzle,
365 struct match_state *state)
366 {
367 if (expr->cond && !expr->cond(instr))
368 return false;
369
370 if (!nir_op_matches_search_op(instr->op, expr->opcode))
371 return false;
372
373 assert(instr->dest.dest.is_ssa);
374
375 if (expr->value.bit_size > 0 &&
376 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
377 return false;
378
379 state->inexact_match = expr->inexact || state->inexact_match;
380 state->has_exact_alu = instr->exact || state->has_exact_alu;
381 if (state->inexact_match && state->has_exact_alu)
382 return false;
383
384 assert(!instr->dest.saturate);
385 assert(nir_op_infos[instr->op].num_inputs > 0);
386
387 /* If we have an explicitly sized destination, we can only handle the
388 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
389 * expression, we don't have the information right now to propagate that
390 * swizzle through. We can only properly propagate swizzles if the
391 * instruction is vectorized.
392 */
393 if (nir_op_infos[instr->op].output_size != 0) {
394 for (unsigned i = 0; i < num_components; i++) {
395 if (swizzle[i] != i)
396 return false;
397 }
398 }
399
400 /* If this is a commutative expression and it's one of the first few, look
401 * up its direction for the current search operation. We'll use that value
402 * to possibly flip the sources for the match.
403 */
404 unsigned comm_op_flip =
405 (expr->comm_expr_idx >= 0 &&
406 expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
407 ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
408
409 bool matched = true;
410 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
411 if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
412 num_components, swizzle, state)) {
413 matched = false;
414 break;
415 }
416 }
417
418 return matched;
419 }
420
421 static unsigned
422 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
423 struct match_state *state)
424 {
425 if (value->bit_size > 0)
426 return value->bit_size;
427 if (value->bit_size < 0)
428 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
429 return search_bitsize;
430 }
431
432 static nir_alu_src
433 construct_value(nir_builder *build,
434 const nir_search_value *value,
435 unsigned num_components, unsigned search_bitsize,
436 struct match_state *state,
437 nir_instr *instr)
438 {
439 switch (value->type) {
440 case nir_search_value_expression: {
441 const nir_search_expression *expr = nir_search_value_as_expression(value);
442 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
443 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
444
445 if (nir_op_infos[op].output_size != 0)
446 num_components = nir_op_infos[op].output_size;
447
448 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
449 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
450 dst_bit_size, NULL);
451 alu->dest.write_mask = (1 << num_components) - 1;
452 alu->dest.saturate = false;
453
454 /* We have no way of knowing what values in a given search expression
455 * map to a particular replacement value. Therefore, if the
456 * expression we are replacing has any exact values, the entire
457 * replacement should be exact.
458 */
459 alu->exact = state->has_exact_alu;
460
461 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
462 /* If the source is an explicitly sized source, then we need to reset
463 * the number of components to match.
464 */
465 if (nir_op_infos[alu->op].input_sizes[i] != 0)
466 num_components = nir_op_infos[alu->op].input_sizes[i];
467
468 alu->src[i] = construct_value(build, expr->srcs[i],
469 num_components, search_bitsize,
470 state, instr);
471 }
472
473 nir_builder_instr_insert(build, &alu->instr);
474
475 nir_alu_src val;
476 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
477 val.negate = false;
478 val.abs = false,
479 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
480
481 return val;
482 }
483
484 case nir_search_value_variable: {
485 const nir_search_variable *var = nir_search_value_as_variable(value);
486 assert(state->variables_seen & (1 << var->variable));
487
488 nir_alu_src val = { NIR_SRC_INIT };
489 nir_alu_src_copy(&val, &state->variables[var->variable],
490 (void *)build->shader);
491 assert(!var->is_constant);
492
493 return val;
494 }
495
496 case nir_search_value_constant: {
497 const nir_search_constant *c = nir_search_value_as_constant(value);
498 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
499
500 nir_ssa_def *cval;
501 switch (c->type) {
502 case nir_type_float:
503 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
504 break;
505
506 case nir_type_int:
507 case nir_type_uint:
508 cval = nir_imm_intN_t(build, c->data.i, bit_size);
509 break;
510
511 case nir_type_bool:
512 cval = nir_imm_boolN_t(build, c->data.u, bit_size);
513 break;
514
515 default:
516 unreachable("Invalid alu source type");
517 }
518
519 nir_alu_src val;
520 val.src = nir_src_for_ssa(cval);
521 val.negate = false;
522 val.abs = false,
523 memset(val.swizzle, 0, sizeof val.swizzle);
524
525 return val;
526 }
527
528 default:
529 unreachable("Invalid search value type");
530 }
531 }
532
533 MAYBE_UNUSED static void dump_value(const nir_search_value *val)
534 {
535 switch (val->type) {
536 case nir_search_value_constant: {
537 const nir_search_constant *sconst = nir_search_value_as_constant(val);
538 switch (sconst->type) {
539 case nir_type_float:
540 printf("%f", sconst->data.d);
541 break;
542 case nir_type_int:
543 printf("%"PRId64, sconst->data.i);
544 break;
545 case nir_type_uint:
546 printf("0x%"PRIx64, sconst->data.u);
547 break;
548 default:
549 unreachable("bad const type");
550 }
551 break;
552 }
553
554 case nir_search_value_variable: {
555 const nir_search_variable *var = nir_search_value_as_variable(val);
556 if (var->is_constant)
557 printf("#");
558 printf("%c", var->variable + 'a');
559 break;
560 }
561
562 case nir_search_value_expression: {
563 const nir_search_expression *expr = nir_search_value_as_expression(val);
564 printf("(");
565 if (expr->inexact)
566 printf("~");
567 switch (expr->opcode) {
568 #define CASE(n) \
569 case nir_search_op_##n: printf(#n); break;
570 CASE(f2b)
571 CASE(b2f)
572 CASE(b2i)
573 CASE(i2b)
574 CASE(i2i)
575 CASE(f2i)
576 CASE(i2f)
577 #undef CASE
578 default:
579 printf("%s", nir_op_infos[expr->opcode].name);
580 }
581
582 unsigned num_srcs = 1;
583 if (expr->opcode <= nir_last_opcode)
584 num_srcs = nir_op_infos[expr->opcode].num_inputs;
585
586 for (unsigned i = 0; i < num_srcs; i++) {
587 printf(" ");
588 dump_value(expr->srcs[i]);
589 }
590
591 printf(")");
592 break;
593 }
594 }
595
596 if (val->bit_size > 0)
597 printf("@%d", val->bit_size);
598 }
599
600 nir_ssa_def *
601 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
602 const nir_search_expression *search,
603 const nir_search_value *replace)
604 {
605 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
606
607 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
608 swizzle[i] = i;
609
610 assert(instr->dest.dest.is_ssa);
611
612 struct match_state state;
613 state.inexact_match = false;
614 state.has_exact_alu = false;
615
616 unsigned comm_expr_combinations =
617 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
618
619 bool found = false;
620 for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
621 /* The bitfield of directions is just the current iteration. Hooray for
622 * binary.
623 */
624 state.comm_op_direction = comb;
625 state.variables_seen = 0;
626
627 if (match_expression(search, instr,
628 instr->dest.dest.ssa.num_components,
629 swizzle, &state)) {
630 found = true;
631 break;
632 }
633 }
634 if (!found)
635 return NULL;
636
637 #if 0
638 printf("matched: ");
639 dump_value(&search->value);
640 printf(" -> ");
641 dump_value(replace);
642 printf(" ssa_%d\n", instr->dest.dest.ssa.index);
643 #endif
644
645 build->cursor = nir_before_instr(&instr->instr);
646
647 nir_alu_src val = construct_value(build, replace,
648 instr->dest.dest.ssa.num_components,
649 instr->dest.dest.ssa.bit_size,
650 &state, &instr->instr);
651
652 /* Inserting a mov may be unnecessary. However, it's much easier to
653 * simply let copy propagation clean this up than to try to go through
654 * and rewrite swizzles ourselves.
655 */
656 nir_ssa_def *ssa_val =
657 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
658 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
659
660 /* We know this one has no more uses because we just rewrote them all,
661 * so we can remove it. The rest of the matched expression, however, we
662 * don't know so much about. We'll just let dead code clean them up.
663 */
664 nir_instr_remove(&instr->instr);
665
666 return ssa_val;
667 }