iris: Only enable GL_AMD_depth_clamp_separate on Gen9+
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 #define NIR_SEARCH_MAX_COMM_OPS 4
34
35 struct match_state {
36 bool inexact_match;
37 bool has_exact_alu;
38 uint8_t comm_op_direction;
39 unsigned variables_seen;
40 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
41 };
42
43 static bool
44 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
45 unsigned num_components, const uint8_t *swizzle,
46 struct match_state *state);
47
48 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
49
50 /**
51 * Check if a source produces a value of the given type.
52 *
53 * Used for satisfying 'a@type' constraints.
54 */
55 static bool
56 src_is_type(nir_src src, nir_alu_type type)
57 {
58 assert(type != nir_type_invalid);
59
60 if (!src.is_ssa)
61 return false;
62
63 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
64 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
65 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
66
67 if (type == nir_type_bool) {
68 switch (src_alu->op) {
69 case nir_op_iand:
70 case nir_op_ior:
71 case nir_op_ixor:
72 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
73 src_is_type(src_alu->src[1].src, nir_type_bool);
74 case nir_op_inot:
75 return src_is_type(src_alu->src[0].src, nir_type_bool);
76 default:
77 break;
78 }
79 }
80
81 return nir_alu_type_get_base_type(output_type) == type;
82 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
83 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
84
85 if (type == nir_type_bool) {
86 return intr->intrinsic == nir_intrinsic_load_front_face ||
87 intr->intrinsic == nir_intrinsic_load_helper_invocation;
88 }
89 }
90
91 /* don't know */
92 return false;
93 }
94
95 static bool
96 nir_op_matches_search_op(nir_op nop, uint16_t sop)
97 {
98 if (sop <= nir_last_opcode)
99 return nop == sop;
100
101 #define MATCH_FCONV_CASE(op) \
102 case nir_search_op_##op: \
103 return nop == nir_op_##op##16 || \
104 nop == nir_op_##op##32 || \
105 nop == nir_op_##op##64;
106
107 #define MATCH_ICONV_CASE(op) \
108 case nir_search_op_##op: \
109 return nop == nir_op_##op##8 || \
110 nop == nir_op_##op##16 || \
111 nop == nir_op_##op##32 || \
112 nop == nir_op_##op##64;
113
114 #define MATCH_BCONV_CASE(op) \
115 case nir_search_op_##op: \
116 return nop == nir_op_##op##1 || \
117 nop == nir_op_##op##32;
118
119 switch (sop) {
120 MATCH_FCONV_CASE(i2f)
121 MATCH_FCONV_CASE(u2f)
122 MATCH_FCONV_CASE(f2f)
123 MATCH_ICONV_CASE(f2u)
124 MATCH_ICONV_CASE(f2i)
125 MATCH_ICONV_CASE(u2u)
126 MATCH_ICONV_CASE(i2i)
127 MATCH_FCONV_CASE(b2f)
128 MATCH_ICONV_CASE(b2i)
129 MATCH_BCONV_CASE(i2b)
130 MATCH_BCONV_CASE(f2b)
131 default:
132 unreachable("Invalid nir_search_op");
133 }
134
135 #undef MATCH_FCONV_CASE
136 #undef MATCH_ICONV_CASE
137 }
138
139 static nir_op
140 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
141 {
142 if (sop <= nir_last_opcode)
143 return sop;
144
145 #define RET_FCONV_CASE(op) \
146 case nir_search_op_##op: \
147 switch (bit_size) { \
148 case 16: return nir_op_##op##16; \
149 case 32: return nir_op_##op##32; \
150 case 64: return nir_op_##op##64; \
151 default: unreachable("Invalid bit size"); \
152 }
153
154 #define RET_ICONV_CASE(op) \
155 case nir_search_op_##op: \
156 switch (bit_size) { \
157 case 8: return nir_op_##op##8; \
158 case 16: return nir_op_##op##16; \
159 case 32: return nir_op_##op##32; \
160 case 64: return nir_op_##op##64; \
161 default: unreachable("Invalid bit size"); \
162 }
163
164 #define RET_BCONV_CASE(op) \
165 case nir_search_op_##op: \
166 switch (bit_size) { \
167 case 1: return nir_op_##op##1; \
168 case 32: return nir_op_##op##32; \
169 default: unreachable("Invalid bit size"); \
170 }
171
172 switch (sop) {
173 RET_FCONV_CASE(i2f)
174 RET_FCONV_CASE(u2f)
175 RET_FCONV_CASE(f2f)
176 RET_ICONV_CASE(f2u)
177 RET_ICONV_CASE(f2i)
178 RET_ICONV_CASE(u2u)
179 RET_ICONV_CASE(i2i)
180 RET_FCONV_CASE(b2f)
181 RET_ICONV_CASE(b2i)
182 RET_BCONV_CASE(i2b)
183 RET_BCONV_CASE(f2b)
184 default:
185 unreachable("Invalid nir_search_op");
186 }
187
188 #undef RET_FCONV_CASE
189 #undef RET_ICONV_CASE
190 }
191
192 static bool
193 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
194 unsigned num_components, const uint8_t *swizzle,
195 struct match_state *state)
196 {
197 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
198
199 /* Searching only works on SSA values because, if it's not SSA, we can't
200 * know if the value changed between one instance of that value in the
201 * expression and another. Also, the replace operation will place reads of
202 * that value right before the last instruction in the expression we're
203 * replacing so those reads will happen after the original reads and may
204 * not be valid if they're register reads.
205 */
206 assert(instr->src[src].src.is_ssa);
207
208 /* If the source is an explicitly sized source, then we need to reset
209 * both the number of components and the swizzle.
210 */
211 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
212 num_components = nir_op_infos[instr->op].input_sizes[src];
213 swizzle = identity_swizzle;
214 }
215
216 for (unsigned i = 0; i < num_components; ++i)
217 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
218
219 /* If the value has a specific bit size and it doesn't match, bail */
220 if (value->bit_size > 0 &&
221 nir_src_bit_size(instr->src[src].src) != value->bit_size)
222 return false;
223
224 switch (value->type) {
225 case nir_search_value_expression:
226 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
227 return false;
228
229 return match_expression(nir_search_value_as_expression(value),
230 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
231 num_components, new_swizzle, state);
232
233 case nir_search_value_variable: {
234 nir_search_variable *var = nir_search_value_as_variable(value);
235 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
236
237 if (state->variables_seen & (1 << var->variable)) {
238 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
239 return false;
240
241 assert(!instr->src[src].abs && !instr->src[src].negate);
242
243 for (unsigned i = 0; i < num_components; ++i) {
244 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
245 return false;
246 }
247
248 return true;
249 } else {
250 if (var->is_constant &&
251 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
252 return false;
253
254 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
255 return false;
256
257 if (var->type != nir_type_invalid &&
258 !src_is_type(instr->src[src].src, var->type))
259 return false;
260
261 state->variables_seen |= (1 << var->variable);
262 state->variables[var->variable].src = instr->src[src].src;
263 state->variables[var->variable].abs = false;
264 state->variables[var->variable].negate = false;
265
266 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
267 if (i < num_components)
268 state->variables[var->variable].swizzle[i] = new_swizzle[i];
269 else
270 state->variables[var->variable].swizzle[i] = 0;
271 }
272
273 return true;
274 }
275 }
276
277 case nir_search_value_constant: {
278 nir_search_constant *const_val = nir_search_value_as_constant(value);
279
280 if (!nir_src_is_const(instr->src[src].src))
281 return false;
282
283 switch (const_val->type) {
284 case nir_type_float:
285 for (unsigned i = 0; i < num_components; ++i) {
286 double val = nir_src_comp_as_float(instr->src[src].src,
287 new_swizzle[i]);
288 if (val != const_val->data.d)
289 return false;
290 }
291 return true;
292
293 case nir_type_int:
294 case nir_type_uint:
295 case nir_type_bool: {
296 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
297 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
298 for (unsigned i = 0; i < num_components; ++i) {
299 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
300 new_swizzle[i]);
301 if ((val & mask) != (const_val->data.u & mask))
302 return false;
303 }
304 return true;
305 }
306
307 default:
308 unreachable("Invalid alu source type");
309 }
310 }
311
312 default:
313 unreachable("Invalid search value type");
314 }
315 }
316
317 static bool
318 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
319 unsigned num_components, const uint8_t *swizzle,
320 struct match_state *state)
321 {
322 if (expr->cond && !expr->cond(instr))
323 return false;
324
325 if (!nir_op_matches_search_op(instr->op, expr->opcode))
326 return false;
327
328 assert(instr->dest.dest.is_ssa);
329
330 if (expr->value.bit_size > 0 &&
331 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
332 return false;
333
334 state->inexact_match = expr->inexact || state->inexact_match;
335 state->has_exact_alu = instr->exact || state->has_exact_alu;
336 if (state->inexact_match && state->has_exact_alu)
337 return false;
338
339 assert(!instr->dest.saturate);
340 assert(nir_op_infos[instr->op].num_inputs > 0);
341
342 /* If we have an explicitly sized destination, we can only handle the
343 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
344 * expression, we don't have the information right now to propagate that
345 * swizzle through. We can only properly propagate swizzles if the
346 * instruction is vectorized.
347 */
348 if (nir_op_infos[instr->op].output_size != 0) {
349 for (unsigned i = 0; i < num_components; i++) {
350 if (swizzle[i] != i)
351 return false;
352 }
353 }
354
355 /* If this is a commutative expression and it's one of the first few, look
356 * up its direction for the current search operation. We'll use that value
357 * to possibly flip the sources for the match.
358 */
359 unsigned comm_op_flip =
360 (expr->comm_expr_idx >= 0 &&
361 expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
362 ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
363
364 bool matched = true;
365 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
366 if (!match_value(expr->srcs[i], instr, i ^ comm_op_flip,
367 num_components, swizzle, state)) {
368 matched = false;
369 break;
370 }
371 }
372
373 return matched;
374 }
375
376 static unsigned
377 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
378 struct match_state *state)
379 {
380 if (value->bit_size > 0)
381 return value->bit_size;
382 if (value->bit_size < 0)
383 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
384 return search_bitsize;
385 }
386
387 static nir_alu_src
388 construct_value(nir_builder *build,
389 const nir_search_value *value,
390 unsigned num_components, unsigned search_bitsize,
391 struct match_state *state,
392 nir_instr *instr)
393 {
394 switch (value->type) {
395 case nir_search_value_expression: {
396 const nir_search_expression *expr = nir_search_value_as_expression(value);
397 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
398 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
399
400 if (nir_op_infos[op].output_size != 0)
401 num_components = nir_op_infos[op].output_size;
402
403 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
404 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
405 dst_bit_size, NULL);
406 alu->dest.write_mask = (1 << num_components) - 1;
407 alu->dest.saturate = false;
408
409 /* We have no way of knowing what values in a given search expression
410 * map to a particular replacement value. Therefore, if the
411 * expression we are replacing has any exact values, the entire
412 * replacement should be exact.
413 */
414 alu->exact = state->has_exact_alu;
415
416 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
417 /* If the source is an explicitly sized source, then we need to reset
418 * the number of components to match.
419 */
420 if (nir_op_infos[alu->op].input_sizes[i] != 0)
421 num_components = nir_op_infos[alu->op].input_sizes[i];
422
423 alu->src[i] = construct_value(build, expr->srcs[i],
424 num_components, search_bitsize,
425 state, instr);
426 }
427
428 nir_builder_instr_insert(build, &alu->instr);
429
430 nir_alu_src val;
431 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
432 val.negate = false;
433 val.abs = false,
434 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
435
436 return val;
437 }
438
439 case nir_search_value_variable: {
440 const nir_search_variable *var = nir_search_value_as_variable(value);
441 assert(state->variables_seen & (1 << var->variable));
442
443 nir_alu_src val = { NIR_SRC_INIT };
444 nir_alu_src_copy(&val, &state->variables[var->variable],
445 (void *)build->shader);
446 assert(!var->is_constant);
447
448 return val;
449 }
450
451 case nir_search_value_constant: {
452 const nir_search_constant *c = nir_search_value_as_constant(value);
453 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
454
455 nir_ssa_def *cval;
456 switch (c->type) {
457 case nir_type_float:
458 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
459 break;
460
461 case nir_type_int:
462 case nir_type_uint:
463 cval = nir_imm_intN_t(build, c->data.i, bit_size);
464 break;
465
466 case nir_type_bool:
467 cval = nir_imm_boolN_t(build, c->data.u, bit_size);
468 break;
469
470 default:
471 unreachable("Invalid alu source type");
472 }
473
474 nir_alu_src val;
475 val.src = nir_src_for_ssa(cval);
476 val.negate = false;
477 val.abs = false,
478 memset(val.swizzle, 0, sizeof val.swizzle);
479
480 return val;
481 }
482
483 default:
484 unreachable("Invalid search value type");
485 }
486 }
487
488 nir_ssa_def *
489 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
490 const nir_search_expression *search,
491 const nir_search_value *replace)
492 {
493 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
494
495 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
496 swizzle[i] = i;
497
498 assert(instr->dest.dest.is_ssa);
499
500 struct match_state state;
501 state.inexact_match = false;
502 state.has_exact_alu = false;
503
504 unsigned comm_expr_combinations =
505 1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
506
507 bool found = false;
508 for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
509 /* The bitfield of directions is just the current iteration. Hooray for
510 * binary.
511 */
512 state.comm_op_direction = comb;
513 state.variables_seen = 0;
514
515 if (match_expression(search, instr,
516 instr->dest.dest.ssa.num_components,
517 swizzle, &state)) {
518 found = true;
519 break;
520 }
521 }
522 if (!found)
523 return NULL;
524
525 build->cursor = nir_before_instr(&instr->instr);
526
527 nir_alu_src val = construct_value(build, replace,
528 instr->dest.dest.ssa.num_components,
529 instr->dest.dest.ssa.bit_size,
530 &state, &instr->instr);
531
532 /* Inserting a mov may be unnecessary. However, it's much easier to
533 * simply let copy propagation clean this up than to try to go through
534 * and rewrite swizzles ourselves.
535 */
536 nir_ssa_def *ssa_val =
537 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
538 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
539
540 /* We know this one has no more uses because we just rewrote them all,
541 * so we can remove it. The rest of the matched expression, however, we
542 * don't know so much about. We'll just let dead code clean them up.
543 */
544 nir_instr_remove(&instr->instr);
545
546 return ssa_val;
547 }