nir: Rename parallel_copy_copy to parallel_copy_entry and add a foreach macro
[mesa.git] / src / glsl / nir / nir_opt_constant_folding.c
1 /*
2 * Copyright © 2014 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Jason Ekstrand (jason@jlekstrand.net)
25 *
26 */
27
28 #include "nir.h"
29 #include <math.h>
30
31 /*
32 * Implements SSA-based constant folding.
33 */
34
35 struct constant_fold_state {
36 void *mem_ctx;
37 nir_function_impl *impl;
38 bool progress;
39 };
40
41 #define SRC_COMP(T, IDX, CMP) src[IDX]->value.T[instr->src[IDX].swizzle[CMP]]
42 #define SRC(T, IDX) SRC_COMP(T, IDX, i)
43 #define DEST_COMP(T, CMP) dest->value.T[CMP]
44 #define DEST(T) DEST_COMP(T, i)
45
46 #define FOLD_PER_COMP(EXPR) \
47 for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; i++) { \
48 EXPR; \
49 } \
50
51 static bool
52 constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
53 {
54 nir_load_const_instr *src[4], *dest;
55
56 if (!instr->dest.dest.is_ssa)
57 return false;
58
59 for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
60 if (!instr->src[i].src.is_ssa)
61 return false;
62
63 if (instr->src[i].src.ssa->parent_instr->type != nir_instr_type_load_const)
64 return false;
65
66 /* We shouldn't have any source modifiers in the optimization loop. */
67 assert(!instr->src[i].abs && !instr->src[i].negate);
68
69 src[i] = nir_instr_as_load_const(instr->src[i].src.ssa->parent_instr);
70 }
71
72 /* We shouldn't have any saturate modifiers in the optimization loop. */
73 assert(!instr->dest.saturate);
74
75 dest = nir_load_const_instr_create(mem_ctx,
76 instr->dest.dest.ssa.num_components);
77
78 switch (instr->op) {
79 case nir_op_ineg:
80 FOLD_PER_COMP(DEST(i) = -SRC(i, 0));
81 break;
82 case nir_op_fneg:
83 FOLD_PER_COMP(DEST(f) = -SRC(f, 0));
84 break;
85 case nir_op_inot:
86 FOLD_PER_COMP(DEST(i) = ~SRC(i, 0));
87 break;
88 case nir_op_fnot:
89 FOLD_PER_COMP(DEST(f) = (SRC(f, 0) == 0.0f) ? 1.0f : 0.0f);
90 break;
91 case nir_op_frcp:
92 FOLD_PER_COMP(DEST(f) = 1.0f / SRC(f, 0));
93 break;
94 case nir_op_frsq:
95 FOLD_PER_COMP(DEST(f) = 1.0f / sqrt(SRC(f, 0)));
96 break;
97 case nir_op_fsqrt:
98 FOLD_PER_COMP(DEST(f) = sqrtf(SRC(f, 0)));
99 break;
100 case nir_op_fexp:
101 FOLD_PER_COMP(DEST(f) = expf(SRC(f, 0)));
102 break;
103 case nir_op_flog:
104 FOLD_PER_COMP(DEST(f) = logf(SRC(f, 0)));
105 break;
106 case nir_op_fexp2:
107 FOLD_PER_COMP(DEST(f) = exp2f(SRC(f, 0)));
108 break;
109 case nir_op_flog2:
110 FOLD_PER_COMP(DEST(f) = log2f(SRC(f, 0)));
111 break;
112 case nir_op_f2i:
113 FOLD_PER_COMP(DEST(i) = SRC(f, 0));
114 break;
115 case nir_op_f2u:
116 FOLD_PER_COMP(DEST(u) = SRC(f, 0));
117 break;
118 case nir_op_i2f:
119 FOLD_PER_COMP(DEST(f) = SRC(i, 0));
120 break;
121 case nir_op_f2b:
122 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) == 0.0f) ? NIR_FALSE : NIR_TRUE);
123 break;
124 case nir_op_b2f:
125 FOLD_PER_COMP(DEST(f) = SRC(u, 0) ? 1.0f : 0.0f);
126 break;
127 case nir_op_i2b:
128 FOLD_PER_COMP(DEST(u) = SRC(i, 0) ? NIR_TRUE : NIR_FALSE);
129 break;
130 case nir_op_u2f:
131 FOLD_PER_COMP(DEST(f) = SRC(u, 0));
132 break;
133 case nir_op_bany2:
134 DEST_COMP(u, 0) = (SRC_COMP(u, 0, 0) || SRC_COMP(u, 0, 1)) ?
135 NIR_TRUE : NIR_FALSE;
136 break;
137 case nir_op_fadd:
138 FOLD_PER_COMP(DEST(f) = SRC(f, 0) + SRC(f, 1));
139 break;
140 case nir_op_iadd:
141 FOLD_PER_COMP(DEST(i) = SRC(i, 0) + SRC(i, 1));
142 break;
143 case nir_op_fsub:
144 FOLD_PER_COMP(DEST(f) = SRC(f, 0) - SRC(f, 1));
145 break;
146 case nir_op_isub:
147 FOLD_PER_COMP(DEST(i) = SRC(i, 0) - SRC(i, 1));
148 break;
149 case nir_op_fmul:
150 FOLD_PER_COMP(DEST(f) = SRC(f, 0) * SRC(f, 1));
151 break;
152 case nir_op_imul:
153 FOLD_PER_COMP(DEST(i) = SRC(i, 0) * SRC(i, 1));
154 break;
155 case nir_op_fdiv:
156 FOLD_PER_COMP(DEST(f) = SRC(f, 0) / SRC(f, 1));
157 break;
158 case nir_op_idiv:
159 FOLD_PER_COMP(DEST(i) = SRC(i, 0) / SRC(i, 1));
160 break;
161 case nir_op_udiv:
162 FOLD_PER_COMP(DEST(u) = SRC(u, 0) / SRC(u, 1));
163 break;
164 case nir_op_flt:
165 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) < SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
166 break;
167 case nir_op_fge:
168 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) >= SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
169 break;
170 case nir_op_feq:
171 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) == SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
172 break;
173 case nir_op_fne:
174 FOLD_PER_COMP(DEST(u) = (SRC(f, 0) != SRC(f, 1)) ? NIR_TRUE : NIR_FALSE);
175 break;
176 case nir_op_ilt:
177 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) < SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
178 break;
179 case nir_op_ige:
180 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) >= SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
181 break;
182 case nir_op_ieq:
183 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) == SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
184 break;
185 case nir_op_ine:
186 FOLD_PER_COMP(DEST(u) = (SRC(i, 0) != SRC(i, 1)) ? NIR_TRUE : NIR_FALSE);
187 break;
188 case nir_op_ult:
189 FOLD_PER_COMP(DEST(u) = (SRC(u, 0) < SRC(u, 1)) ? NIR_TRUE : NIR_FALSE);
190 break;
191 case nir_op_uge:
192 FOLD_PER_COMP(DEST(u) = (SRC(u, 0) >= SRC(u, 1)) ? NIR_TRUE : NIR_FALSE);
193 break;
194 case nir_op_ishl:
195 FOLD_PER_COMP(DEST(i) = SRC(i, 0) << SRC(i, 1));
196 break;
197 case nir_op_ishr:
198 FOLD_PER_COMP(DEST(i) = SRC(i, 0) >> SRC(i, 1));
199 break;
200 case nir_op_ushr:
201 FOLD_PER_COMP(DEST(u) = SRC(u, 0) >> SRC(u, 1));
202 break;
203 case nir_op_iand:
204 FOLD_PER_COMP(DEST(i) = SRC(i, 0) & SRC(i, 1));
205 break;
206 case nir_op_ior:
207 FOLD_PER_COMP(DEST(i) = SRC(i, 0) | SRC(i, 1));
208 break;
209 case nir_op_ixor:
210 FOLD_PER_COMP(DEST(i) = SRC(i, 0) ^ SRC(i, 1));
211 break;
212 default:
213 ralloc_free(dest);
214 return false;
215 }
216
217 nir_instr_insert_before(&instr->instr, &dest->instr);
218
219 nir_src new_src = {
220 .is_ssa = true,
221 .ssa = &dest->def,
222 };
223
224 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, new_src, mem_ctx);
225
226 nir_instr_remove(&instr->instr);
227 ralloc_free(instr);
228
229 return true;
230 }
231
232 static bool
233 constant_fold_deref(nir_instr *instr, nir_deref_var *deref)
234 {
235 bool progress = false;
236
237 for (nir_deref *tail = deref->deref.child; tail; tail = tail->child) {
238 if (tail->deref_type != nir_deref_type_array)
239 continue;
240
241 nir_deref_array *arr = nir_deref_as_array(tail);
242
243 if (arr->deref_array_type == nir_deref_array_type_indirect &&
244 arr->indirect.is_ssa &&
245 arr->indirect.ssa->parent_instr->type == nir_instr_type_load_const) {
246 nir_load_const_instr *indirect =
247 nir_instr_as_load_const(arr->indirect.ssa->parent_instr);
248
249 arr->base_offset += indirect->value.u[0];
250
251 nir_src empty = {
252 .is_ssa = true,
253 .ssa = NULL,
254 };
255
256 nir_instr_rewrite_src(instr, &arr->indirect, empty);
257
258 arr->deref_array_type = nir_deref_array_type_direct;
259
260 progress = true;
261 }
262 }
263
264 return progress;
265 }
266
267 static bool
268 constant_fold_intrinsic_instr(nir_intrinsic_instr *instr)
269 {
270 bool progress = false;
271
272 unsigned num_vars = nir_intrinsic_infos[instr->intrinsic].num_variables;
273 for (unsigned i = 0; i < num_vars; i++) {
274 progress |= constant_fold_deref(&instr->instr, instr->variables[i]);
275 }
276
277 return progress;
278 }
279
280 static bool
281 constant_fold_tex_instr(nir_tex_instr *instr)
282 {
283 if (instr->sampler)
284 return constant_fold_deref(&instr->instr, instr->sampler);
285 else
286 return false;
287 }
288
289 static bool
290 constant_fold_block(nir_block *block, void *void_state)
291 {
292 struct constant_fold_state *state = void_state;
293
294 nir_foreach_instr_safe(block, instr) {
295 switch (instr->type) {
296 case nir_instr_type_alu:
297 state->progress |= constant_fold_alu_instr(nir_instr_as_alu(instr),
298 state->mem_ctx);
299 break;
300 case nir_instr_type_intrinsic:
301 state->progress |=
302 constant_fold_intrinsic_instr(nir_instr_as_intrinsic(instr));
303 break;
304 case nir_instr_type_tex:
305 state->progress |= constant_fold_tex_instr(nir_instr_as_tex(instr));
306 break;
307 default:
308 /* Don't know how to constant fold */
309 break;
310 }
311 }
312
313 return true;
314 }
315
316 static bool
317 nir_opt_constant_folding_impl(nir_function_impl *impl)
318 {
319 struct constant_fold_state state;
320
321 state.mem_ctx = ralloc_parent(impl);
322 state.impl = impl;
323 state.progress = false;
324
325 nir_foreach_block(impl, constant_fold_block, &state);
326
327 if (state.progress)
328 nir_metadata_preserve(impl, nir_metadata_block_index |
329 nir_metadata_dominance);
330
331 return state.progress;
332 }
333
334 bool
335 nir_opt_constant_folding(nir_shader *shader)
336 {
337 bool progress = false;
338
339 nir_foreach_overload(shader, overload) {
340 if (overload->impl)
341 progress |= nir_opt_constant_folding_impl(overload->impl);
342 }
343
344 return progress;
345 }