nir/algebraic: Rewrite bit-size inference
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "util/half_float.h"
32
33 struct match_state {
34 bool inexact_match;
35 bool has_exact_alu;
36 unsigned variables_seen;
37 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
38 };
39
40 static bool
41 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
42 unsigned num_components, const uint8_t *swizzle,
43 struct match_state *state);
44
45 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3 };
46
47 /**
48 * Check if a source produces a value of the given type.
49 *
50 * Used for satisfying 'a@type' constraints.
51 */
52 static bool
53 src_is_type(nir_src src, nir_alu_type type)
54 {
55 assert(type != nir_type_invalid);
56
57 if (!src.is_ssa)
58 return false;
59
60 if (src.ssa->parent_instr->type == nir_instr_type_alu) {
61 nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
62 nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
63
64 if (type == nir_type_bool) {
65 switch (src_alu->op) {
66 case nir_op_iand:
67 case nir_op_ior:
68 case nir_op_ixor:
69 return src_is_type(src_alu->src[0].src, nir_type_bool) &&
70 src_is_type(src_alu->src[1].src, nir_type_bool);
71 case nir_op_inot:
72 return src_is_type(src_alu->src[0].src, nir_type_bool);
73 default:
74 break;
75 }
76 }
77
78 return nir_alu_type_get_base_type(output_type) == type;
79 } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
80 nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
81
82 if (type == nir_type_bool) {
83 return intr->intrinsic == nir_intrinsic_load_front_face ||
84 intr->intrinsic == nir_intrinsic_load_helper_invocation;
85 }
86 }
87
88 /* don't know */
89 return false;
90 }
91
92 static bool
93 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
94 unsigned num_components, const uint8_t *swizzle,
95 struct match_state *state)
96 {
97 uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
98
99 /* Searching only works on SSA values because, if it's not SSA, we can't
100 * know if the value changed between one instance of that value in the
101 * expression and another. Also, the replace operation will place reads of
102 * that value right before the last instruction in the expression we're
103 * replacing so those reads will happen after the original reads and may
104 * not be valid if they're register reads.
105 */
106 if (!instr->src[src].src.is_ssa)
107 return false;
108
109 /* If the source is an explicitly sized source, then we need to reset
110 * both the number of components and the swizzle.
111 */
112 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
113 num_components = nir_op_infos[instr->op].input_sizes[src];
114 swizzle = identity_swizzle;
115 }
116
117 for (unsigned i = 0; i < num_components; ++i)
118 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
119
120 /* If the value has a specific bit size and it doesn't match, bail */
121 if (value->bit_size > 0 &&
122 nir_src_bit_size(instr->src[src].src) != value->bit_size)
123 return false;
124
125 switch (value->type) {
126 case nir_search_value_expression:
127 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
128 return false;
129
130 return match_expression(nir_search_value_as_expression(value),
131 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
132 num_components, new_swizzle, state);
133
134 case nir_search_value_variable: {
135 nir_search_variable *var = nir_search_value_as_variable(value);
136 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
137
138 if (state->variables_seen & (1 << var->variable)) {
139 if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
140 return false;
141
142 assert(!instr->src[src].abs && !instr->src[src].negate);
143
144 for (unsigned i = 0; i < num_components; ++i) {
145 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
146 return false;
147 }
148
149 return true;
150 } else {
151 if (var->is_constant &&
152 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
153 return false;
154
155 if (var->cond && !var->cond(instr, src, num_components, new_swizzle))
156 return false;
157
158 if (var->type != nir_type_invalid &&
159 !src_is_type(instr->src[src].src, var->type))
160 return false;
161
162 state->variables_seen |= (1 << var->variable);
163 state->variables[var->variable].src = instr->src[src].src;
164 state->variables[var->variable].abs = false;
165 state->variables[var->variable].negate = false;
166
167 for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
168 if (i < num_components)
169 state->variables[var->variable].swizzle[i] = new_swizzle[i];
170 else
171 state->variables[var->variable].swizzle[i] = 0;
172 }
173
174 return true;
175 }
176 }
177
178 case nir_search_value_constant: {
179 nir_search_constant *const_val = nir_search_value_as_constant(value);
180
181 if (!nir_src_is_const(instr->src[src].src))
182 return false;
183
184 switch (const_val->type) {
185 case nir_type_float:
186 for (unsigned i = 0; i < num_components; ++i) {
187 double val = nir_src_comp_as_float(instr->src[src].src,
188 new_swizzle[i]);
189 if (val != const_val->data.d)
190 return false;
191 }
192 return true;
193
194 case nir_type_int:
195 case nir_type_uint:
196 case nir_type_bool: {
197 unsigned bit_size = nir_src_bit_size(instr->src[src].src);
198 uint64_t mask = bit_size == 64 ? UINT64_MAX : (1ull << bit_size) - 1;
199 for (unsigned i = 0; i < num_components; ++i) {
200 uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
201 new_swizzle[i]);
202 if ((val & mask) != (const_val->data.u & mask))
203 return false;
204 }
205 return true;
206 }
207
208 default:
209 unreachable("Invalid alu source type");
210 }
211 }
212
213 default:
214 unreachable("Invalid search value type");
215 }
216 }
217
218 static bool
219 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
220 unsigned num_components, const uint8_t *swizzle,
221 struct match_state *state)
222 {
223 if (expr->cond && !expr->cond(instr))
224 return false;
225
226 if (instr->op != expr->opcode)
227 return false;
228
229 assert(instr->dest.dest.is_ssa);
230
231 if (expr->value.bit_size > 0 &&
232 instr->dest.dest.ssa.bit_size != expr->value.bit_size)
233 return false;
234
235 state->inexact_match = expr->inexact || state->inexact_match;
236 state->has_exact_alu = instr->exact || state->has_exact_alu;
237 if (state->inexact_match && state->has_exact_alu)
238 return false;
239
240 assert(!instr->dest.saturate);
241 assert(nir_op_infos[instr->op].num_inputs > 0);
242
243 /* If we have an explicitly sized destination, we can only handle the
244 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
245 * expression, we don't have the information right now to propagate that
246 * swizzle through. We can only properly propagate swizzles if the
247 * instruction is vectorized.
248 */
249 if (nir_op_infos[instr->op].output_size != 0) {
250 for (unsigned i = 0; i < num_components; i++) {
251 if (swizzle[i] != i)
252 return false;
253 }
254 }
255
256 /* Stash off the current variables_seen bitmask. This way we can
257 * restore it prior to matching in the commutative case below.
258 */
259 unsigned variables_seen_stash = state->variables_seen;
260
261 bool matched = true;
262 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
263 if (!match_value(expr->srcs[i], instr, i, num_components,
264 swizzle, state)) {
265 matched = false;
266 break;
267 }
268 }
269
270 if (matched)
271 return true;
272
273 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
274 assert(nir_op_infos[instr->op].num_inputs == 2);
275
276 /* Restore the variables_seen bitmask. If we don't do this, then we
277 * could end up with an erroneous failure due to variables found in the
278 * first match attempt above not matching those in the second.
279 */
280 state->variables_seen = variables_seen_stash;
281
282 if (!match_value(expr->srcs[0], instr, 1, num_components,
283 swizzle, state))
284 return false;
285
286 return match_value(expr->srcs[1], instr, 0, num_components,
287 swizzle, state);
288 } else {
289 return false;
290 }
291 }
292
293 static unsigned
294 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
295 struct match_state *state)
296 {
297 if (value->bit_size > 0)
298 return value->bit_size;
299 if (value->bit_size < 0)
300 return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
301 return search_bitsize;
302 }
303
304 static nir_alu_src
305 construct_value(nir_builder *build,
306 const nir_search_value *value,
307 unsigned num_components, unsigned search_bitsize,
308 struct match_state *state,
309 nir_instr *instr)
310 {
311 switch (value->type) {
312 case nir_search_value_expression: {
313 const nir_search_expression *expr = nir_search_value_as_expression(value);
314
315 if (nir_op_infos[expr->opcode].output_size != 0)
316 num_components = nir_op_infos[expr->opcode].output_size;
317
318 nir_alu_instr *alu = nir_alu_instr_create(build->shader, expr->opcode);
319 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
320 replace_bitsize(value, search_bitsize, state), NULL);
321 alu->dest.write_mask = (1 << num_components) - 1;
322 alu->dest.saturate = false;
323
324 /* We have no way of knowing what values in a given search expression
325 * map to a particular replacement value. Therefore, if the
326 * expression we are replacing has any exact values, the entire
327 * replacement should be exact.
328 */
329 alu->exact = state->has_exact_alu;
330
331 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
332 /* If the source is an explicitly sized source, then we need to reset
333 * the number of components to match.
334 */
335 if (nir_op_infos[alu->op].input_sizes[i] != 0)
336 num_components = nir_op_infos[alu->op].input_sizes[i];
337
338 alu->src[i] = construct_value(build, expr->srcs[i],
339 num_components, search_bitsize,
340 state, instr);
341 }
342
343 nir_builder_instr_insert(build, &alu->instr);
344
345 nir_alu_src val;
346 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
347 val.negate = false;
348 val.abs = false,
349 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
350
351 return val;
352 }
353
354 case nir_search_value_variable: {
355 const nir_search_variable *var = nir_search_value_as_variable(value);
356 assert(state->variables_seen & (1 << var->variable));
357
358 nir_alu_src val = { NIR_SRC_INIT };
359 nir_alu_src_copy(&val, &state->variables[var->variable],
360 (void *)build->shader);
361 assert(!var->is_constant);
362
363 return val;
364 }
365
366 case nir_search_value_constant: {
367 const nir_search_constant *c = nir_search_value_as_constant(value);
368 unsigned bit_size = replace_bitsize(value, search_bitsize, state);
369
370 nir_ssa_def *cval;
371 switch (c->type) {
372 case nir_type_float:
373 cval = nir_imm_floatN_t(build, c->data.d, bit_size);
374 break;
375
376 case nir_type_int:
377 case nir_type_uint:
378 cval = nir_imm_intN_t(build, c->data.i, bit_size);
379 break;
380
381 case nir_type_bool:
382 cval = nir_imm_bool(build, c->data.u);
383 break;
384 default:
385 unreachable("Invalid alu source type");
386 }
387
388 nir_alu_src val;
389 val.src = nir_src_for_ssa(cval);
390 val.negate = false;
391 val.abs = false,
392 memset(val.swizzle, 0, sizeof val.swizzle);
393
394 return val;
395 }
396
397 default:
398 unreachable("Invalid search value type");
399 }
400 }
401
402 nir_ssa_def *
403 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
404 const nir_search_expression *search,
405 const nir_search_value *replace)
406 {
407 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
408
409 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
410 swizzle[i] = i;
411
412 assert(instr->dest.dest.is_ssa);
413
414 struct match_state state;
415 state.inexact_match = false;
416 state.has_exact_alu = false;
417 state.variables_seen = 0;
418
419 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
420 swizzle, &state))
421 return NULL;
422
423 build->cursor = nir_before_instr(&instr->instr);
424
425 nir_alu_src val = construct_value(build, replace,
426 instr->dest.dest.ssa.num_components,
427 instr->dest.dest.ssa.bit_size,
428 &state, &instr->instr);
429
430 /* Inserting a mov may be unnecessary. However, it's much easier to
431 * simply let copy propagation clean this up than to try to go through
432 * and rewrite swizzles ourselves.
433 */
434 nir_ssa_def *ssa_val =
435 nir_imov_alu(build, val, instr->dest.dest.ssa.num_components);
436 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(ssa_val));
437
438 /* We know this one has no more uses because we just rewrote them all,
439 * so we can remove it. The rest of the matched expression, however, we
440 * don't know so much about. We'll just let dead code clean them up.
441 */
442 nir_instr_remove(&instr->instr);
443
444 return ssa_val;
445 }