9c5cb547a701b02c944aefca9ae0facb5408bc62
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "util/half_float.h"
31
32 struct match_state {
33 bool inexact_match;
34 bool has_exact_alu;
35 unsigned variables_seen;
36 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
37 };
38
39 static bool
40 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
41 unsigned num_components, const uint8_t *swizzle,
42 struct match_state *state);
43
44 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
45
46 /**
47 * Check if a source produces a value of the given type.
48 *
49 * Used for satisfying 'a@type' constraints.
50 */
51 static bool
52 src_is_type(nir_src src, nir_alu_type type)
53 {
54 assert(type != nir_type_invalid);
55
56 if (!src.is_ssa)
57 return false;
58
59 /* Turn nir_type_bool32 into nir_type_bool...they're the same thing. */
60 if (nir_alu_type_get_base_type(type) == nir_type_bool)
61 type = nir_type_bool;
62
63 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
64 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
65 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
66
67 if (type == nir_type_bool) {
68 switch (src_alu->op) {
69 case nir_op_iand:
70 case nir_op_ior:
71 case nir_op_ixor:
72 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
73 src_is_type(src_alu->src[1].src, nir_type_bool);
74 case nir_op_inot:
75 return src_is_type(src_alu->src[0].src, nir_type_bool);
76 default:
77 break;
78 }
79 }
80
81 return nir_alu_type_get_base_type(output_type) == type;
82 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
83 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
84
85 if (type == nir_type_bool) {
86 return intr->intrinsic == nir_intrinsic_load_front_face ||
87 intr->intrinsic == nir_intrinsic_load_helper_invocation;
88 }
89 }
90
91 /* don't know */
92 return false;
93 }
94
95 static bool
96 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
97 unsigned num_components, const uint8_t *swizzle,
98 struct match_state *state)
99 {
100 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
101
102 /* Searching only works on SSA values because, if it's not SSA, we can't
103 * know if the value changed between one instance of that value in the
104 * expression and another. Also, the replace operation will place reads of
105 * that value right before the last instruction in the expression we're
106 * replacing so those reads will happen after the original reads and may
107 * not be valid if they're register reads.
108 */
109 if (!instr->src[src].src.is_ssa)
110 return false;
111
112 /* If the source is an explicitly sized source, then we need to reset
113 * both the number of components and the swizzle.
114 */
115 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
116 num_components = nir_op_infos[instr->op].input_sizes[src];
117 swizzle = identity_swizzle;
118 }
119
120 for (unsigned i = 0; i < num_components; ++i)
121 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
122
123 /* If the value has a specific bit size and it doesn't match, bail */
124 if (value->bit_size &&
125 nir_src_bit_size(instr->src[src].src) != value->bit_size)
126 return false;
127
128 switch (value->type) {
129 case nir_search_value_expression:
130 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
131 return false;
132
133 return match_expression(nir_search_value_as_expression(value),
134 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
135 num_components, new_swizzle, state);
136
137 case nir_search_value_variable: {
138 nir_search_variable *var = nir_search_value_as_variable(value);
139 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
140
141 if (state->variables_seen & (1 << var->variable)) {
142 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
143 return false;
144
145 assert(!instr->src[src].abs && !instr->src[src].negate);
146
147 for (unsigned i = 0; i < num_components; ++i) {
148 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
149 return false;
150 }
151
152 return true;
153 } else {
154 if (var->is_constant &&
155 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
156 return false;
157
158 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
159 return false;
160
161 if (var->type != nir_type_invalid &&
162 !src_is_type(instr->src[src].src, var->type))
163 return false;
164
165 state->variables_seen |= (1 << var->variable);
166 state->variables[var->variable].src = instr->src[src].src;
167 state->variables[var->variable].abs = false;
168 state->variables[var->variable].negate = false;
169
170 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
171 if (i < num_components)
172 state->variables[var->variable].swizzle[i] = new_swizzle[i];
173 else
174 state->variables[var->variable].swizzle[i] = 0;
175 }
176
177 return true;
178 }
179 }
180
181 case nir_search_value_constant: {
182 nir_search_constant *const_val = nir_search_value_as_constant(value);
183
184 if (!nir_src_is_const(instr->src[src].src))
185 return false;
186
187 switch (const_val->type) {
188 case nir_type_float:
189 for (unsigned i = 0; i < num_components; ++i) {
190 double val = nir_src_comp_as_float(instr->src[src].src,
191 new_swizzle[i]);
192 if (val != const_val->data.d)
193 return false;
194 }
195 return true;
196
197 case nir_type_int:
198 case nir_type_uint:
199 case nir_type_bool32: {
200 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
201 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
202 for (unsigned i = 0; i < num_components; ++i) {
203 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
204 new_swizzle[i]);
205 if ((val & mask) != (const_val->data.u & mask))
206 return false;
207 }
208 return true;
209 }
210
211 default:
212 unreachable("Invalid alu source type");
213 }
214 }
215
216 default:
217 unreachable("Invalid search value type");
218 }
219 }
220
221 static bool
222 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
223 unsigned num_components, const uint8_t *swizzle,
224 struct match_state *state)
225 {
226 if (expr->cond && !expr->cond(instr))
227 return false;
228
229 if (instr->op != expr->opcode)
230 return false;
231
232 assert(instr->dest.dest.is_ssa);
233
234 if (expr->value.bit_size &&
235 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
236 return false;
237
238 state->inexact_match = expr->inexact || state->inexact_match;
239 state->has_exact_alu = instr->exact || state->has_exact_alu;
240 if (state->inexact_match && state->has_exact_alu)
241 return false;
242
243 assert(!instr->dest.saturate);
244 assert(nir_op_infos[instr->op].num_inputs > 0);
245
246 /* If we have an explicitly sized destination, we can only handle the
247 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
248 * expression, we don't have the information right now to propagate that
249 * swizzle through. We can only properly propagate swizzles if the
250 * instruction is vectorized.
251 */
252 if (nir_op_infos[instr->op].output_size != 0) {
253 for (unsigned i = 0; i < num_components; i++) {
254 if (swizzle[i] != i)
255 return false;
256 }
257 }
258
259 /* Stash off the current variables_seen bitmask. This way we can
260 * restore it prior to matching in the commutative case below.
261 */
262 unsigned variables_seen_stash = state->variables_seen;
263
264 bool matched = true;
265 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
266 if (!match_value(expr->srcs[i], instr, i, num_components,
267 swizzle, state)) {
268 matched = false;
269 break;
270 }
271 }
272
273 if (matched)
274 return true;
275
276 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
277 assert(nir_op_infos[instr->op].num_inputs == 2);
278
279 /* Restore the variables_seen bitmask. If we don't do this, then we
280 * could end up with an erroneous failure due to variables found in the
281 * first match attempt above not matching those in the second.
282 */
283 state->variables_seen = variables_seen_stash;
284
285 if (!match_value(expr->srcs[0], instr, 1, num_components,
286 swizzle, state))
287 return false;
288
289 return match_value(expr->srcs[1], instr, 0, num_components,
290 swizzle, state);
291 } else {
292 return false;
293 }
294 }
295
296 typedef struct bitsize_tree {
297 unsigned num_srcs;
298 struct bitsize_tree *srcs[4];
299
300 unsigned common_size;
301 bool is_src_sized[4];
302 bool is_dest_sized;
303
304 unsigned dest_size;
305 unsigned src_size[4];
306 } bitsize_tree;
307
308 static bitsize_tree *
309 build_bitsize_tree(void *mem_ctx, struct match_state *state,
310 const nir_search_value *value)
311 {
312 bitsize_tree *tree = rzalloc(mem_ctx, bitsize_tree);
313
314 switch (value->type) {
315 case nir_search_value_expression: {
316 nir_search_expression *expr = nir_search_value_as_expression(value);
317 nir_op_info info = nir_op_infos[expr->opcode];
318 tree->num_srcs = info.num_inputs;
319 tree->common_size = 0;
320 for (unsigned i = 0; i < info.num_inputs; i++) {
321 tree->is_src_sized[i] = !!nir_alu_type_get_type_size(info.input_types[i]);
322 if (tree->is_src_sized[i])
323 tree->src_size[i] = nir_alu_type_get_type_size(info.input_types[i]);
324 tree->srcs[i] = build_bitsize_tree(mem_ctx, state, expr->srcs[i]);
325 }
326 tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type);
327 if (tree->is_dest_sized)
328 tree->dest_size = nir_alu_type_get_type_size(info.output_type);
329 break;
330 }
331
332 case nir_search_value_variable: {
333 nir_search_variable *var = nir_search_value_as_variable(value);
334 tree->num_srcs = 0;
335 tree->is_dest_sized = true;
336 tree->dest_size = nir_src_bit_size(state->variables[var->variable].src);
337 break;
338 }
339
340 case nir_search_value_constant: {
341 tree->num_srcs = 0;
342 tree->is_dest_sized = false;
343 tree->common_size = 0;
344 break;
345 }
346 }
347
348 if (value->bit_size) {
349 assert(!tree->is_dest_sized || tree->dest_size == value->bit_size);
350 tree->common_size = value->bit_size;
351 }
352
353 return tree;
354 }
355
356 static unsigned
357 bitsize_tree_filter_up(bitsize_tree *tree)
358 {
359 for (unsigned i = 0; i < tree->num_srcs; i++) {
360 unsigned src_size = bitsize_tree_filter_up(tree->srcs[i]);
361 if (src_size == 0)
362 continue;
363
364 if (tree->is_src_sized[i]) {
365 assert(src_size == tree->src_size[i]);
366 } else if (tree->common_size != 0) {
367 assert(src_size == tree->common_size);
368 tree->src_size[i] = src_size;
369 } else {
370 tree->common_size = src_size;
371 tree->src_size[i] = src_size;
372 }
373 }
374
375 if (tree->num_srcs && tree->common_size) {
376 if (tree->dest_size == 0)
377 tree->dest_size = tree->common_size;
378 else if (!tree->is_dest_sized)
379 assert(tree->dest_size == tree->common_size);
380
381 for (unsigned i = 0; i < tree->num_srcs; i++) {
382 if (!tree->src_size[i])
383 tree->src_size[i] = tree->common_size;
384 }
385 }
386
387 return tree->dest_size;
388 }
389
390 static void
391 bitsize_tree_filter_down(bitsize_tree *tree, unsigned size)
392 {
393 if (tree->dest_size)
394 assert(tree->dest_size == size);
395 else
396 tree->dest_size = size;
397
398 if (!tree->is_dest_sized) {
399 if (tree->common_size)
400 assert(tree->common_size == size);
401 else
402 tree->common_size = size;
403 }
404
405 for (unsigned i = 0; i < tree->num_srcs; i++) {
406 if (!tree->src_size[i]) {
407 assert(tree->common_size);
408 tree->src_size[i] = tree->common_size;
409 }
410 bitsize_tree_filter_down(tree->srcs[i], tree->src_size[i]);
411 }
412 }
413
414 static nir_alu_src
415 construct_value(const nir_search_value *value,
416 unsigned num_components, bitsize_tree *bitsize,
417 struct match_state *state,
418 nir_instr *instr, void *mem_ctx)
419 {
420 switch (value->type) {
421 case nir_search_value_expression: {
422 const nir_search_expression *expr = nir_search_value_as_expression(value);
423
424 if (nir_op_infos[expr->opcode].output_size != 0)
425 num_components = nir_op_infos[expr->opcode].output_size;
426
427 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
428 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
429 bitsize->dest_size, NULL);
430 alu->dest.write_mask = (1 << num_components) - 1;
431 alu->dest.saturate = false;
432
433 /* We have no way of knowing what values in a given search expression
434 * map to a particular replacement value. Therefore, if the
435 * expression we are replacing has any exact values, the entire
436 * replacement should be exact.
437 */
438 alu->exact = state->has_exact_alu;
439
440 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
441 /* If the source is an explicitly sized source, then we need to reset
442 * the number of components to match.
443 */
444 if (nir_op_infos[alu->op].input_sizes[i] != 0)
445 num_components = nir_op_infos[alu->op].input_sizes[i];
446
447 alu->src[i] = construct_value(expr->srcs[i],
448 num_components, bitsize->srcs[i],
449 state, instr, mem_ctx);
450 }
451
452 nir_instr_insert_before(instr, &alu->instr);
453
454 nir_alu_src val;
455 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
456 val.negate = false;
457 val.abs = false,
458 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
459
460 return val;
461 }
462
463 case nir_search_value_variable: {
464 const nir_search_variable *var = nir_search_value_as_variable(value);
465 assert(state->variables_seen & (1 << var->variable));
466
467 nir_alu_src val = { NIR_SRC_INIT };
468 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
469
470 assert(!var->is_constant);
471
472 return val;
473 }
474
475 case nir_search_value_constant: {
476 const nir_search_constant *c = nir_search_value_as_constant(value);
477 nir_load_const_instr *load =
478 nir_load_const_instr_create(mem_ctx, 1, bitsize->dest_size);
479
480 switch (c->type) {
481 case nir_type_float:
482 load->def.name = ralloc_asprintf(load, "%f", c->data.d);
483 switch (bitsize->dest_size) {
484 case 16:
485 load->value.u16[0] = _mesa_float_to_half(c->data.d);
486 break;
487 case 32:
488 load->value.f32[0] = c->data.d;
489 break;
490 case 64:
491 load->value.f64[0] = c->data.d;
492 break;
493 default:
494 unreachable("unknown bit size");
495 }
496 break;
497
498 case nir_type_int:
499 load->def.name = ralloc_asprintf(load, "%" PRIi64, c->data.i);
500 switch (bitsize->dest_size) {
501 case 8:
502 load->value.i8[0] = c->data.i;
503 break;
504 case 16:
505 load->value.i16[0] = c->data.i;
506 break;
507 case 32:
508 load->value.i32[0] = c->data.i;
509 break;
510 case 64:
511 load->value.i64[0] = c->data.i;
512 break;
513 default:
514 unreachable("unknown bit size");
515 }
516 break;
517
518 case nir_type_uint:
519 load->def.name = ralloc_asprintf(load, "%" PRIu64, c->data.u);
520 switch (bitsize->dest_size) {
521 case 8:
522 load->value.u8[0] = c->data.u;
523 break;
524 case 16:
525 load->value.u16[0] = c->data.u;
526 break;
527 case 32:
528 load->value.u32[0] = c->data.u;
529 break;
530 case 64:
531 load->value.u64[0] = c->data.u;
532 break;
533 default:
534 unreachable("unknown bit size");
535 }
536 break;
537
538 case nir_type_bool32:
539 load->value.u32[0] = c->data.u;
540 break;
541 default:
542 unreachable("Invalid alu source type");
543 }
544
545 nir_instr_insert_before(instr, &load->instr);
546
547 nir_alu_src val;
548 val.src = nir_src_for_ssa(&load->def);
549 val.negate = false;
550 val.abs = false,
551 memset(val.swizzle, 0, sizeof val.swizzle);
552
553 return val;
554 }
555
556 default:
557 unreachable("Invalid search value type");
558 }
559 }
560
561 nir_alu_instr *
562 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
563 const nir_search_value *replace, void *mem_ctx)
564 {
565 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
566
567 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
568 swizzle[i] = i;
569
570 assert(instr->dest.dest.is_ssa);
571
572 struct match_state state;
573 state.inexact_match = false;
574 state.has_exact_alu = false;
575 state.variables_seen = 0;
576
577 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
578 swizzle, &state))
579 return NULL;
580
581 void *bitsize_ctx = ralloc_context(NULL);
582 bitsize_tree *tree = build_bitsize_tree(bitsize_ctx, &state, replace);
583 bitsize_tree_filter_up(tree);
584 bitsize_tree_filter_down(tree, instr->dest.dest.ssa.bit_size);
585
586 /* Inserting a mov may be unnecessary. However, it's much easier to
587 * simply let copy propagation clean this up than to try to go through
588 * and rewrite swizzles ourselves.
589 */
590 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
591 mov->dest.write_mask = instr->dest.write_mask;
592 nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
593 instr->dest.dest.ssa.num_components,
594 instr->dest.dest.ssa.bit_size, NULL);
595
596 mov->src[0] = construct_value(replace,
597 instr->dest.dest.ssa.num_components, tree,
598 &state, &instr->instr, mem_ctx);
599 nir_instr_insert_before(&instr->instr, &mov->instr);
600
601 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
602 nir_src_for_ssa(&mov->dest.dest.ssa));
603
604 /* We know this one has no more uses because we just rewrote them all,
605 * so we can remove it. The rest of the matched expression, however, we
606 * don't know so much about. We'll just let dead code clean them up.
607 */
608 nir_instr_remove(&instr->instr);
609
610 ralloc_free(bitsize_ctx);
611
612 return mov;
613 }