nir/algebraic: Add support for unsized conversion opcodes
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 struct match_state {
34 bool inexact_match;
35 bool has_exact_alu;
36 unsigned variables_seen;
37 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
38 };
39
40 static bool
41 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
42 unsigned num_components, const uint8_t *swizzle,
43 struct match_state *state);
44
45 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
46
47 /**
48 * Check if a source produces a value of the given type.
49 *
50 * Used for satisfying 'a@type' constraints.
51 */
52 static bool
53 src_is_type(nir_src src, nir_alu_type type)
54 {
55 assert(type != nir_type_invalid);
56
57 if (!src.is_ssa)
58 return false;
59
60 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
61 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
62 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
63
64 if (type == nir_type_bool) {
65 switch (src_alu->op) {
66 case nir_op_iand:
67 case nir_op_ior:
68 case nir_op_ixor:
69 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
70 src_is_type(src_alu->src[1].src, nir_type_bool);
71 case nir_op_inot:
72 return src_is_type(src_alu->src[0].src, nir_type_bool);
73 default:
74 break;
75 }
76 }
77
78 return nir_alu_type_get_base_type(output_type) == type;
79 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
80 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
81
82 if (type == nir_type_bool) {
83 return intr->intrinsic == nir_intrinsic_load_front_face ||
84 intr->intrinsic == nir_intrinsic_load_helper_invocation;
85 }
86 }
87
88 /* don't know */
89 return false;
90 }
91
92 static bool
93 nir_op_matches_search_op(nir_op nop, uint16_t sop)
94 {
95 if (sop <= nir_last_opcode)
96 return nop == sop;
97
98 #define MATCH_FCONV_CASE(op) \
99 case nir_search_op_##op: \
100 return nop == nir_op_##op##16 || \
101 nop == nir_op_##op##32 || \
102 nop == nir_op_##op##64;
103
104 #define MATCH_ICONV_CASE(op) \
105 case nir_search_op_##op: \
106 return nop == nir_op_##op##8 || \
107 nop == nir_op_##op##16 || \
108 nop == nir_op_##op##32 || \
109 nop == nir_op_##op##64;
110
111 switch (sop) {
112 MATCH_FCONV_CASE(i2f)
113 MATCH_FCONV_CASE(u2f)
114 MATCH_FCONV_CASE(f2f)
115 MATCH_ICONV_CASE(f2u)
116 MATCH_ICONV_CASE(f2i)
117 MATCH_ICONV_CASE(u2u)
118 MATCH_ICONV_CASE(i2i)
119 default:
120 unreachable("Invalid nir_search_op");
121 }
122
123 #undef MATCH_FCONV_CASE
124 #undef MATCH_ICONV_CASE
125 }
126
127 static nir_op
128 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
129 {
130 if (sop <= nir_last_opcode)
131 return sop;
132
133 #define RET_FCONV_CASE(op) \
134 case nir_search_op_##op: \
135 switch (bit_size) { \
136 case 16: return nir_op_##op##16; \
137 case 32: return nir_op_##op##32; \
138 case 64: return nir_op_##op##64; \
139 default: unreachable("Invalid bit size"); \
140 }
141
142 #define RET_ICONV_CASE(op) \
143 case nir_search_op_##op: \
144 switch (bit_size) { \
145 case 8: return nir_op_##op##8; \
146 case 16: return nir_op_##op##16; \
147 case 32: return nir_op_##op##32; \
148 case 64: return nir_op_##op##64; \
149 default: unreachable("Invalid bit size"); \
150 }
151
152 switch (sop) {
153 RET_FCONV_CASE(i2f)
154 RET_FCONV_CASE(u2f)
155 RET_FCONV_CASE(f2f)
156 RET_ICONV_CASE(f2u)
157 RET_ICONV_CASE(f2i)
158 RET_ICONV_CASE(u2u)
159 RET_ICONV_CASE(i2i)
160 default:
161 unreachable("Invalid nir_search_op");
162 }
163
164 #undef RET_FCONV_CASE
165 #undef RET_ICONV_CASE
166 }
167
168 static bool
169 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
170 unsigned num_components, const uint8_t *swizzle,
171 struct match_state *state)
172 {
173 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
174
175 /* Searching only works on SSA values because, if it's not SSA, we can't
176 * know if the value changed between one instance of that value in the
177 * expression and another. Also, the replace operation will place reads of
178 * that value right before the last instruction in the expression we're
179 * replacing so those reads will happen after the original reads and may
180 * not be valid if they're register reads.
181 */
182 if (!instr->src[src].src.is_ssa)
183 return false;
184
185 /* If the source is an explicitly sized source, then we need to reset
186 * both the number of components and the swizzle.
187 */
188 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
189 num_components = nir_op_infos[instr->op].input_sizes[src];
190 swizzle = identity_swizzle;
191 }
192
193 for (unsigned i = 0; i < num_components; ++i)
194 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
195
196 /* If the value has a specific bit size and it doesn't match, bail */
197 if (value->bit_size > 0 &&
198 nir_src_bit_size(instr->src[src].src) != value->bit_size)
199 return false;
200
201 switch (value->type) {
202 case nir_search_value_expression:
203 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
204 return false;
205
206 return match_expression(nir_search_value_as_expression(value),
207 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
208 num_components, new_swizzle, state);
209
210 case nir_search_value_variable: {
211 nir_search_variable *var = nir_search_value_as_variable(value);
212 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
213
214 if (state->variables_seen & (1 << var->variable)) {
215 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
216 return false;
217
218 assert(!instr->src[src].abs && !instr->src[src].negate);
219
220 for (unsigned i = 0; i < num_components; ++i) {
221 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
222 return false;
223 }
224
225 return true;
226 } else {
227 if (var->is_constant &&
228 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
229 return false;
230
231 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
232 return false;
233
234 if (var->type != nir_type_invalid &&
235 !src_is_type(instr->src[src].src, var->type))
236 return false;
237
238 state->variables_seen |= (1 << var->variable);
239 state->variables[var->variable].src = instr->src[src].src;
240 state->variables[var->variable].abs = false;
241 state->variables[var->variable].negate = false;
242
243 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
244 if (i < num_components)
245 state->variables[var->variable].swizzle[i] = new_swizzle[i];
246 else
247 state->variables[var->variable].swizzle[i] = 0;
248 }
249
250 return true;
251 }
252 }
253
254 case nir_search_value_constant: {
255 nir_search_constant *const_val = nir_search_value_as_constant(value);
256
257 if (!nir_src_is_const(instr->src[src].src))
258 return false;
259
260 switch (const_val->type) {
261 case nir_type_float:
262 for (unsigned i = 0; i < num_components; ++i) {
263 double val = nir_src_comp_as_float(instr->src[src].src,
264 new_swizzle[i]);
265 if (val != const_val->data.d)
266 return false;
267 }
268 return true;
269
270 case nir_type_int:
271 case nir_type_uint:
272 case nir_type_bool: {
273 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
274 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
275 for (unsigned i = 0; i < num_components; ++i) {
276 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
277 new_swizzle[i]);
278 if ((val & mask) != (const_val->data.u & mask))
279 return false;
280 }
281 return true;
282 }
283
284 default:
285 unreachable("Invalid alu source type");
286 }
287 }
288
289 default:
290 unreachable("Invalid search value type");
291 }
292 }
293
294 static bool
295 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
296 unsigned num_components, const uint8_t *swizzle,
297 struct match_state *state)
298 {
299 if (expr->cond && !expr->cond(instr))
300 return false;
301
302 if (!nir_op_matches_search_op(instr->op, expr->opcode))
303 return false;
304
305 assert(instr->dest.dest.is_ssa);
306
307 if (expr->value.bit_size > 0 &&
308 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
309 return false;
310
311 state->inexact_match = expr->inexact || state->inexact_match;
312 state->has_exact_alu = instr->exact || state->has_exact_alu;
313 if (state->inexact_match && state->has_exact_alu)
314 return false;
315
316 assert(!instr->dest.saturate);
317 assert(nir_op_infos[instr->op].num_inputs > 0);
318
319 /* If we have an explicitly sized destination, we can only handle the
320 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
321 * expression, we don't have the information right now to propagate that
322 * swizzle through. We can only properly propagate swizzles if the
323 * instruction is vectorized.
324 */
325 if (nir_op_infos[instr->op].output_size != 0) {
326 for (unsigned i = 0; i < num_components; i++) {
327 if (swizzle[i] != i)
328 return false;
329 }
330 }
331
332 /* Stash off the current variables_seen bitmask. This way we can
333 * restore it prior to matching in the commutative case below.
334 */
335 unsigned variables_seen_stash = state->variables_seen;
336
337 bool matched = true;
338 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
339 if (!match_value(expr->srcs[i], instr, i, num_components,
340 swizzle, state)) {
341 matched = false;
342 break;
343 }
344 }
345
346 if (matched)
347 return true;
348
349 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
350 assert(nir_op_infos[instr->op].num_inputs == 2);
351
352 /* Restore the variables_seen bitmask. If we don't do this, then we
353 * could end up with an erroneous failure due to variables found in the
354 * first match attempt above not matching those in the second.
355 */
356 state->variables_seen = variables_seen_stash;
357
358 if (!match_value(expr->srcs[0], instr, 1, num_components,
359 swizzle, state))
360 return false;
361
362 return match_value(expr->srcs[1], instr, 0, num_components,
363 swizzle, state);
364 } else {
365 return false;
366 }
367 }
368
369 static unsigned
370 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
371 struct match_state *state)
372 {
373 if (value->bit_size > 0)
374 return value->bit_size;
375 if (value->bit_size < 0)
376 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
377 return search_bitsize;
378 }
379
380 static nir_alu_src
381 construct_value(nir_builder *build,
382 const nir_search_value *value,
383 unsigned num_components, unsigned search_bitsize,
384 struct match_state *state,
385 nir_instr *instr)
386 {
387 switch (value->type) {
388 case nir_search_value_expression: {
389 const nir_search_expression *expr = nir_search_value_as_expression(value);
390 unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
391 nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
392
393 if (nir_op_infos[op].output_size != 0)
394 num_components = nir_op_infos[op].output_size;
395
396 nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
397 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
398 dst_bit_size, NULL);
399 alu->dest.write_mask = (1 << num_components) - 1;
400 alu->dest.saturate = false;
401
402 /* We have no way of knowing what values in a given search expression
403 * map to a particular replacement value. Therefore, if the
404 * expression we are replacing has any exact values, the entire
405 * replacement should be exact.
406 */
407 alu->exact = state->has_exact_alu;
408
409 for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
410 /* If the source is an explicitly sized source, then we need to reset
411 * the number of components to match.
412 */
413 if (nir_op_infos[alu->op].input_sizes[i] != 0)
414 num_components = nir_op_infos[alu->op].input_sizes[i];
415
416 alu->src[i] = construct_value(build, expr->srcs[i],
417 num_components, search_bitsize,
418 state, instr);
419 }
420
421 nir_builder_instr_insert(build, &alu->instr);
422
423 nir_alu_src val;
424 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
425 val.negate = false;
426 val.abs = false,
427 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
428
429 return val;
430 }
431
432 case nir_search_value_variable: {
433 const nir_search_variable *var = nir_search_value_as_variable(value);
434 assert(state->variables_seen & (1 << var->variable));
435
436 nir_alu_src val = { NIR_SRC_INIT };
437 nir_alu_src_copy(&val, &state->variables[var->variable],
438 (void *)build->shader);
439 assert(!var->is_constant);
440
441 return val;
442 }
443
444 case nir_search_value_constant: {
445 const nir_search_constant *c = nir_search_value_as_constant(value);
446 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
447
448 nir_ssa_def *cval;
449 switch (c->type) {
450 case nir_type_float:
451 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
452 break;
453
454 case nir_type_int:
455 case nir_type_uint:
456 cval = nir_imm_intN_t(build, c->data.i, bit_size);
457 break;
458
459 case nir_type_bool:
460 cval = nir_imm_bool(build, c->data.u);
461 break;
462 default:
463 unreachable("Invalid alu source type");
464 }
465
466 nir_alu_src val;
467 val.src = nir_src_for_ssa(cval);
468 val.negate = false;
469 val.abs = false,
470 memset(val.swizzle, 0, sizeof val.swizzle);
471
472 return val;
473 }
474
475 default:
476 unreachable("Invalid search value type");
477 }
478 }
479
480 nir_ssa_def *
481 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
482 const nir_search_expression *search,
483 const nir_search_value *replace)
484 {
485 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
486
487 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
488 swizzle[i] = i;
489
490 assert(instr->dest.dest.is_ssa);
491
492 struct match_state state;
493 state.inexact_match = false;
494 state.has_exact_alu = false;
495 state.variables_seen = 0;
496
497 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
498 swizzle, &state))
499 return NULL;
500
501 build->cursor = nir_before_instr(&instr->instr);
502
503 nir_alu_src val = construct_value(build, replace,
504 instr->dest.dest.ssa.num_components,
505 instr->dest.dest.ssa.bit_size,
506 &state, &instr->instr);
507
508 /* Inserting a mov may be unnecessary. However, it's much easier to
509 * simply let copy propagation clean this up than to try to go through
510 * and rewrite swizzles ourselves.
511 */
512 nir_ssa_def *ssa_val =
513 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
514 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
515
516 /* We know this one has no more uses because we just rewrote them all,
517 * so we can remove it. The rest of the matched expression, however, we
518 * don't know so much about. We'll just let dead code clean them up.
519 */
520 nir_instr_remove(&instr->instr);
521
522 return ssa_val;
523 }