anv,i965: Lower away image derefs in the driver
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "util/half_float.h"
31
32 struct match_state {
33 bool inexact_match;
34 bool has_exact_alu;
35 unsigned variables_seen;
36 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
37 };
38
39 static bool
40 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
41 unsigned num_components, const uint8_t *swizzle,
42 struct match_state *state);
43
44 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
45
46 /**
47 * Check if a source produces a value of the given type.
48 *
49 * Used for satisfying 'a@type' constraints.
50 */
51 static bool
52 src_is_type(nir_src src, nir_alu_type type)
53 {
54 assert(type != nir_type_invalid);
55
56 if (!src.is_ssa)
57 return false;
58
59 /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
60 if (nir_alu_type_get_base_type(type) == nir_type_bool)
61 type = nir_type_bool;
62
63 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
64 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
65 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
66
67 if (type == nir_type_bool) {
68 switch (src_alu->op) {
69 case nir_op_iand:
70 case nir_op_ior:
71 case nir_op_ixor:
72 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
73 src_is_type(src_alu->src[1].src, nir_type_bool);
74 case nir_op_inot:
75 return src_is_type(src_alu->src[0].src, nir_type_bool);
76 default:
77 break;
78 }
79 }
80
81 return nir_alu_type_get_base_type(output_type) == type;
82 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
83 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
84
85 if (type == nir_type_bool) {
86 return intr->intrinsic == nir_intrinsic_load_front_face ||
87 intr->intrinsic == nir_intrinsic_load_helper_invocation;
88 }
89 }
90
91 /* don't know */
92 return false;
93 }
94
95 static bool
96 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
97 unsigned num_components, const uint8_t *swizzle,
98 struct match_state *state)
99 {
100 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
101
102 /* Searching only works on SSA values because, if it's not SSA, we can't
103 * know if the value changed between one instance of that value in the
104 * expression and another. Also, the replace operation will place reads of
105 * that value right before the last instruction in the expression we're
106 * replacing so those reads will happen after the original reads and may
107 * not be valid if they're register reads.
108 */
109 if (!instr->src[src].src.is_ssa)
110 return false;
111
112 /* If the source is an explicitly sized source, then we need to reset
113 * both the number of components and the swizzle.
114 */
115 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
116 num_components = nir_op_infos[instr->op].input_sizes[src];
117 swizzle = identity_swizzle;
118 }
119
120 for (unsigned i = 0; i < num_components; ++i)
121 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
122
123 /* If the value has a specific bit size and it doesn't match, bail */
124 if (value->bit_size &&
125 nir_src_bit_size(instr->src[src].src) != value->bit_size)
126 return false;
127
128 switch (value->type) {
129 case nir_search_value_expression:
130 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
131 return false;
132
133 return match_expression(nir_search_value_as_expression(value),
134 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
135 num_components, new_swizzle, state);
136
137 case nir_search_value_variable: {
138 nir_search_variable *var = nir_search_value_as_variable(value);
139 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
140
141 if (state->variables_seen & (1 << var->variable)) {
142 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
143 return false;
144
145 assert(!instr->src[src].abs && !instr->src[src].negate);
146
147 for (unsigned i = 0; i < num_components; ++i) {
148 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
149 return false;
150 }
151
152 return true;
153 } else {
154 if (var->is_constant &&
155 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
156 return false;
157
158 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
159 return false;
160
161 if (var->type != nir_type_invalid &&
162 !src_is_type(instr->src[src].src, var->type))
163 return false;
164
165 state->variables_seen |= (1 << var->variable);
166 state->variables[var->variable].src = instr->src[src].src;
167 state->variables[var->variable].abs = false;
168 state->variables[var->variable].negate = false;
169
170 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
171 if (i < num_components)
172 state->variables[var->variable].swizzle[i] = new_swizzle[i];
173 else
174 state->variables[var->variable].swizzle[i] = 0;
175 }
176
177 return true;
178 }
179 }
180
181 case nir_search_value_constant: {
182 nir_search_constant *const_val = nir_search_value_as_constant(value);
183
184 if (!instr->src[src].src.is_ssa)
185 return false;
186
187 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
188 return false;
189
190 nir_load_const_instr *load =
191 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
192
193 switch (const_val->type) {
194 case nir_type_float:
195 for (unsigned i = 0; i < num_components; ++i) {
196 double val;
197 switch (load->def.bit_size) {
198 case 16:
199 val = _mesa_half_to_float(load->value.u16[new_swizzle[i]]);
200 break;
201 case 32:
202 val = load->value.f32[new_swizzle[i]];
203 break;
204 case 64:
205 val = load->value.f64[new_swizzle[i]];
206 break;
207 default:
208 unreachable("unknown bit size");
209 }
210
211 if (val != const_val->data.d)
212 return false;
213 }
214 return true;
215
216 case nir_type_int:
217 case nir_type_uint:
218 case nir_type_bool32:
219 switch (load->def.bit_size) {
220 case 8:
221 for (unsigned i = 0; i < num_components; ++i) {
222 if (load->value.u8[new_swizzle[i]] !=
223 (uint8_t)const_val->data.u)
224 return false;
225 }
226 return true;
227
228 case 16:
229 for (unsigned i = 0; i < num_components; ++i) {
230 if (load->value.u16[new_swizzle[i]] !=
231 (uint16_t)const_val->data.u)
232 return false;
233 }
234 return true;
235
236 case 32:
237 for (unsigned i = 0; i < num_components; ++i) {
238 if (load->value.u32[new_swizzle[i]] !=
239 (uint32_t)const_val->data.u)
240 return false;
241 }
242 return true;
243
244 case 64:
245 for (unsigned i = 0; i < num_components; ++i) {
246 if (load->value.u64[new_swizzle[i]] != const_val->data.u)
247 return false;
248 }
249 return true;
250
251 default:
252 unreachable("unknown bit size");
253 }
254
255 default:
256 unreachable("Invalid alu source type");
257 }
258 }
259
260 default:
261 unreachable("Invalid search value type");
262 }
263 }
264
265 static bool
266 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
267 unsigned num_components, const uint8_t *swizzle,
268 struct match_state *state)
269 {
270 if (expr->cond && !expr->cond(instr))
271 return false;
272
273 if (instr->op != expr->opcode)
274 return false;
275
276 assert(instr->dest.dest.is_ssa);
277
278 if (expr->value.bit_size &&
279 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
280 return false;
281
282 state->inexact_match = expr->inexact || state->inexact_match;
283 state->has_exact_alu = instr->exact || state->has_exact_alu;
284 if (state->inexact_match && state->has_exact_alu)
285 return false;
286
287 assert(!instr->dest.saturate);
288 assert(nir_op_infos[instr->op].num_inputs > 0);
289
290 /* If we have an explicitly sized destination, we can only handle the
291 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
292 * expression, we don't have the information right now to propagate that
293 * swizzle through. We can only properly propagate swizzles if the
294 * instruction is vectorized.
295 */
296 if (nir_op_infos[instr->op].output_size != 0) {
297 for (unsigned i = 0; i < num_components; i++) {
298 if (swizzle[i] != i)
299 return false;
300 }
301 }
302
303 /* Stash off the current variables_seen bitmask. This way we can
304 * restore it prior to matching in the commutative case below.
305 */
306 unsigned variables_seen_stash = state->variables_seen;
307
308 bool matched = true;
309 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
310 if (!match_value(expr->srcs[i], instr, i, num_components,
311 swizzle, state)) {
312 matched = false;
313 break;
314 }
315 }
316
317 if (matched)
318 return true;
319
320 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
321 assert(nir_op_infos[instr->op].num_inputs == 2);
322
323 /* Restore the variables_seen bitmask. If we don't do this, then we
324 * could end up with an erroneous failure due to variables found in the
325 * first match attempt above not matching those in the second.
326 */
327 state->variables_seen = variables_seen_stash;
328
329 if (!match_value(expr->srcs[0], instr, 1, num_components,
330 swizzle, state))
331 return false;
332
333 return match_value(expr->srcs[1], instr, 0, num_components,
334 swizzle, state);
335 } else {
336 return false;
337 }
338 }
339
340 typedef struct bitsize_tree {
341 unsigned num_srcs;
342 struct bitsize_tree *srcs[4];
343
344 unsigned common_size;
345 bool is_src_sized[4];
346 bool is_dest_sized;
347
348 unsigned dest_size;
349 unsigned src_size[4];
350 } bitsize_tree;
351
352 static bitsize_tree *
353 build_bitsize_tree(void *mem_ctx, struct match_state *state,
354 const nir_search_value *value)
355 {
356 bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
357
358 switch (value->type) {
359 case nir_search_value_expression: {
360 nir_search_expression *expr = nir_search_value_as_expression(value);
361 nir_op_info info = nir_op_infos[expr->opcode];
362 tree->num_srcs = info.num_inputs;
363 tree->common_size = 0;
364 for (unsigned i = 0; i < info.num_inputs; i++) {
365 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
366 if (tree->is_src_sized[i])
367 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
368 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
369 }
370 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
371 if (tree->is_dest_sized)
372 tree->dest_size = nir_alu_type_get_type_size(info.output_type);
373 break;
374 }
375
376 case nir_search_value_variable: {
377 nir_search_variable *var = nir_search_value_as_variable(value);
378 tree->num_srcs = 0;
379 tree->is_dest_sized = true;
380 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
381 break;
382 }
383
384 case nir_search_value_constant: {
385 tree->num_srcs = 0;
386 tree->is_dest_sized = false;
387 tree->common_size = 0;
388 break;
389 }
390 }
391
392 if (value->bit_size) {
393 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
394 tree->common_size = value->bit_size;
395 }
396
397 return tree;
398 }
399
400 static unsigned
401 bitsize_tree_filter_up(bitsize_tree *tree)
402 {
403 for (unsigned i = 0; i < tree->num_srcs; i++) {
404 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
405 if (src_size == 0)
406 continue;
407
408 if (tree->is_src_sized[i]) {
409 assert(src_size == tree->src_size[i]);
410 } else if (tree->common_size != 0) {
411 assert(src_size == tree->common_size);
412 tree->src_size[i] = src_size;
413 } else {
414 tree->common_size = src_size;
415 tree->src_size[i] = src_size;
416 }
417 }
418
419 if (tree->num_srcs && tree->common_size) {
420 if (tree->dest_size == 0)
421 tree->dest_size = tree->common_size;
422 else if (!tree->is_dest_sized)
423 assert(tree->dest_size == tree->common_size);
424
425 for (unsigned i = 0; i < tree->num_srcs; i++) {
426 if (!tree->src_size[i])
427 tree->src_size[i] = tree->common_size;
428 }
429 }
430
431 return tree->dest_size;
432 }
433
434 static void
435 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
436 {
437 if (tree->dest_size)
438 assert(tree->dest_size == size);
439 else
440 tree->dest_size = size;
441
442 if (!tree->is_dest_sized) {
443 if (tree->common_size)
444 assert(tree->common_size == size);
445 else
446 tree->common_size = size;
447 }
448
449 for (unsigned i = 0; i < tree->num_srcs; i++) {
450 if (!tree->src_size[i]) {
451 assert(tree->common_size);
452 tree->src_size[i] = tree->common_size;
453 }
454 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
455 }
456 }
457
458 static nir_alu_src
459 construct_value(const nir_search_value *value,
460 unsigned num_components, bitsize_tree *bitsize,
461 struct match_state *state,
462 nir_instr *instr, void *mem_ctx)
463 {
464 switch (value->type) {
465 case nir_search_value_expression: {
466 const nir_search_expression *expr = nir_search_value_as_expression(value);
467
468 if (nir_op_infos[expr->opcode].output_size != 0)
469 num_components = nir_op_infos[expr->opcode].output_size;
470
471 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
472 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
473 bitsize->dest_size, NULL);
474 alu->dest.write_mask = (1 << num_components) - 1;
475 alu->dest.saturate = false;
476
477 /* We have no way of knowing what values in a given search expression
478 * map to a particular replacement value. Therefore, if the
479 * expression we are replacing has any exact values, the entire
480 * replacement should be exact.
481 */
482 alu->exact = state->has_exact_alu;
483
484 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
485 /* If the source is an explicitly sized source, then we need to reset
486 * the number of components to match.
487 */
488 if (nir_op_infos[alu->op].input_sizes[i] != 0)
489 num_components = nir_op_infos[alu->op].input_sizes[i];
490
491 alu->src[i] = construct_value(expr->srcs[i],
492 num_components, bitsize->srcs[i],
493 state, instr, mem_ctx);
494 }
495
496 nir_instr_insert_before(instr, &alu->instr);
497
498 nir_alu_src val;
499 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
500 val.negate = false;
501 val.abs = false,
502 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
503
504 return val;
505 }
506
507 case nir_search_value_variable: {
508 const nir_search_variable *var = nir_search_value_as_variable(value);
509 assert(state->variables_seen & (1 << var->variable));
510
511 nir_alu_src val = { NIR_SRC_INIT };
512 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
513
514 assert(!var->is_constant);
515
516 return val;
517 }
518
519 case nir_search_value_constant: {
520 const nir_search_constant *c = nir_search_value_as_constant(value);
521 nir_load_const_instr *load =
522 nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
523
524 switch (c->type) {
525 case nir_type_float:
526 load->def.name = ralloc_asprintf(load, "%f", c->data.d);
527 switch (bitsize->dest_size) {
528 case 16:
529 load->value.u16[0] = _mesa_float_to_half(c->data.d);
530 break;
531 case 32:
532 load->value.f32[0] = c->data.d;
533 break;
534 case 64:
535 load->value.f64[0] = c->data.d;
536 break;
537 default:
538 unreachable("unknown bit size");
539 }
540 break;
541
542 case nir_type_int:
543 load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
544 switch (bitsize->dest_size) {
545 case 8:
546 load->value.i8[0] = c->data.i;
547 break;
548 case 16:
549 load->value.i16[0] = c->data.i;
550 break;
551 case 32:
552 load->value.i32[0] = c->data.i;
553 break;
554 case 64:
555 load->value.i64[0] = c->data.i;
556 break;
557 default:
558 unreachable("unknown bit size");
559 }
560 break;
561
562 case nir_type_uint:
563 load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
564 switch (bitsize->dest_size) {
565 case 8:
566 load->value.u8[0] = c->data.u;
567 break;
568 case 16:
569 load->value.u16[0] = c->data.u;
570 break;
571 case 32:
572 load->value.u32[0] = c->data.u;
573 break;
574 case 64:
575 load->value.u64[0] = c->data.u;
576 break;
577 default:
578 unreachable("unknown bit size");
579 }
580 break;
581
582 case nir_type_bool32:
583 load->value.u32[0] = c->data.u;
584 break;
585 default:
586 unreachable("Invalid alu source type");
587 }
588
589 nir_instr_insert_before(instr, &load->instr);
590
591 nir_alu_src val;
592 val.src = nir_src_for_ssa(&load->def);
593 val.negate = false;
594 val.abs = false,
595 memset(val.swizzle, 0, sizeof val.swizzle);
596
597 return val;
598 }
599
600 default:
601 unreachable("Invalid search value type");
602 }
603 }
604
605 nir_alu_instr *
606 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
607 const nir_search_value *replace, void *mem_ctx)
608 {
609 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
610
611 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
612 swizzle[i] = i;
613
614 assert(instr->dest.dest.is_ssa);
615
616 struct match_state state;
617 state.inexact_match = false;
618 state.has_exact_alu = false;
619 state.variables_seen = 0;
620
621 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
622 swizzle, &state))
623 return NULL;
624
625 void *bitsize_ctx = ralloc_context(NULL);
626 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
627 bitsize_tree_filter_up(tree);
628 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
629
630 /* Inserting a mov may be unnecessary. However, it's much easier to
631 * simply let copy propagation clean this up than to try to go through
632 * and rewrite swizzles ourselves.
633 */
634 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
635 mov->dest.write_mask = instr->dest.write_mask;
636 nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
637 instr->dest.dest.ssa.num_components,
638 instr->dest.dest.ssa.bit_size, NULL);
639
640 mov->src[0] = construct_value(replace,
641 instr->dest.dest.ssa.num_components, tree,
642 &state, &instr->instr, mem_ctx);
643 nir_instr_insert_before(&instr->instr, &mov->instr);
644
645 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
646 nir_src_for_ssa(&mov->dest.dest.ssa));
647
648 /* We know this one has no more uses because we just rewrote them all,
649 * so we can remove it. The rest of the matched expression, however, we
650 * don't know so much about. We'll just let dead code clean them up.
651 */
652 nir_instr_remove(&instr->instr);
653
654 ralloc_free(bitsize_ctx);
655
656 return mov;
657 }