nir/search: Assert that variable id's are in range
[mesa.git] / src / glsl / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "nir_search.h"
29
30 struct match_state {
31 unsigned variables_seen;
32 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
33 };
34
35 static bool
36 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
37 unsigned num_components, const uint8_t *swizzle,
38 struct match_state *state);
39
40 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
41
42 static bool alu_instr_is_bool(nir_alu_instr *instr);
43
44 static bool
45 src_is_bool(nir_src src)
46 {
47 if (!src.is_ssa)
48 return false;
49 if (src.ssa->parent_instr->type != nir_instr_type_alu)
50 return false;
51 return alu_instr_is_bool((nir_alu_instr *)src.ssa->parent_instr);
52 }
53
54 static bool
55 alu_instr_is_bool(nir_alu_instr *instr)
56 {
57 switch (instr->op) {
58 case nir_op_iand:
59 case nir_op_ior:
60 case nir_op_ixor:
61 return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src);
62 case nir_op_inot:
63 return src_is_bool(instr->src[0].src);
64 default:
65 return nir_op_infos[instr->op].output_type == nir_type_bool;
66 }
67 }
68
69 static bool
70 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
71 unsigned num_components, const uint8_t *swizzle,
72 struct match_state *state)
73 {
74 uint8_t new_swizzle[4];
75
76 /* If the source is an explicitly sized source, then we need to reset
77 * both the number of components and the swizzle.
78 */
79 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
80 num_components = nir_op_infos[instr->op].input_sizes[src];
81 swizzle = identity_swizzle;
82 }
83
84 for (int i = 0; i < num_components; ++i)
85 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
86
87 switch (value->type) {
88 case nir_search_value_expression:
89 if (!instr->src[src].src.is_ssa)
90 return false;
91
92 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
93 return false;
94
95 return match_expression(nir_search_value_as_expression(value),
96 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
97 num_components, new_swizzle, state);
98
99 case nir_search_value_variable: {
100 nir_search_variable *var = nir_search_value_as_variable(value);
101 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
102
103 if (state->variables_seen & (1 << var->variable)) {
104 if (!nir_srcs_equal(state->variables[var->variable].src,
105 instr->src[src].src))
106 return false;
107
108 assert(!instr->src[src].abs && !instr->src[src].negate);
109
110 for (int i = 0; i < num_components; ++i) {
111 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
112 return false;
113 }
114
115 return true;
116 } else {
117 if (var->is_constant &&
118 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
119 return false;
120
121 if (var->type != nir_type_invalid) {
122 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
123 return false;
124
125 nir_alu_instr *src_alu =
126 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
127
128 if (nir_op_infos[src_alu->op].output_type != var->type &&
129 !(var->type == nir_type_bool && alu_instr_is_bool(src_alu)))
130 return false;
131 }
132
133 state->variables_seen |= (1 << var->variable);
134 state->variables[var->variable].src = instr->src[src].src;
135 state->variables[var->variable].abs = false;
136 state->variables[var->variable].negate = false;
137
138 for (int i = 0; i < 4; ++i) {
139 if (i < num_components)
140 state->variables[var->variable].swizzle[i] = new_swizzle[i];
141 else
142 state->variables[var->variable].swizzle[i] = 0;
143 }
144
145 return true;
146 }
147 }
148
149 case nir_search_value_constant: {
150 nir_search_constant *const_val = nir_search_value_as_constant(value);
151
152 if (!instr->src[src].src.is_ssa)
153 return false;
154
155 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
156 return false;
157
158 nir_load_const_instr *load =
159 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
160
161 switch (nir_op_infos[instr->op].input_types[src]) {
162 case nir_type_float:
163 for (unsigned i = 0; i < num_components; ++i) {
164 if (load->value.f[new_swizzle[i]] != const_val->data.f)
165 return false;
166 }
167 return true;
168 case nir_type_int:
169 case nir_type_unsigned:
170 case nir_type_bool:
171 for (unsigned i = 0; i < num_components; ++i) {
172 if (load->value.i[new_swizzle[i]] != const_val->data.i)
173 return false;
174 }
175 return true;
176 default:
177 unreachable("Invalid alu source type");
178 }
179 }
180
181 default:
182 unreachable("Invalid search value type");
183 }
184 }
185
186 static bool
187 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
188 unsigned num_components, const uint8_t *swizzle,
189 struct match_state *state)
190 {
191 if (instr->op != expr->opcode)
192 return false;
193
194 assert(!instr->dest.saturate);
195 assert(nir_op_infos[instr->op].num_inputs > 0);
196
197 /* If we have an explicitly sized destination, we can only handle the
198 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
199 * expression, we don't have the information right now to propagate that
200 * swizzle through. We can only properly propagate swizzles if the
201 * instruction is vectorized.
202 */
203 if (nir_op_infos[instr->op].output_size != 0) {
204 for (unsigned i = 0; i < num_components; i++) {
205 if (swizzle[i] != i)
206 return false;
207 }
208 }
209
210 bool matched = true;
211 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
212 if (!match_value(expr->srcs[i], instr, i, num_components,
213 swizzle, state)) {
214 matched = false;
215 break;
216 }
217 }
218
219 if (matched)
220 return true;
221
222 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
223 assert(nir_op_infos[instr->op].num_inputs == 2);
224 if (!match_value(expr->srcs[0], instr, 1, num_components,
225 swizzle, state))
226 return false;
227
228 return match_value(expr->srcs[1], instr, 0, num_components,
229 swizzle, state);
230 } else {
231 return false;
232 }
233 }
234
235 static nir_alu_src
236 construct_value(const nir_search_value *value, nir_alu_type type,
237 unsigned num_components, struct match_state *state,
238 nir_instr *instr, void *mem_ctx)
239 {
240 switch (value->type) {
241 case nir_search_value_expression: {
242 const nir_search_expression *expr = nir_search_value_as_expression(value);
243
244 if (nir_op_infos[expr->opcode].output_size != 0)
245 num_components = nir_op_infos[expr->opcode].output_size;
246
247 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
248 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, NULL);
249 alu->dest.write_mask = (1 << num_components) - 1;
250 alu->dest.saturate = false;
251
252 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
253 /* If the source is an explicitly sized source, then we need to reset
254 * the number of components to match.
255 */
256 if (nir_op_infos[alu->op].input_sizes[i] != 0)
257 num_components = nir_op_infos[alu->op].input_sizes[i];
258
259 alu->src[i] = construct_value(expr->srcs[i],
260 nir_op_infos[alu->op].input_types[i],
261 num_components,
262 state, instr, mem_ctx);
263 }
264
265 nir_instr_insert_before(instr, &alu->instr);
266
267 nir_alu_src val;
268 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
269 val.negate = false;
270 val.abs = false,
271 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
272
273 return val;
274 }
275
276 case nir_search_value_variable: {
277 const nir_search_variable *var = nir_search_value_as_variable(value);
278 assert(state->variables_seen & (1 << var->variable));
279
280 nir_alu_src val = { NIR_SRC_INIT };
281 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
282
283 assert(!var->is_constant);
284
285 return val;
286 }
287
288 case nir_search_value_constant: {
289 const nir_search_constant *c = nir_search_value_as_constant(value);
290 nir_load_const_instr *load = nir_load_const_instr_create(mem_ctx, 1);
291
292 switch (type) {
293 case nir_type_float:
294 load->def.name = ralloc_asprintf(mem_ctx, "%f", c->data.f);
295 load->value.f[0] = c->data.f;
296 break;
297 case nir_type_int:
298 load->def.name = ralloc_asprintf(mem_ctx, "%d", c->data.i);
299 load->value.i[0] = c->data.i;
300 break;
301 case nir_type_unsigned:
302 case nir_type_bool:
303 load->value.u[0] = c->data.u;
304 break;
305 default:
306 unreachable("Invalid alu source type");
307 }
308
309 nir_instr_insert_before(instr, &load->instr);
310
311 nir_alu_src val;
312 val.src = nir_src_for_ssa(&load->def);
313 val.negate = false;
314 val.abs = false,
315 memset(val.swizzle, 0, sizeof val.swizzle);
316
317 return val;
318 }
319
320 default:
321 unreachable("Invalid search value type");
322 }
323 }
324
325 nir_alu_instr *
326 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
327 const nir_search_value *replace, void *mem_ctx)
328 {
329 uint8_t swizzle[4] = { 0, 0, 0, 0 };
330
331 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
332 swizzle[i] = i;
333
334 assert(instr->dest.dest.is_ssa);
335
336 struct match_state state;
337 state.variables_seen = 0;
338
339 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
340 swizzle, &state))
341 return NULL;
342
343 /* Inserting a mov may be unnecessary. However, it's much easier to
344 * simply let copy propagation clean this up than to try to go through
345 * and rewrite swizzles ourselves.
346 */
347 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
348 mov->dest.write_mask = instr->dest.write_mask;
349 nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
350 instr->dest.dest.ssa.num_components, NULL);
351
352 mov->src[0] = construct_value(replace, nir_op_infos[instr->op].output_type,
353 instr->dest.dest.ssa.num_components, &state,
354 &instr->instr, mem_ctx);
355 nir_instr_insert_before(&instr->instr, &mov->instr);
356
357 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
358 nir_src_for_ssa(&mov->dest.dest.ssa), mem_ctx);
359
360 /* We know this one has no more uses because we just rewrote them all,
361 * so we can remove it. The rest of the matched expression, however, we
362 * don't know so much about. We'll just let dead code clean them up.
363 */
364 nir_instr_remove(&instr->instr);
365
366 return mov;
367 }