nir/flrp: Lower flrp(#a, #b, c) differently
[mesa.git] / src / compiler / nir / nir_lower_flrp.c
1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23 #include <math.h>
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "util/u_vector.h"
27
28 /**
29 * Lower flrp instructions.
30 *
31 * Unlike the lowerings that are possible in nir_opt_algrbraic, this pass can
32 * examine more global information to determine a possibly more efficient
33 * lowering for each flrp.
34 */
35
36 static void
37 append_flrp_to_dead_list(struct u_vector *dead_flrp, struct nir_alu_instr *alu)
38 {
39 struct nir_alu_instr **tail = u_vector_add(dead_flrp);
40 *tail = alu;
41 }
42
43 /**
44 * Replace flrp(a, b, c) with ffma(b, c, ffma(-a, c, a)).
45 */
46 static void
47 replace_with_strict_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
48 struct nir_alu_instr *alu)
49 {
50 nir_ssa_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
51 nir_ssa_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
52 nir_ssa_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
53
54 nir_ssa_def *const neg_a = nir_fneg(bld, a);
55 nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
56
57 nir_ssa_def *const inner_ffma = nir_ffma(bld, neg_a, c, a);
58 nir_instr_as_alu(inner_ffma->parent_instr)->exact = alu->exact;
59
60 nir_ssa_def *const outer_ffma = nir_ffma(bld, b, c, inner_ffma);
61 nir_instr_as_alu(outer_ffma->parent_instr)->exact = alu->exact;
62
63 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(outer_ffma));
64
65 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
66 * based on other uses of the sources. Removing the flrp may cause the
67 * last flrp in a sequence to make a different, incorrect choice.
68 */
69 append_flrp_to_dead_list(dead_flrp, alu);
70 }
71
72 /**
73 * Replace flrp(a, b, c) with a(1-c) + bc.
74 */
75 static void
76 replace_with_strict(struct nir_builder *bld, struct u_vector *dead_flrp,
77 struct nir_alu_instr *alu)
78 {
79 nir_ssa_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
80 nir_ssa_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
81 nir_ssa_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
82
83 nir_ssa_def *const neg_c = nir_fneg(bld, c);
84 nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
85
86 nir_ssa_def *const one_minus_c =
87 nir_fadd(bld, nir_imm_float(bld, 1.0f), neg_c);
88 nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
89
90 nir_ssa_def *const first_product = nir_fmul(bld, a, one_minus_c);
91 nir_instr_as_alu(first_product->parent_instr)->exact = alu->exact;
92
93 nir_ssa_def *const second_product = nir_fmul(bld, b, c);
94 nir_instr_as_alu(second_product->parent_instr)->exact = alu->exact;
95
96 nir_ssa_def *const sum = nir_fadd(bld, first_product, second_product);
97 nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
98
99 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(sum));
100
101 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
102 * based on other uses of the sources. Removing the flrp may cause the
103 * last flrp in a sequence to make a different, incorrect choice.
104 */
105 append_flrp_to_dead_list(dead_flrp, alu);
106 }
107
108 /**
109 * Replace flrp(a, b, c) with a + c(b-a).
110 */
111 static void
112 replace_with_fast(struct nir_builder *bld, struct u_vector *dead_flrp,
113 struct nir_alu_instr *alu)
114 {
115 nir_ssa_def *const a = nir_ssa_for_alu_src(bld, alu, 0);
116 nir_ssa_def *const b = nir_ssa_for_alu_src(bld, alu, 1);
117 nir_ssa_def *const c = nir_ssa_for_alu_src(bld, alu, 2);
118
119 nir_ssa_def *const neg_a = nir_fneg(bld, a);
120 nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
121
122 nir_ssa_def *const b_minus_a = nir_fadd(bld, b, neg_a);
123 nir_instr_as_alu(b_minus_a->parent_instr)->exact = alu->exact;
124
125 nir_ssa_def *const product = nir_fmul(bld, c, b_minus_a);
126 nir_instr_as_alu(product->parent_instr)->exact = alu->exact;
127
128 nir_ssa_def *const sum = nir_fadd(bld, a, product);
129 nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
130
131 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(sum));
132
133 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
134 * based on other uses of the sources. Removing the flrp may cause the
135 * last flrp in a sequence to make a different, incorrect choice.
136 */
137 append_flrp_to_dead_list(dead_flrp, alu);
138 }
139
140 static bool
141 sources_are_constants_with_similar_magnitudes(const nir_alu_instr *instr)
142 {
143 nir_const_value *val0 = nir_src_as_const_value(instr->src[0].src);
144 nir_const_value *val1 = nir_src_as_const_value(instr->src[1].src);
145
146 if (val0 == NULL || val1 == NULL)
147 return false;
148
149 const uint8_t *const swizzle0 = instr->src[0].swizzle;
150 const uint8_t *const swizzle1 = instr->src[1].swizzle;
151 const unsigned num_components = nir_dest_num_components(instr->dest.dest);
152
153 if (instr->dest.dest.ssa.bit_size == 32) {
154 for (unsigned i = 0; i < num_components; i++) {
155 int exp0;
156 int exp1;
157
158 frexpf(val0[swizzle0[i]].f32, &exp0);
159 frexpf(val1[swizzle1[i]].f32, &exp1);
160
161 /* If the difference between exponents is >= 24, then A+B will always
162 * have the value whichever between A and B has the largest absolute
163 * value. So, [0, 23] is the valid range. The smaller the limit
164 * value, the more precision will be maintained at a potential
165 * performance cost. Somewhat arbitrarilly split the range in half.
166 */
167 if (abs(exp0 - exp1) > (23 / 2))
168 return false;
169 }
170 } else {
171 for (unsigned i = 0; i < num_components; i++) {
172 int exp0;
173 int exp1;
174
175 frexp(val0[swizzle0[i]].f64, &exp0);
176 frexp(val1[swizzle1[i]].f64, &exp1);
177
178 /* If the difference between exponents is >= 53, then A+B will always
179 * have the value whichever between A and B has the largest absolute
180 * value. So, [0, 52] is the valid range. The smaller the limit
181 * value, the more precision will be maintained at a potential
182 * performance cost. Somewhat arbitrarilly split the range in half.
183 */
184 if (abs(exp0 - exp1) > (52 / 2))
185 return false;
186 }
187 }
188
189 return true;
190 }
191
192 static void
193 convert_flrp_instruction(nir_builder *bld,
194 struct u_vector *dead_flrp,
195 nir_alu_instr *alu,
196 bool always_precise,
197 bool have_ffma)
198 {
199 bld->cursor = nir_before_instr(&alu->instr);
200
201 /* There are two methods to implement flrp(x, y, t). The strictly correct
202 * implementation according to the GLSL spec is:
203 *
204 * x(1 - t) + yt
205 *
206 * This can also be implemented using two chained FMAs
207 *
208 * fma(y, t, fma(-x, t, x))
209 *
210 * This method, using either formulation, has better precision when the
211 * difference between x and y is very large. It guarantess that flrp(x, y,
212 * 1) = y. For example, flrp(1e38, 1.0, 1.0) is 1.0. This is correct.
213 *
214 * The other possible implementation is:
215 *
216 * x + t(y - x)
217 *
218 * This can also be formuated as an FMA:
219 *
220 * fma(y - x, t, x)
221 *
222 * For this implementation, flrp(1e38, 1.0, 1.0) is 0.0. Since 1.0 was
223 * expected, that's a pretty significant error.
224 *
225 * The choice made for lowering depends on a number of factors.
226 *
227 * - If the flrp is marked precise and FMA is supported:
228 *
229 * fma(y, t, fma(-x, t, x))
230 *
231 * This is strictly correct (maybe?), and the cost is two FMA
232 * instructions. It at least maintains the flrp(x, y, 1.0) == y
233 * condition.
234 *
235 * - If the flrp is marked precise and FMA is not supported:
236 *
237 * x(1 - t) + yt
238 *
239 * This is strictly correct, and the cost is 4 instructions. If FMA is
240 * supported, this may or may not be reduced to 3 instructions (a
241 * subtract, a multiply, and an FMA)... but in that case the other
242 * formulation should have been used.
243 */
244 if (alu->exact) {
245 if (have_ffma)
246 replace_with_strict_ffma(bld, dead_flrp, alu);
247 else
248 replace_with_strict(bld, dead_flrp, alu);
249
250 return;
251 }
252
253 /*
254 * - If x and y are both immediates and the relative magnitude of the
255 * values is similar (such that x-y does not lose too much precision):
256 *
257 * x + t(x - y)
258 *
259 * We rely on constant folding to eliminate x-y, and we rely on
260 * nir_opt_algebraic to possibly generate an FMA. The cost is either one
261 * FMA or two instructions.
262 */
263 if (sources_are_constants_with_similar_magnitudes(alu)) {
264 replace_with_fast(bld, dead_flrp, alu);
265 return;
266 }
267
268 if (have_ffma) {
269 if (always_precise) {
270 replace_with_strict_ffma(bld, dead_flrp, alu);
271 return;
272 }
273 } else {
274 if (always_precise) {
275 replace_with_strict(bld, dead_flrp, alu);
276 return;
277 }
278 }
279
280 /*
281 * - Otherwise
282 *
283 * x + t(x - y)
284 */
285 replace_with_fast(bld, dead_flrp, alu);
286 }
287
288 static void
289 lower_flrp_impl(nir_function_impl *impl,
290 struct u_vector *dead_flrp,
291 unsigned lowering_mask,
292 bool always_precise,
293 bool have_ffma)
294 {
295 nir_builder b;
296 nir_builder_init(&b, impl);
297
298 nir_foreach_block(block, impl) {
299 nir_foreach_instr_safe(instr, block) {
300 if (instr->type == nir_instr_type_alu) {
301 nir_alu_instr *const alu = nir_instr_as_alu(instr);
302
303 if (alu->op == nir_op_flrp &&
304 (alu->dest.dest.ssa.bit_size & lowering_mask)) {
305 convert_flrp_instruction(&b, dead_flrp, alu, always_precise,
306 have_ffma);
307 }
308 }
309 }
310 }
311
312 nir_metadata_preserve(impl, nir_metadata_block_index |
313 nir_metadata_dominance);
314 }
315
316 /**
317 * \param lowering_mask - Bitwise-or of the bit sizes that need to be lowered
318 * (e.g., 16 | 64 if only 16-bit and 64-bit flrp need
319 * lowering).
320 * \param always_precise - Always require precise lowering for flrp. This
321 * will always lower flrp to (a * (1 - c)) + (b * c).
322 * \param have_ffma - Set to true if the GPU has an FFMA instruction that
323 * should be used.
324 */
325 bool
326 nir_lower_flrp(nir_shader *shader,
327 unsigned lowering_mask,
328 bool always_precise,
329 bool have_ffma)
330 {
331 struct u_vector dead_flrp;
332
333 if (!u_vector_init(&dead_flrp, sizeof(struct nir_alu_instr *), 64))
334 return false;
335
336 nir_foreach_function(function, shader) {
337 if (function->impl) {
338 lower_flrp_impl(function->impl, &dead_flrp, lowering_mask,
339 always_precise, have_ffma);
340 }
341 }
342
343 /* Progress was made if the dead list is not empty. Remove all the
344 * instructions from the dead list.
345 */
346 const bool progress = u_vector_length(&dead_flrp) != 0;
347
348 struct nir_alu_instr **instr;
349 u_vector_foreach(instr, &dead_flrp)
350 nir_instr_remove(&(*instr)->instr);
351
352 u_vector_finish(&dead_flrp);
353
354 return progress;
355 }