nir: add a bit_size parameter to nir_ssa_dest_init
[mesa.git] / src / compiler / nir / nir_search.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "nir_search.h"
29
30 struct match_state {
31 unsigned variables_seen;
32 nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
33 };
34
35 static bool
36 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
37 unsigned num_components, const uint8_t *swizzle,
38 struct match_state *state);
39
40 static const uint8_t identity_swizzle[] = { 0, 1, 2, 3 };
41
42 static bool alu_instr_is_bool(nir_alu_instr *instr);
43
44 static bool
45 src_is_bool(nir_src src)
46 {
47 if (!src.is_ssa)
48 return false;
49 if (src.ssa->parent_instr->type != nir_instr_type_alu)
50 return false;
51 return alu_instr_is_bool(nir_instr_as_alu(src.ssa->parent_instr));
52 }
53
54 static bool
55 alu_instr_is_bool(nir_alu_instr *instr)
56 {
57 switch (instr->op) {
58 case nir_op_iand:
59 case nir_op_ior:
60 case nir_op_ixor:
61 return src_is_bool(instr->src[0].src) && src_is_bool(instr->src[1].src);
62 case nir_op_inot:
63 return src_is_bool(instr->src[0].src);
64 default:
65 return nir_op_infos[instr->op].output_type == nir_type_bool;
66 }
67 }
68
69 static bool
70 match_value(const nir_search_value *value, nir_alu_instr *instr, unsigned src,
71 unsigned num_components, const uint8_t *swizzle,
72 struct match_state *state)
73 {
74 uint8_t new_swizzle[4];
75
76 /* If the source is an explicitly sized source, then we need to reset
77 * both the number of components and the swizzle.
78 */
79 if (nir_op_infos[instr->op].input_sizes[src] != 0) {
80 num_components = nir_op_infos[instr->op].input_sizes[src];
81 swizzle = identity_swizzle;
82 }
83
84 for (unsigned i = 0; i < num_components; ++i)
85 new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
86
87 switch (value->type) {
88 case nir_search_value_expression:
89 if (!instr->src[src].src.is_ssa)
90 return false;
91
92 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
93 return false;
94
95 return match_expression(nir_search_value_as_expression(value),
96 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
97 num_components, new_swizzle, state);
98
99 case nir_search_value_variable: {
100 nir_search_variable *var = nir_search_value_as_variable(value);
101 assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
102
103 if (state->variables_seen & (1 << var->variable)) {
104 if (!nir_srcs_equal(state->variables[var->variable].src,
105 instr->src[src].src))
106 return false;
107
108 assert(!instr->src[src].abs && !instr->src[src].negate);
109
110 for (unsigned i = 0; i < num_components; ++i) {
111 if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
112 return false;
113 }
114
115 return true;
116 } else {
117 if (var->is_constant &&
118 instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
119 return false;
120
121 if (var->type != nir_type_invalid) {
122 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
123 return false;
124
125 nir_alu_instr *src_alu =
126 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
127
128 if (nir_op_infos[src_alu->op].output_type != var->type &&
129 !(var->type == nir_type_bool && alu_instr_is_bool(src_alu)))
130 return false;
131 }
132
133 state->variables_seen |= (1 << var->variable);
134 state->variables[var->variable].src = instr->src[src].src;
135 state->variables[var->variable].abs = false;
136 state->variables[var->variable].negate = false;
137
138 for (unsigned i = 0; i < 4; ++i) {
139 if (i < num_components)
140 state->variables[var->variable].swizzle[i] = new_swizzle[i];
141 else
142 state->variables[var->variable].swizzle[i] = 0;
143 }
144
145 return true;
146 }
147 }
148
149 case nir_search_value_constant: {
150 nir_search_constant *const_val = nir_search_value_as_constant(value);
151
152 if (!instr->src[src].src.is_ssa)
153 return false;
154
155 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
156 return false;
157
158 nir_load_const_instr *load =
159 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
160
161 switch (nir_op_infos[instr->op].input_types[src]) {
162 case nir_type_float:
163 for (unsigned i = 0; i < num_components; ++i) {
164 if (load->value.f32[new_swizzle[i]] != const_val->data.f)
165 return false;
166 }
167 return true;
168 case nir_type_int:
169 case nir_type_uint:
170 case nir_type_bool:
171 for (unsigned i = 0; i < num_components; ++i) {
172 if (load->value.i32[new_swizzle[i]] != const_val->data.i)
173 return false;
174 }
175 return true;
176 default:
177 unreachable("Invalid alu source type");
178 }
179 }
180
181 default:
182 unreachable("Invalid search value type");
183 }
184 }
185
186 static bool
187 match_expression(const nir_search_expression *expr, nir_alu_instr *instr,
188 unsigned num_components, const uint8_t *swizzle,
189 struct match_state *state)
190 {
191 if (instr->op != expr->opcode)
192 return false;
193
194 assert(!instr->dest.saturate);
195 assert(nir_op_infos[instr->op].num_inputs > 0);
196
197 /* If we have an explicitly sized destination, we can only handle the
198 * identity swizzle. While dot(vec3(a, b, c).zxy) is a valid
199 * expression, we don't have the information right now to propagate that
200 * swizzle through. We can only properly propagate swizzles if the
201 * instruction is vectorized.
202 */
203 if (nir_op_infos[instr->op].output_size != 0) {
204 for (unsigned i = 0; i < num_components; i++) {
205 if (swizzle[i] != i)
206 return false;
207 }
208 }
209
210 /* Stash off the current variables_seen bitmask. This way we can
211 * restore it prior to matching in the commutative case below.
212 */
213 unsigned variables_seen_stash = state->variables_seen;
214
215 bool matched = true;
216 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
217 if (!match_value(expr->srcs[i], instr, i, num_components,
218 swizzle, state)) {
219 matched = false;
220 break;
221 }
222 }
223
224 if (matched)
225 return true;
226
227 if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) {
228 assert(nir_op_infos[instr->op].num_inputs == 2);
229
230 /* Restore the variables_seen bitmask. If we don't do this, then we
231 * could end up with an erroneous failure due to variables found in the
232 * first match attempt above not matching those in the second.
233 */
234 state->variables_seen = variables_seen_stash;
235
236 if (!match_value(expr->srcs[0], instr, 1, num_components,
237 swizzle, state))
238 return false;
239
240 return match_value(expr->srcs[1], instr, 0, num_components,
241 swizzle, state);
242 } else {
243 return false;
244 }
245 }
246
247 static nir_alu_src
248 construct_value(const nir_search_value *value, nir_alu_type type,
249 unsigned num_components, struct match_state *state,
250 nir_instr *instr, void *mem_ctx)
251 {
252 switch (value->type) {
253 case nir_search_value_expression: {
254 const nir_search_expression *expr = nir_search_value_as_expression(value);
255
256 if (nir_op_infos[expr->opcode].output_size != 0)
257 num_components = nir_op_infos[expr->opcode].output_size;
258
259 nir_alu_instr *alu = nir_alu_instr_create(mem_ctx, expr->opcode);
260 nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components, 32, NULL);
261 alu->dest.write_mask = (1 << num_components) - 1;
262 alu->dest.saturate = false;
263
264 for (unsigned i = 0; i < nir_op_infos[expr->opcode].num_inputs; i++) {
265 /* If the source is an explicitly sized source, then we need to reset
266 * the number of components to match.
267 */
268 if (nir_op_infos[alu->op].input_sizes[i] != 0)
269 num_components = nir_op_infos[alu->op].input_sizes[i];
270
271 alu->src[i] = construct_value(expr->srcs[i],
272 nir_op_infos[alu->op].input_types[i],
273 num_components,
274 state, instr, mem_ctx);
275 }
276
277 nir_instr_insert_before(instr, &alu->instr);
278
279 nir_alu_src val;
280 val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
281 val.negate = false;
282 val.abs = false,
283 memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
284
285 return val;
286 }
287
288 case nir_search_value_variable: {
289 const nir_search_variable *var = nir_search_value_as_variable(value);
290 assert(state->variables_seen & (1 << var->variable));
291
292 nir_alu_src val = { NIR_SRC_INIT };
293 nir_alu_src_copy(&val, &state->variables[var->variable], mem_ctx);
294
295 assert(!var->is_constant);
296
297 return val;
298 }
299
300 case nir_search_value_constant: {
301 const nir_search_constant *c = nir_search_value_as_constant(value);
302 nir_load_const_instr *load = nir_load_const_instr_create(mem_ctx, 1);
303
304 switch (type) {
305 case nir_type_float:
306 load->def.name = ralloc_asprintf(mem_ctx, "%f", c->data.f);
307 load->value.f32[0] = c->data.f;
308 break;
309 case nir_type_int:
310 load->def.name = ralloc_asprintf(mem_ctx, "%d", c->data.i);
311 load->value.i32[0] = c->data.i;
312 break;
313 case nir_type_uint:
314 case nir_type_bool:
315 load->value.u32[0] = c->data.u;
316 break;
317 default:
318 unreachable("Invalid alu source type");
319 }
320
321 nir_instr_insert_before(instr, &load->instr);
322
323 nir_alu_src val;
324 val.src = nir_src_for_ssa(&load->def);
325 val.negate = false;
326 val.abs = false,
327 memset(val.swizzle, 0, sizeof val.swizzle);
328
329 return val;
330 }
331
332 default:
333 unreachable("Invalid search value type");
334 }
335 }
336
337 nir_alu_instr *
338 nir_replace_instr(nir_alu_instr *instr, const nir_search_expression *search,
339 const nir_search_value *replace, void *mem_ctx)
340 {
341 uint8_t swizzle[4] = { 0, 0, 0, 0 };
342
343 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
344 swizzle[i] = i;
345
346 assert(instr->dest.dest.is_ssa);
347
348 struct match_state state;
349 state.variables_seen = 0;
350
351 if (!match_expression(search, instr, instr->dest.dest.ssa.num_components,
352 swizzle, &state))
353 return NULL;
354
355 /* Inserting a mov may be unnecessary. However, it's much easier to
356 * simply let copy propagation clean this up than to try to go through
357 * and rewrite swizzles ourselves.
358 */
359 nir_alu_instr *mov = nir_alu_instr_create(mem_ctx, nir_op_imov);
360 mov->dest.write_mask = instr->dest.write_mask;
361 nir_ssa_dest_init(&mov->instr, &mov->dest.dest,
362 instr->dest.dest.ssa.num_components,
363 instr->dest.dest.ssa.bit_size, NULL);
364
365 mov->src[0] = construct_value(replace, nir_op_infos[instr->op].output_type,
366 instr->dest.dest.ssa.num_components, &state,
367 &instr->instr, mem_ctx);
368 nir_instr_insert_before(&instr->instr, &mov->instr);
369
370 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa,
371 nir_src_for_ssa(&mov->dest.dest.ssa));
372
373 /* We know this one has no more uses because we just rewrote them all,
374 * so we can remove it. The rest of the matched expression, however, we
375 * don't know so much about. We'll just let dead code clean them up.
376 */
377 nir_instr_remove(&instr->instr);
378
379 return mov;
380 }