nir/search: handle explicitly sized sources in match_value
[mesa.git] / src / glsl / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "nir_search.h"
29
30 struct match_state {
31 unsigned variables_seen;
32 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
33 };
34
35 static bool
36 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
37 unsigned num_components, const uint8_t *swizzle,
38 struct match_state *state);
39
40 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
41
42 static bool alu_instr_is_bool(nir_alu_instr *instr);
43
44 static bool
45 src_is_bool(nir_src src)
46 {
47 if (!src.is_ssa)
48 return false;
49 if (src.ssa->parent_instr->type != nir_instr_type_alu)
50 return false;
51 return alu_instr_is_bool((nir_alu_instr *)src.ssa->parent_instr);
52 }
53
54 static bool
55 alu_instr_is_bool(nir_alu_instr *instr)
56 {
57 switch (instr->op) {
58 case nir_op_iand:
59 case nir_op_ior:
60 case nir_op_ixor:
61 return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src);
62 case nir_op_inot:
63 return src_is_bool(instr->src[0].src);
64 default:
65 return nir_op_infos[instr->op].output_type == nir_type_bool;
66 }
67 }
68
69 static bool
70 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
71 unsigned num_components, const uint8_t *swizzle,
72 struct match_state *state)
73 {
74 uint8_t new_swizzle[4];
75
76 /* If the source is an explicitly sized source, then we need to reset
77 * both the number of components and the swizzle.
78 */
79 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
80 num_components = nir_op_infos[instr->op].input_sizes[src];
81 swizzle = identity_swizzle;
82 }
83
84 for (int i = 0; i < num_components; ++i)
85 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
86
87 switch (value->type) {
88 case nir_search_value_expression:
89 if (!instr->src[src].src.is_ssa)
90 return false;
91
92 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
93 return false;
94
95 return match_expression(nir_search_value_as_expression(value),
96 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
97 num_components, new_swizzle, state);
98
99 case nir_search_value_variable: {
100 nir_search_variable *var = nir_search_value_as_variable(value);
101
102 if (state->variables_seen & (1 << var->variable)) {
103 if (!nir_srcs_equal(state->variables[var->variable].src,
104 instr->src[src].src))
105 return false;
106
107 assert(!instr->src[src].abs && !instr->src[src].negate);
108
109 for (int i = 0; i < num_components; ++i) {
110 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
111 return false;
112 }
113
114 return true;
115 } else {
116 if (var->is_constant &&
117 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
118 return false;
119
120 if (var->type != nir_type_invalid) {
121 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
122 return false;
123
124 nir_alu_instr *src_alu =
125 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
126
127 if (nir_op_infos[src_alu->op].output_type != var->type &&
128 !(var->type == nir_type_bool && alu_instr_is_bool(src_alu)))
129 return false;
130 }
131
132 state->variables_seen |= (1 << var->variable);
133 state->variables[var->variable].src = instr->src[src].src;
134 state->variables[var->variable].abs = false;
135 state->variables[var->variable].negate = false;
136
137 for (int i = 0; i < 4; ++i) {
138 if (i < num_components)
139 state->variables[var->variable].swizzle[i] = new_swizzle[i];
140 else
141 state->variables[var->variable].swizzle[i] = 0;
142 }
143
144 return true;
145 }
146 }
147
148 case nir_search_value_constant: {
149 nir_search_constant *const_val = nir_search_value_as_constant(value);
150
151 if (!instr->src[src].src.is_ssa)
152 return false;
153
154 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
155 return false;
156
157 nir_load_const_instr *load =
158 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
159
160 switch (nir_op_infos[instr->op].input_types[src]) {
161 case nir_type_float:
162 for (unsigned i = 0; i < num_components; ++i) {
163 if (load->value.f[new_swizzle[i]] != const_val->data.f)
164 return false;
165 }
166 return true;
167 case nir_type_int:
168 case nir_type_unsigned:
169 case nir_type_bool:
170 for (unsigned i = 0; i < num_components; ++i) {
171 if (load->value.i[new_swizzle[i]] != const_val->data.i)
172 return false;
173 }
174 return true;
175 default:
176 unreachable("Invalid alu source type");
177 }
178 }
179
180 default:
181 unreachable("Invalid search value type");
182 }
183 }
184
185 static bool
186 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
187 unsigned num_components, const uint8_t *swizzle,
188 struct match_state *state)
189 {
190 if (instr->op != expr->opcode)
191 return false;
192
193 assert(!instr->dest.saturate);
194 assert(nir_op_infos[instr->op].num_inputs > 0);
195
196 /* If we have an explicitly sized destination, we can only handle the
197 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
198 * expression, we don't have the information right now to propagate that
199 * swizzle through. We can only properly propagate swizzles if the
200 * instruction is vectorized.
201 */
202 if (nir_op_infos[instr->op].output_size != 0) {
203 for (unsigned i = 0; i < num_components; i++) {
204 if (swizzle[i] != i)
205 return false;
206 }
207 }
208
209 bool matched = true;
210 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
211 if (!match_value(expr->srcs[i], instr, i, num_components,
212 swizzle, state)) {
213 matched = false;
214 break;
215 }
216 }
217
218 if (matched)
219 return true;
220
221 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
222 assert(nir_op_infos[instr->op].num_inputs == 2);
223 if (!match_value(expr->srcs[0], instr, 1, num_components,
224 swizzle, state))
225 return false;
226
227 return match_value(expr->srcs[1], instr, 0, num_components,
228 swizzle, state);
229 } else {
230 return false;
231 }
232 }
233
234 static nir_alu_src
235 construct_value(const nir_search_value *value, nir_alu_type type,
236 unsigned num_components, struct match_state *state,
237 nir_instr *instr, void *mem_ctx)
238 {
239 switch (value->type) {
240 case nir_search_value_expression: {
241 const nir_search_expression *expr = nir_search_value_as_expression(value);
242
243 if (nir_op_infos[expr->opcode].output_size != 0)
244 num_components = nir_op_infos[expr->opcode].output_size;
245
246 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
247 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, NULL);
248 alu->dest.write_mask = (1 << num_components) - 1;
249 alu->dest.saturate = false;
250
251 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
252 /* If the source is an explicitly sized source, then we need to reset
253 * the number of components to match.
254 */
255 if (nir_op_infos[alu->op].input_sizes[i] != 0)
256 num_components = nir_op_infos[alu->op].input_sizes[i];
257
258 alu->src[i] = construct_value(expr->srcs[i],
259 nir_op_infos[alu->op].input_types[i],
260 num_components,
261 state, instr, mem_ctx);
262 }
263
264 nir_instr_insert_before(instr, &alu->instr);
265
266 nir_alu_src val;
267 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
268 val.negate = false;
269 val.abs = false,
270 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
271
272 return val;
273 }
274
275 case nir_search_value_variable: {
276 const nir_search_variable *var = nir_search_value_as_variable(value);
277 assert(state->variables_seen & (1 << var->variable));
278
279 nir_alu_src val = { NIR_SRC_INIT };
280 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
281
282 assert(!var->is_constant);
283
284 return val;
285 }
286
287 case nir_search_value_constant: {
288 const nir_search_constant *c = nir_search_value_as_constant(value);
289 nir_load_const_instr *load = nir_load_const_instr_create(mem_ctx, 1);
290
291 switch (type) {
292 case nir_type_float:
293 load->def.name = ralloc_asprintf(mem_ctx, "%f", c->data.f);
294 load->value.f[0] = c->data.f;
295 break;
296 case nir_type_int:
297 load->def.name = ralloc_asprintf(mem_ctx, "%d", c->data.i);
298 load->value.i[0] = c->data.i;
299 break;
300 case nir_type_unsigned:
301 case nir_type_bool:
302 load->value.u[0] = c->data.u;
303 break;
304 default:
305 unreachable("Invalid alu source type");
306 }
307
308 nir_instr_insert_before(instr, &load->instr);
309
310 nir_alu_src val;
311 val.src = nir_src_for_ssa(&load->def);
312 val.negate = false;
313 val.abs = false,
314 memset(val.swizzle, 0, sizeof val.swizzle);
315
316 return val;
317 }
318
319 default:
320 unreachable("Invalid search value type");
321 }
322 }
323
324 nir_alu_instr *
325 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
326 const nir_search_value *replace, void *mem_ctx)
327 {
328 uint8_t swizzle[4] = { 0, 0, 0, 0 };
329
330 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
331 swizzle[i] = i;
332
333 assert(instr->dest.dest.is_ssa);
334
335 struct match_state state;
336 state.variables_seen = 0;
337
338 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
339 swizzle, &state))
340 return NULL;
341
342 /* Inserting a mov may be unnecessary. However, it's much easier to
343 * simply let copy propagation clean this up than to try to go through
344 * and rewrite swizzles ourselves.
345 */
346 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
347 mov->dest.write_mask = instr->dest.write_mask;
348 nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
349 instr->dest.dest.ssa.num_components, NULL);
350
351 mov->src[0] = construct_value(replace, nir_op_infos[instr->op].output_type,
352 instr->dest.dest.ssa.num_components, &state,
353 &instr->instr, mem_ctx);
354 nir_instr_insert_before(&instr->instr, &mov->instr);
355
356 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
357 nir_src_for_ssa(&mov->dest.dest.ssa), mem_ctx);
358
359 /* We know this one has no more uses because we just rewrote them all,
360 * so we can remove it. The rest of the matched expression, however, we
361 * don't know so much about. We'll just let dead code clean them up.
362 */
363 nir_instr_remove(&instr->instr);
364
365 return mov;
366 }