nir: Add a nir_tex_instr_has_implicit_derivatives helper
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
34 #define NIR_SEARCH_MAX_COMM_OPS 8
35
36 struct match_state {
37 bool inexact_match;
38 bool has_exact_alu;
39 uint8_t comm_op_direction;
40 unsigned variables_seen;
41 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
42 };
43
44 static bool
45 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
46 unsigned num_components, const uint8_t *swizzle,
47 struct match_state *state);
48
49 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
50
51 /**
52 * Check if a source produces a value of the given type.
53 *
54 * Used for satisfying 'a@type' constraints.
55 */
56 static bool
57 src_is_type(nir_src src, nir_alu_type type)
58 {
59 assert(type != nir_type_invalid);
60
61 if (!src.is_ssa)
62 return false;
63
64 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
65 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
66 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
67
68 if (type == nir_type_bool) {
69 switch (src_alu->op) {
70 case nir_op_iand:
71 case nir_op_ior:
72 case nir_op_ixor:
73 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
74 src_is_type(src_alu->src[1].src, nir_type_bool);
75 case nir_op_inot:
76 return src_is_type(src_alu->src[0].src, nir_type_bool);
77 default:
78 break;
79 }
80 }
81
82 return nir_alu_type_get_base_type(output_type) == type;
83 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
84 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
85
86 if (type == nir_type_bool) {
87 return intr->intrinsic == nir_intrinsic_load_front_face ||
88 intr->intrinsic == nir_intrinsic_load_helper_invocation;
89 }
90 }
91
92 /* don't know */
93 return false;
94 }
95
96 static bool
97 nir_op_matches_search_op(nir_op nop, uint16_t sop)
98 {
99 if (sop <= nir_last_opcode)
100 return nop == sop;
101
102 #define MATCH_FCONV_CASE(op) \
103 case nir_search_op_##op: \
104 return nop == nir_op_##op##16 || \
105 nop == nir_op_##op##32 || \
106 nop == nir_op_##op##64;
107
108 #define MATCH_ICONV_CASE(op) \
109 case nir_search_op_##op: \
110 return nop == nir_op_##op##8 || \
111 nop == nir_op_##op##16 || \
112 nop == nir_op_##op##32 || \
113 nop == nir_op_##op##64;
114
115 #define MATCH_BCONV_CASE(op) \
116 case nir_search_op_##op: \
117 return nop == nir_op_##op##1 || \
118 nop == nir_op_##op##32;
119
120 switch (sop) {
121 MATCH_FCONV_CASE(i2f)
122 MATCH_FCONV_CASE(u2f)
123 MATCH_FCONV_CASE(f2f)
124 MATCH_ICONV_CASE(f2u)
125 MATCH_ICONV_CASE(f2i)
126 MATCH_ICONV_CASE(u2u)
127 MATCH_ICONV_CASE(i2i)
128 MATCH_FCONV_CASE(b2f)
129 MATCH_ICONV_CASE(b2i)
130 MATCH_BCONV_CASE(i2b)
131 MATCH_BCONV_CASE(f2b)
132 default:
133 unreachable("Invalid nir_search_op");
134 }
135
136 #undef MATCH_FCONV_CASE
137 #undef MATCH_ICONV_CASE
138 #undef MATCH_BCONV_CASE
139 }
140
141 uint16_t
142 nir_search_op_for_nir_op(nir_op nop)
143 {
144 #define MATCH_FCONV_CASE(op) \
145 case nir_op_##op##16: \
146 case nir_op_##op##32: \
147 case nir_op_##op##64: \
148 return nir_search_op_##op;
149
150 #define MATCH_ICONV_CASE(op) \
151 case nir_op_##op##8: \
152 case nir_op_##op##16: \
153 case nir_op_##op##32: \
154 case nir_op_##op##64: \
155 return nir_search_op_##op;
156
157 #define MATCH_BCONV_CASE(op) \
158 case nir_op_##op##1: \
159 case nir_op_##op##32: \
160 return nir_search_op_##op;
161
162
163 switch (nop) {
164 MATCH_FCONV_CASE(i2f)
165 MATCH_FCONV_CASE(u2f)
166 MATCH_FCONV_CASE(f2f)
167 MATCH_ICONV_CASE(f2u)
168 MATCH_ICONV_CASE(f2i)
169 MATCH_ICONV_CASE(u2u)
170 MATCH_ICONV_CASE(i2i)
171 MATCH_FCONV_CASE(b2f)
172 MATCH_ICONV_CASE(b2i)
173 MATCH_BCONV_CASE(i2b)
174 MATCH_BCONV_CASE(f2b)
175 default:
176 return nop;
177 }
178
179 #undef MATCH_FCONV_CASE
180 #undef MATCH_ICONV_CASE
181 #undef MATCH_BCONV_CASE
182 }
183
184 static nir_op
185 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
186 {
187 if (sop <= nir_last_opcode)
188 return sop;
189
190 #define RET_FCONV_CASE(op) \
191 case nir_search_op_##op: \
192 switch (bit_size) { \
193 case 16: return nir_op_##op##16; \
194 case 32: return nir_op_##op##32; \
195 case 64: return nir_op_##op##64; \
196 default: unreachable("Invalid bit size"); \
197 }
198
199 #define RET_ICONV_CASE(op) \
200 case nir_search_op_##op: \
201 switch (bit_size) { \
202 case 8: return nir_op_##op##8; \
203 case 16: return nir_op_##op##16; \
204 case 32: return nir_op_##op##32; \
205 case 64: return nir_op_##op##64; \
206 default: unreachable("Invalid bit size"); \
207 }
208
209 #define RET_BCONV_CASE(op) \
210 case nir_search_op_##op: \
211 switch (bit_size) { \
212 case 1: return nir_op_##op##1; \
213 case 32: return nir_op_##op##32; \
214 default: unreachable("Invalid bit size"); \
215 }
216
217 switch (sop) {
218 RET_FCONV_CASE(i2f)
219 RET_FCONV_CASE(u2f)
220 RET_FCONV_CASE(f2f)
221 RET_ICONV_CASE(f2u)
222 RET_ICONV_CASE(f2i)
223 RET_ICONV_CASE(u2u)
224 RET_ICONV_CASE(i2i)
225 RET_FCONV_CASE(b2f)
226 RET_ICONV_CASE(b2i)
227 RET_BCONV_CASE(i2b)
228 RET_BCONV_CASE(f2b)
229 default:
230 unreachable("Invalid nir_search_op");
231 }
232
233 #undef RET_FCONV_CASE
234 #undef RET_ICONV_CASE
235 #undef RET_BCONV_CASE
236 }
237
238 static bool
239 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
240 unsigned num_components, const uint8_t *swizzle,
241 struct match_state *state)
242 {
243 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
244
245 /* Searching only works on SSA values because, if it's not SSA, we can't
246 * know if the value changed between one instance of that value in the
247 * expression and another. Also, the replace operation will place reads of
248 * that value right before the last instruction in the expression we're
249 * replacing so those reads will happen after the original reads and may
250 * not be valid if they're register reads.
251 */
252 assert(instr->src[src].src.is_ssa);
253
254 /* If the source is an explicitly sized source, then we need to reset
255 * both the number of components and the swizzle.
256 */
257 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
258 num_components = nir_op_infos[instr->op].input_sizes[src];
259 swizzle = identity_swizzle;
260 }
261
262 for (unsigned i = 0; i < num_components; ++i)
263 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
264
265 /* If the value has a specific bit size and it doesn't match, bail */
266 if (value->bit_size > 0 &&
267 nir_src_bit_size(instr->src[src].src) != value->bit_size)
268 return false;
269
270 switch (value->type) {
271 case nir_search_value_expression:
272 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
273 return false;
274
275 return match_expression(nir_search_value_as_expression(value),
276 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
277 num_components, new_swizzle, state);
278
279 case nir_search_value_variable: {
280 nir_search_variable *var = nir_search_value_as_variable(value);
281 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
282
283 if (state->variables_seen & (1 << var->variable)) {
284 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
285 return false;
286
287 assert(!instr->src[src].abs && !instr->src[src].negate);
288
289 for (unsigned i = 0; i < num_components; ++i) {
290 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
291 return false;
292 }
293
294 return true;
295 } else {
296 if (var->is_constant &&
297 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
298 return false;
299
300 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
301 return false;
302
303 if (var->type != nir_type_invalid &&
304 !src_is_type(instr->src[src].src, var->type))
305 return false;
306
307 state->variables_seen |= (1 << var->variable);
308 state->variables[var->variable].src = instr->src[src].src;
309 state->variables[var->variable].abs = false;
310 state->variables[var->variable].negate = false;
311
312 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
313 if (i < num_components)
314 state->variables[var->variable].swizzle[i] = new_swizzle[i];
315 else
316 state->variables[var->variable].swizzle[i] = 0;
317 }
318
319 return true;
320 }
321 }
322
323 case nir_search_value_constant: {
324 nir_search_constant *const_val = nir_search_value_as_constant(value);
325
326 if (!nir_src_is_const(instr->src[src].src))
327 return false;
328
329 switch (const_val->type) {
330 case nir_type_float: {
331 nir_load_const_instr *const load =
332 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
333
334 /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
335 * 1-bit float types. This prevents potential assertion failures in
336 * nir_src_comp_as_float.
337 */
338 if (load->def.bit_size < 16)
339 return false;
340
341 for (unsigned i = 0; i < num_components; ++i) {
342 double val = nir_src_comp_as_float(instr->src[src].src,
343 new_swizzle[i]);
344 if (val != const_val->data.d)
345 return false;
346 }
347 return true;
348 }
349
350 case nir_type_int:
351 case nir_type_uint:
352 case nir_type_bool: {
353 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
354 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
355 for (unsigned i = 0; i < num_components; ++i) {
356 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
357 new_swizzle[i]);
358 if ((val & mask) != (const_val->data.u & mask))
359 return false;
360 }
361 return true;
362 }
363
364 default:
365 unreachable("Invalid alu source type");
366 }
367 }
368
369 default:
370 unreachable("Invalid search value type");
371 }
372 }
373
374 static bool
375 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
376 unsigned num_components, const uint8_t *swizzle,
377 struct match_state *state)
378 {
379 if (expr->cond && !expr->cond(instr))
380 return false;
381
382 if (!nir_op_matches_search_op(instr->op, expr->opcode))
383 return false;
384
385 assert(instr->dest.dest.is_ssa);
386
387 if (expr->value.bit_size > 0 &&
388 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
389 return false;
390
391 state->inexact_match = expr->inexact || state->inexact_match;
392 state->has_exact_alu = instr->exact || state->has_exact_alu;
393 if (state->inexact_match && state->has_exact_alu)
394 return false;
395
396 assert(!instr->dest.saturate);
397 assert(nir_op_infos[instr->op].num_inputs > 0);
398
399 /* If we have an explicitly sized destination, we can only handle the
400 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
401 * expression, we don't have the information right now to propagate that
402 * swizzle through. We can only properly propagate swizzles if the
403 * instruction is vectorized.
404 */
405 if (nir_op_infos[instr->op].output_size != 0) {
406 for (unsigned i = 0; i < num_components; i++) {
407 if (swizzle[i] != i)
408 return false;
409 }
410 }
411
412 /* If this is a commutative expression and it's one of the first few, look
413 * up its direction for the current search operation. We'll use that value
414 * to possibly flip the sources for the match.
415 */
416 unsigned comm_op_flip =
417 (expr->comm_expr_idx >= 0 &&
418 expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
419 ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
420
421 bool matched = true;
422 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
423 /* 2src_commutative instructions that have 3 sources are only commutative
424 * in the first two sources. Source 2 is always source 2.
425 */
426 if (!match_value(expr->srcs[i], instr,
427 i < 2 ? i ^ comm_op_flip : i,
428 num_components, swizzle, state)) {
429 matched = false;
430 break;
431 }
432 }
433
434 return matched;
435 }
436
437 static unsigned
438 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
439 struct match_state *state)
440 {
441 if (value->bit_size > 0)
442 return value->bit_size;
443 if (value->bit_size < 0)
444 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
445 return search_bitsize;
446 }
447
448 static nir_alu_src
449 construct_value(nir_builder *build,
450 const nir_search_value *value,
451 unsigned num_components, unsigned search_bitsize,
452 struct match_state *state,
453 nir_instr *instr)
454 {
455 switch (value->type) {
456 case nir_search_value_expression: {
457 const nir_search_expression *expr = nir_search_value_as_expression(value);
458 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
459 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
460
461 if (nir_op_infos[op].output_size != 0)
462 num_components = nir_op_infos[op].output_size;
463
464 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
465 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
466 dst_bit_size, NULL);
467 alu->dest.write_mask = (1 << num_components) - 1;
468 alu->dest.saturate = false;
469
470 /* We have no way of knowing what values in a given search expression
471 * map to a particular replacement value. Therefore, if the
472 * expression we are replacing has any exact values, the entire
473 * replacement should be exact.
474 */
475 alu->exact = state->has_exact_alu;
476
477 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
478 /* If the source is an explicitly sized source, then we need to reset
479 * the number of components to match.
480 */
481 if (nir_op_infos[alu->op].input_sizes[i] != 0)
482 num_components = nir_op_infos[alu->op].input_sizes[i];
483
484 alu->src[i] = construct_value(build, expr->srcs[i],
485 num_components, search_bitsize,
486 state, instr);
487 }
488
489 nir_builder_instr_insert(build, &alu->instr);
490
491 nir_alu_src val;
492 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
493 val.negate = false;
494 val.abs = false,
495 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
496
497 return val;
498 }
499
500 case nir_search_value_variable: {
501 const nir_search_variable *var = nir_search_value_as_variable(value);
502 assert(state->variables_seen & (1 << var->variable));
503
504 nir_alu_src val = { NIR_SRC_INIT };
505 nir_alu_src_copy(&val, &state->variables[var->variable],
506 (void *)build->shader);
507 assert(!var->is_constant);
508
509 return val;
510 }
511
512 case nir_search_value_constant: {
513 const nir_search_constant *c = nir_search_value_as_constant(value);
514 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
515
516 nir_ssa_def *cval;
517 switch (c->type) {
518 case nir_type_float:
519 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
520 break;
521
522 case nir_type_int:
523 case nir_type_uint:
524 cval = nir_imm_intN_t(build, c->data.i, bit_size);
525 break;
526
527 case nir_type_bool:
528 cval = nir_imm_boolN_t(build, c->data.u, bit_size);
529 break;
530
531 default:
532 unreachable("Invalid alu source type");
533 }
534
535 nir_alu_src val;
536 val.src = nir_src_for_ssa(cval);
537 val.negate = false;
538 val.abs = false,
539 memset(val.swizzle, 0, sizeof val.swizzle);
540
541 return val;
542 }
543
544 default:
545 unreachable("Invalid search value type");
546 }
547 }
548
549 MAYBE_UNUSED static void dump_value(const nir_search_value *val)
550 {
551 switch (val->type) {
552 case nir_search_value_constant: {
553 const nir_search_constant *sconst = nir_search_value_as_constant(val);
554 switch (sconst->type) {
555 case nir_type_float:
556 printf("%f", sconst->data.d);
557 break;
558 case nir_type_int:
559 printf("%"PRId64, sconst->data.i);
560 break;
561 case nir_type_uint:
562 printf("0x%"PRIx64, sconst->data.u);
563 break;
564 case nir_type_bool:
565 printf("%s", sconst->data.u != 0 ? "True" : "False");
566 break;
567 default:
568 unreachable("bad const type");
569 }
570 break;
571 }
572
573 case nir_search_value_variable: {
574 const nir_search_variable *var = nir_search_value_as_variable(val);
575 if (var->is_constant)
576 printf("#");
577 printf("%c", var->variable + 'a');
578 break;
579 }
580
581 case nir_search_value_expression: {
582 const nir_search_expression *expr = nir_search_value_as_expression(val);
583 printf("(");
584 if (expr->inexact)
585 printf("~");
586 switch (expr->opcode) {
587 #define CASE(n) \
588 case nir_search_op_##n: printf(#n); break;
589 CASE(f2b)
590 CASE(b2f)
591 CASE(b2i)
592 CASE(i2b)
593 CASE(i2i)
594 CASE(f2i)
595 CASE(i2f)
596 #undef CASE
597 default:
598 printf("%s", nir_op_infos[expr->opcode].name);
599 }
600
601 unsigned num_srcs = 1;
602 if (expr->opcode <= nir_last_opcode)
603 num_srcs = nir_op_infos[expr->opcode].num_inputs;
604
605 for (unsigned i = 0; i < num_srcs; i++) {
606 printf(" ");
607 dump_value(expr->srcs[i]);
608 }
609
610 printf(")");
611 break;
612 }
613 }
614
615 if (val->bit_size > 0)
616 printf("@%d", val->bit_size);
617 }
618
619 nir_ssa_def *
620 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
621 const nir_search_expression *search,
622 const nir_search_value *replace)
623 {
624 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
625
626 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
627 swizzle[i] = i;
628
629 assert(instr->dest.dest.is_ssa);
630
631 struct match_state state;
632 state.inexact_match = false;
633 state.has_exact_alu = false;
634
635 STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
636
637 unsigned comm_expr_combinations =
638 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
639
640 bool found = false;
641 for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
642 /* The bitfield of directions is just the current iteration. Hooray for
643 * binary.
644 */
645 state.comm_op_direction = comb;
646 state.variables_seen = 0;
647
648 if (match_expression(search, instr,
649 instr->dest.dest.ssa.num_components,
650 swizzle, &state)) {
651 found = true;
652 break;
653 }
654 }
655 if (!found)
656 return NULL;
657
658 #if 0
659 printf("matched: ");
660 dump_value(&search->value);
661 printf(" -> ");
662 dump_value(replace);
663 printf(" ssa_%d\n", instr->dest.dest.ssa.index);
664 #endif
665
666 build->cursor = nir_before_instr(&instr->instr);
667
668 nir_alu_src val = construct_value(build, replace,
669 instr->dest.dest.ssa.num_components,
670 instr->dest.dest.ssa.bit_size,
671 &state, &instr->instr);
672
673 /* Inserting a mov may be unnecessary. However, it's much easier to
674 * simply let copy propagation clean this up than to try to go through
675 * and rewrite swizzles ourselves.
676 */
677 nir_ssa_def *ssa_val =
678 nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
679 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
680
681 /* We know this one has no more uses because we just rewrote them all,
682 * so we can remove it. The rest of the matched expression, however, we
683 * don't know so much about. We'll just let dead code clean them up.
684 */
685 nir_instr_remove(&instr->instr);
686
687 return ssa_val;
688 }