nir: Call nir_metadata_preserve more places
[mesa.git] / src / glsl / nir / nir_opt_constant_folding.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "nir.h"
29 #include <math.h>
30
31 /*
32 * Implements SSA-based constant folding.
33 */
34
35 struct constant_fold_state {
36 void *mem_ctx;
37 nir_function_impl *impl;
38 bool progress;
39 };
40
41 #define SRC_COMP(T, IDX, CMP) src[IDX]->value.T[instr->src[IDX].swizzle[CMP]]
42 #define SRC(T, IDX) SRC_COMP(T, IDX, i)
43 #define DEST_COMP(T, CMP) dest->value.T[CMP]
44 #define DEST(T) DEST_COMP(T, i)
45
46 #define FOLD_PER_COMP(EXPR) \
47 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; i++) { \
48 EXPR; \
49 } \
50
51 static bool
52 constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
53 {
54 nir_load_const_instr *src[4], *dest;
55
56 if (!instr->dest.dest.is_ssa)
57 return false;
58
59 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
60 if (!instr->src[i].src.is_ssa)
61 return false;
62
63 if (instr->src[i].src.ssa->parent_instr->type != nir_instr_type_load_const)
64 return false;
65
66 /* We shouldn't have any source modifiers in the optimization loop. */
67 assert(!instr->src[i].abs && !instr->src[i].negate);
68
69 src[i] = nir_instr_as_load_const(instr->src[i].src.ssa->parent_instr);
70 }
71
72 /* We shouldn't have any saturate modifiers in the optimization loop. */
73 assert(!instr->dest.saturate);
74
75 dest = nir_load_const_instr_create(mem_ctx);
76 dest->array_elems = 0;
77 dest->num_components = instr->dest.dest.ssa.num_components;
78
79 switch (instr->op) {
80 case nir_op_ineg:
81 FOLD_PER_COMP(DEST(i) = -SRC(i, 0));
82 break;
83 case nir_op_fneg:
84 FOLD_PER_COMP(DEST(f) = -SRC(f, 0));
85 break;
86 case nir_op_inot:
87 FOLD_PER_COMP(DEST(i) = ~SRC(i, 0));
88 break;
89 case nir_op_fnot:
90 FOLD_PER_COMP(DEST(f) = (SRC(f, 0) == 0.0f) ? 1.0f : 0.0f);
91 break;
92 case nir_op_frcp:
93 FOLD_PER_COMP(DEST(f) = 1.0f / SRC(f, 0));
94 break;
95 case nir_op_frsq:
96 FOLD_PER_COMP(DEST(f) = 1.0f / sqrt(SRC(f, 0)));
97 break;
98 case nir_op_fsqrt:
99 FOLD_PER_COMP(DEST(f) = sqrtf(SRC(f, 0)));
100 break;
101 case nir_op_fexp:
102 FOLD_PER_COMP(DEST(f) = expf(SRC(f, 0)));
103 break;
104 case nir_op_flog:
105 FOLD_PER_COMP(DEST(f) = logf(SRC(f, 0)));
106 break;
107 case nir_op_fexp2:
108 FOLD_PER_COMP(DEST(f) = exp2f(SRC(f, 0)));
109 break;
110 case nir_op_flog2:
111 FOLD_PER_COMP(DEST(f) = log2f(SRC(f, 0)));
112 break;
113 case nir_op_f2i:
114 FOLD_PER_COMP(DEST(i) = SRC(f, 0));
115 break;
116 case nir_op_f2u:
117 FOLD_PER_COMP(DEST(u) = SRC(f, 0));
118 break;
119 case nir_op_i2f:
120 FOLD_PER_COMP(DEST(f) = SRC(i, 0));
121 break;
122 case nir_op_f2b:
123 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) == 0.0f) ? NIR_FALSE : NIR_TRUE);
124 break;
125 case nir_op_b2f:
126 FOLD_PER_COMP(DEST(f) = SRC(u, 0) ? 1.0f : 0.0f);
127 break;
128 case nir_op_i2b:
129 FOLD_PER_COMP(DEST(u) = SRC(i, 0) ? NIR_TRUE : NIR_FALSE);
130 break;
131 case nir_op_u2f:
132 FOLD_PER_COMP(DEST(f) = SRC(u, 0));
133 break;
134 case nir_op_bany2:
135 DEST_COMP(u, 0) = (SRC_COMP(u, 0, 0) || SRC_COMP(u, 0, 1)) ?
136 NIR_TRUE : NIR_FALSE;
137 break;
138 case nir_op_fadd:
139 FOLD_PER_COMP(DEST(f) = SRC(f, 0) + SRC(f, 1));
140 break;
141 case nir_op_iadd:
142 FOLD_PER_COMP(DEST(i) = SRC(i, 0) + SRC(i, 1));
143 break;
144 case nir_op_fsub:
145 FOLD_PER_COMP(DEST(f) = SRC(f, 0) - SRC(f, 1));
146 break;
147 case nir_op_isub:
148 FOLD_PER_COMP(DEST(i) = SRC(i, 0) - SRC(i, 1));
149 break;
150 case nir_op_fmul:
151 FOLD_PER_COMP(DEST(f) = SRC(f, 0) * SRC(f, 1));
152 break;
153 case nir_op_imul:
154 FOLD_PER_COMP(DEST(i) = SRC(i, 0) * SRC(i, 1));
155 break;
156 case nir_op_fdiv:
157 FOLD_PER_COMP(DEST(f) = SRC(f, 0) / SRC(f, 1));
158 break;
159 case nir_op_idiv:
160 FOLD_PER_COMP(DEST(i) = SRC(i, 0) / SRC(i, 1));
161 break;
162 case nir_op_udiv:
163 FOLD_PER_COMP(DEST(u) = SRC(u, 0) / SRC(u, 1));
164 break;
165 case nir_op_flt:
166 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) < SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
167 break;
168 case nir_op_fge:
169 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) >= SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
170 break;
171 case nir_op_feq:
172 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) == SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
173 break;
174 case nir_op_fne:
175 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) != SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
176 break;
177 case nir_op_ilt:
178 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) < SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
179 break;
180 case nir_op_ige:
181 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) >= SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
182 break;
183 case nir_op_ieq:
184 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) == SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
185 break;
186 case nir_op_ine:
187 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) != SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
188 break;
189 case nir_op_ult:
190 FOLD_PER_COMP(DEST(u) = (SRC(u, 0) < SRC(u, 1)) ? NIR_TRUE : NIR_FALSE);
191 break;
192 case nir_op_uge:
193 FOLD_PER_COMP(DEST(u) = (SRC(u, 0) >= SRC(u, 1)) ? NIR_TRUE : NIR_FALSE);
194 break;
195 case nir_op_ishl:
196 FOLD_PER_COMP(DEST(i) = SRC(i, 0) << SRC(i, 1));
197 break;
198 case nir_op_ishr:
199 FOLD_PER_COMP(DEST(i) = SRC(i, 0) >> SRC(i, 1));
200 break;
201 case nir_op_ushr:
202 FOLD_PER_COMP(DEST(u) = SRC(u, 0) >> SRC(u, 1));
203 break;
204 case nir_op_iand:
205 FOLD_PER_COMP(DEST(i) = SRC(i, 0) & SRC(i, 1));
206 break;
207 case nir_op_ior:
208 FOLD_PER_COMP(DEST(i) = SRC(i, 0) | SRC(i, 1));
209 break;
210 case nir_op_ixor:
211 FOLD_PER_COMP(DEST(i) = SRC(i, 0) ^ SRC(i, 1));
212 break;
213 default:
214 ralloc_free(dest);
215 return false;
216 }
217
218 dest->dest.is_ssa = true;
219 nir_ssa_def_init(&dest->instr, &dest->dest.ssa,
220 instr->dest.dest.ssa.num_components,
221 instr->dest.dest.ssa.name);
222
223 nir_instr_insert_before(&instr->instr, &dest->instr);
224
225 nir_src new_src = {
226 .is_ssa = true,
227 .ssa = &dest->dest.ssa,
228 };
229
230 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, new_src, mem_ctx);
231
232 nir_instr_remove(&instr->instr);
233 ralloc_free(instr);
234
235 return true;
236 }
237
238 static bool
239 constant_fold_deref(nir_instr *instr, nir_deref_var *deref)
240 {
241 bool progress = false;
242
243 for (nir_deref *tail = deref->deref.child; tail; tail = tail->child) {
244 if (tail->deref_type != nir_deref_type_array)
245 continue;
246
247 nir_deref_array *arr = nir_deref_as_array(tail);
248
249 if (arr->deref_array_type == nir_deref_array_type_indirect &&
250 arr->indirect.is_ssa &&
251 arr->indirect.ssa->parent_instr->type == nir_instr_type_load_const) {
252 nir_load_const_instr *indirect =
253 nir_instr_as_load_const(arr->indirect.ssa->parent_instr);
254
255 arr->base_offset += indirect->value.u[0];
256
257 nir_src empty = {
258 .is_ssa = true,
259 .ssa = NULL,
260 };
261
262 nir_instr_rewrite_src(instr, &arr->indirect, empty);
263
264 arr->deref_array_type = nir_deref_array_type_direct;
265
266 progress = true;
267 }
268 }
269
270 return progress;
271 }
272
273 static bool
274 constant_fold_intrinsic_instr(nir_intrinsic_instr *instr)
275 {
276 bool progress = false;
277
278 unsigned num_vars = nir_intrinsic_infos[instr->intrinsic].num_variables;
279 for (unsigned i = 0; i < num_vars; i++) {
280 progress |= constant_fold_deref(&instr->instr, instr->variables[i]);
281 }
282
283 return progress;
284 }
285
286 static bool
287 constant_fold_tex_instr(nir_tex_instr *instr)
288 {
289 if (instr->sampler)
290 return constant_fold_deref(&instr->instr, instr->sampler);
291 else
292 return false;
293 }
294
295 static bool
296 constant_fold_block(nir_block *block, void *void_state)
297 {
298 struct constant_fold_state *state = void_state;
299
300 nir_foreach_instr_safe(block, instr) {
301 switch (instr->type) {
302 case nir_instr_type_alu:
303 state->progress |= constant_fold_alu_instr(nir_instr_as_alu(instr),
304 state->mem_ctx);
305 break;
306 case nir_instr_type_intrinsic:
307 state->progress |=
308 constant_fold_intrinsic_instr(nir_instr_as_intrinsic(instr));
309 break;
310 case nir_instr_type_tex:
311 state->progress |= constant_fold_tex_instr(nir_instr_as_tex(instr));
312 break;
313 default:
314 /* Don't know how to constant fold */
315 break;
316 }
317 }
318
319 return true;
320 }
321
322 static bool
323 nir_opt_constant_folding_impl(nir_function_impl *impl)
324 {
325 struct constant_fold_state state;
326
327 state.mem_ctx = ralloc_parent(impl);
328 state.impl = impl;
329 state.progress = false;
330
331 nir_foreach_block(impl, constant_fold_block, &state);
332
333 if (state.progress)
334 nir_metadata_preserve(impl, nir_metadata_block_index |
335 nir_metadata_dominance);
336
337 return state.progress;
338 }
339
340 bool
341 nir_opt_constant_folding(nir_shader *shader)
342 {
343 bool progress = false;
344
345 nir_foreach_overload(shader, overload) {
346 if (overload->impl)
347 progress |= nir_opt_constant_folding_impl(overload->impl);
348 }
349
350 return progress;
351 }