2 * Copyright © 2018 Intel Corporation
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
25 #include "nir_builder.h"
26 #include "util/u_vector.h"
29 * Lower flrp instructions.
31 * Unlike the lowerings that are possible in nir_opt_algrbraic, this pass can
32 * examine more global information to determine a possibly more efficient
33 * lowering for each flrp.
37 append_flrp_to_dead_list(struct u_vector
*dead_flrp
, struct nir_alu_instr
*alu
)
39 struct nir_alu_instr
**tail
= u_vector_add(dead_flrp
);
44 * Replace flrp(a, b, c) with ffma(b, c, ffma(-a, c, a)).
47 replace_with_strict_ffma(struct nir_builder
*bld
, struct u_vector
*dead_flrp
,
48 struct nir_alu_instr
*alu
)
50 nir_ssa_def
*const a
= nir_ssa_for_alu_src(bld
, alu
, 0);
51 nir_ssa_def
*const b
= nir_ssa_for_alu_src(bld
, alu
, 1);
52 nir_ssa_def
*const c
= nir_ssa_for_alu_src(bld
, alu
, 2);
54 nir_ssa_def
*const neg_a
= nir_fneg(bld
, a
);
55 nir_instr_as_alu(neg_a
->parent_instr
)->exact
= alu
->exact
;
57 nir_ssa_def
*const inner_ffma
= nir_ffma(bld
, neg_a
, c
, a
);
58 nir_instr_as_alu(inner_ffma
->parent_instr
)->exact
= alu
->exact
;
60 nir_ssa_def
*const outer_ffma
= nir_ffma(bld
, b
, c
, inner_ffma
);
61 nir_instr_as_alu(outer_ffma
->parent_instr
)->exact
= alu
->exact
;
63 nir_ssa_def_rewrite_uses(&alu
->dest
.dest
.ssa
, nir_src_for_ssa(outer_ffma
));
65 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
66 * based on other uses of the sources. Removing the flrp may cause the
67 * last flrp in a sequence to make a different, incorrect choice.
69 append_flrp_to_dead_list(dead_flrp
, alu
);
73 * Replace flrp(a, b, c) with ffma(a, (1 - c), bc)
76 replace_with_single_ffma(struct nir_builder
*bld
, struct u_vector
*dead_flrp
,
77 struct nir_alu_instr
*alu
)
79 nir_ssa_def
*const a
= nir_ssa_for_alu_src(bld
, alu
, 0);
80 nir_ssa_def
*const b
= nir_ssa_for_alu_src(bld
, alu
, 1);
81 nir_ssa_def
*const c
= nir_ssa_for_alu_src(bld
, alu
, 2);
83 nir_ssa_def
*const neg_c
= nir_fneg(bld
, c
);
84 nir_instr_as_alu(neg_c
->parent_instr
)->exact
= alu
->exact
;
86 nir_ssa_def
*const one_minus_c
=
87 nir_fadd(bld
, nir_imm_float(bld
, 1.0f
), neg_c
);
88 nir_instr_as_alu(one_minus_c
->parent_instr
)->exact
= alu
->exact
;
90 nir_ssa_def
*const b_times_c
= nir_fmul(bld
, b
, c
);
91 nir_instr_as_alu(b_times_c
->parent_instr
)->exact
= alu
->exact
;
93 nir_ssa_def
*const final_ffma
= nir_ffma(bld
, a
, one_minus_c
, b_times_c
);
94 nir_instr_as_alu(final_ffma
->parent_instr
)->exact
= alu
->exact
;
96 nir_ssa_def_rewrite_uses(&alu
->dest
.dest
.ssa
, nir_src_for_ssa(final_ffma
));
98 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
99 * based on other uses of the sources. Removing the flrp may cause the
100 * last flrp in a sequence to make a different, incorrect choice.
102 append_flrp_to_dead_list(dead_flrp
, alu
);
106 * Replace flrp(a, b, c) with a(1-c) + bc.
109 replace_with_strict(struct nir_builder
*bld
, struct u_vector
*dead_flrp
,
110 struct nir_alu_instr
*alu
)
112 nir_ssa_def
*const a
= nir_ssa_for_alu_src(bld
, alu
, 0);
113 nir_ssa_def
*const b
= nir_ssa_for_alu_src(bld
, alu
, 1);
114 nir_ssa_def
*const c
= nir_ssa_for_alu_src(bld
, alu
, 2);
116 nir_ssa_def
*const neg_c
= nir_fneg(bld
, c
);
117 nir_instr_as_alu(neg_c
->parent_instr
)->exact
= alu
->exact
;
119 nir_ssa_def
*const one_minus_c
=
120 nir_fadd(bld
, nir_imm_float(bld
, 1.0f
), neg_c
);
121 nir_instr_as_alu(one_minus_c
->parent_instr
)->exact
= alu
->exact
;
123 nir_ssa_def
*const first_product
= nir_fmul(bld
, a
, one_minus_c
);
124 nir_instr_as_alu(first_product
->parent_instr
)->exact
= alu
->exact
;
126 nir_ssa_def
*const second_product
= nir_fmul(bld
, b
, c
);
127 nir_instr_as_alu(second_product
->parent_instr
)->exact
= alu
->exact
;
129 nir_ssa_def
*const sum
= nir_fadd(bld
, first_product
, second_product
);
130 nir_instr_as_alu(sum
->parent_instr
)->exact
= alu
->exact
;
132 nir_ssa_def_rewrite_uses(&alu
->dest
.dest
.ssa
, nir_src_for_ssa(sum
));
134 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
135 * based on other uses of the sources. Removing the flrp may cause the
136 * last flrp in a sequence to make a different, incorrect choice.
138 append_flrp_to_dead_list(dead_flrp
, alu
);
142 * Replace flrp(a, b, c) with a + c(b-a).
145 replace_with_fast(struct nir_builder
*bld
, struct u_vector
*dead_flrp
,
146 struct nir_alu_instr
*alu
)
148 nir_ssa_def
*const a
= nir_ssa_for_alu_src(bld
, alu
, 0);
149 nir_ssa_def
*const b
= nir_ssa_for_alu_src(bld
, alu
, 1);
150 nir_ssa_def
*const c
= nir_ssa_for_alu_src(bld
, alu
, 2);
152 nir_ssa_def
*const neg_a
= nir_fneg(bld
, a
);
153 nir_instr_as_alu(neg_a
->parent_instr
)->exact
= alu
->exact
;
155 nir_ssa_def
*const b_minus_a
= nir_fadd(bld
, b
, neg_a
);
156 nir_instr_as_alu(b_minus_a
->parent_instr
)->exact
= alu
->exact
;
158 nir_ssa_def
*const product
= nir_fmul(bld
, c
, b_minus_a
);
159 nir_instr_as_alu(product
->parent_instr
)->exact
= alu
->exact
;
161 nir_ssa_def
*const sum
= nir_fadd(bld
, a
, product
);
162 nir_instr_as_alu(sum
->parent_instr
)->exact
= alu
->exact
;
164 nir_ssa_def_rewrite_uses(&alu
->dest
.dest
.ssa
, nir_src_for_ssa(sum
));
166 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
167 * based on other uses of the sources. Removing the flrp may cause the
168 * last flrp in a sequence to make a different, incorrect choice.
170 append_flrp_to_dead_list(dead_flrp
, alu
);
174 * Replace flrp(a, b, c) with (b*c ± c) + a
177 replace_with_expanded_ffma_and_add(struct nir_builder
*bld
,
178 struct u_vector
*dead_flrp
,
179 struct nir_alu_instr
*alu
, bool subtract_c
)
181 nir_ssa_def
*const a
= nir_ssa_for_alu_src(bld
, alu
, 0);
182 nir_ssa_def
*const b
= nir_ssa_for_alu_src(bld
, alu
, 1);
183 nir_ssa_def
*const c
= nir_ssa_for_alu_src(bld
, alu
, 2);
185 nir_ssa_def
*const b_times_c
= nir_fadd(bld
, b
, c
);
186 nir_instr_as_alu(b_times_c
->parent_instr
)->exact
= alu
->exact
;
188 nir_ssa_def
*inner_sum
;
191 nir_ssa_def
*const neg_c
= nir_fneg(bld
, c
);
192 nir_instr_as_alu(neg_c
->parent_instr
)->exact
= alu
->exact
;
194 inner_sum
= nir_fadd(bld
, b_times_c
, neg_c
);
196 inner_sum
= nir_fadd(bld
, b_times_c
, c
);
199 nir_instr_as_alu(inner_sum
->parent_instr
)->exact
= alu
->exact
;
201 nir_ssa_def
*const outer_sum
= nir_fadd(bld
, inner_sum
, a
);
202 nir_instr_as_alu(outer_sum
->parent_instr
)->exact
= alu
->exact
;
204 nir_ssa_def_rewrite_uses(&alu
->dest
.dest
.ssa
, nir_src_for_ssa(outer_sum
));
206 /* DO NOT REMOVE the original flrp yet. Many of the lowering choices are
207 * based on other uses of the sources. Removing the flrp may cause the
208 * last flrp in a sequence to make a different, incorrect choice.
210 append_flrp_to_dead_list(dead_flrp
, alu
);
214 * Determines whether a swizzled source is constant w/ all components the same.
216 * The value of the constant is stored in \c result.
219 * True if all components of the swizzled source are the same constant.
220 * Otherwise false is returned.
223 all_same_constant(const nir_alu_instr
*instr
, unsigned src
, double *result
)
225 nir_const_value
*val
= nir_src_as_const_value(instr
->src
[src
].src
);
230 const uint8_t *const swizzle
= instr
->src
[src
].swizzle
;
231 const unsigned num_components
= nir_dest_num_components(instr
->dest
.dest
);
233 if (instr
->dest
.dest
.ssa
.bit_size
== 32) {
234 const float first
= val
[swizzle
[0]].f32
;
236 for (unsigned i
= 1; i
< num_components
; i
++) {
237 if (val
[swizzle
[i
]].f32
!= first
)
243 const double first
= val
[swizzle
[0]].f64
;
245 for (unsigned i
= 1; i
< num_components
; i
++) {
246 if (val
[swizzle
[i
]].f64
!= first
)
257 sources_are_constants_with_similar_magnitudes(const nir_alu_instr
*instr
)
259 nir_const_value
*val0
= nir_src_as_const_value(instr
->src
[0].src
);
260 nir_const_value
*val1
= nir_src_as_const_value(instr
->src
[1].src
);
262 if (val0
== NULL
|| val1
== NULL
)
265 const uint8_t *const swizzle0
= instr
->src
[0].swizzle
;
266 const uint8_t *const swizzle1
= instr
->src
[1].swizzle
;
267 const unsigned num_components
= nir_dest_num_components(instr
->dest
.dest
);
269 if (instr
->dest
.dest
.ssa
.bit_size
== 32) {
270 for (unsigned i
= 0; i
< num_components
; i
++) {
274 frexpf(val0
[swizzle0
[i
]].f32
, &exp0
);
275 frexpf(val1
[swizzle1
[i
]].f32
, &exp1
);
277 /* If the difference between exponents is >= 24, then A+B will always
278 * have the value whichever between A and B has the largest absolute
279 * value. So, [0, 23] is the valid range. The smaller the limit
280 * value, the more precision will be maintained at a potential
281 * performance cost. Somewhat arbitrarilly split the range in half.
283 if (abs(exp0
- exp1
) > (23 / 2))
287 for (unsigned i
= 0; i
< num_components
; i
++) {
291 frexp(val0
[swizzle0
[i
]].f64
, &exp0
);
292 frexp(val1
[swizzle1
[i
]].f64
, &exp1
);
294 /* If the difference between exponents is >= 53, then A+B will always
295 * have the value whichever between A and B has the largest absolute
296 * value. So, [0, 52] is the valid range. The smaller the limit
297 * value, the more precision will be maintained at a potential
298 * performance cost. Somewhat arbitrarilly split the range in half.
300 if (abs(exp0
- exp1
) > (52 / 2))
309 * Counts of similar types of nir_op_flrp instructions
311 * If a similar instruction fits into more than one category, it will only be
312 * counted once. The assumption is that no other instruction will have all
313 * sources the same, or CSE would have removed one of the instructions.
315 struct similar_flrp_stats
{
317 unsigned src0_and_src2
;
318 unsigned src1_and_src2
;
322 * Collection counts of similar FLRP instructions.
324 * This function only cares about similar instructions that have src2 in
328 get_similar_flrp_stats(nir_alu_instr
*alu
, struct similar_flrp_stats
*st
)
330 memset(st
, 0, sizeof(*st
));
332 nir_foreach_use(other_use
, alu
->src
[2].src
.ssa
) {
333 /* Is the use also a flrp? */
334 nir_instr
*const other_instr
= other_use
->parent_instr
;
335 if (other_instr
->type
!= nir_instr_type_alu
)
338 /* Eh-hem... don't match the instruction with itself. */
339 if (other_instr
== &alu
->instr
)
342 nir_alu_instr
*const other_alu
= nir_instr_as_alu(other_instr
);
343 if (other_alu
->op
!= nir_op_flrp
)
346 /* Does the other flrp use source 2 from the first flrp as its source 2
349 if (!nir_alu_srcs_equal(alu
, other_alu
, 2, 2))
352 if (nir_alu_srcs_equal(alu
, other_alu
, 0, 0))
354 else if (nir_alu_srcs_equal(alu
, other_alu
, 1, 1))
362 convert_flrp_instruction(nir_builder
*bld
,
363 struct u_vector
*dead_flrp
,
368 bld
->cursor
= nir_before_instr(&alu
->instr
);
370 /* There are two methods to implement flrp(x, y, t). The strictly correct
371 * implementation according to the GLSL spec is:
375 * This can also be implemented using two chained FMAs
377 * fma(y, t, fma(-x, t, x))
379 * This method, using either formulation, has better precision when the
380 * difference between x and y is very large. It guarantess that flrp(x, y,
381 * 1) = y. For example, flrp(1e38, 1.0, 1.0) is 1.0. This is correct.
383 * The other possible implementation is:
387 * This can also be formuated as an FMA:
391 * For this implementation, flrp(1e38, 1.0, 1.0) is 0.0. Since 1.0 was
392 * expected, that's a pretty significant error.
394 * The choice made for lowering depends on a number of factors.
396 * - If the flrp is marked precise and FMA is supported:
398 * fma(y, t, fma(-x, t, x))
400 * This is strictly correct (maybe?), and the cost is two FMA
401 * instructions. It at least maintains the flrp(x, y, 1.0) == y
404 * - If the flrp is marked precise and FMA is not supported:
408 * This is strictly correct, and the cost is 4 instructions. If FMA is
409 * supported, this may or may not be reduced to 3 instructions (a
410 * subtract, a multiply, and an FMA)... but in that case the other
411 * formulation should have been used.
415 replace_with_strict_ffma(bld
, dead_flrp
, alu
);
417 replace_with_strict(bld
, dead_flrp
, alu
);
423 * - If x and y are both immediates and the relative magnitude of the
424 * values is similar (such that x-y does not lose too much precision):
428 * We rely on constant folding to eliminate x-y, and we rely on
429 * nir_opt_algebraic to possibly generate an FMA. The cost is either one
430 * FMA or two instructions.
432 if (sources_are_constants_with_similar_magnitudes(alu
)) {
433 replace_with_fast(bld
, dead_flrp
, alu
);
446 * In both cases, x is used in place of ±1 for simplicity. Both forms
447 * lend to ffma generation on platforms that support ffma.
449 double src0_as_constant
;
450 if (all_same_constant(alu
, 0, &src0_as_constant
)) {
451 if (src0_as_constant
== 1.0) {
452 replace_with_expanded_ffma_and_add(bld
, dead_flrp
, alu
,
453 true /* subtract t */);
455 } else if (src0_as_constant
== -1.0) {
456 replace_with_expanded_ffma_and_add(bld
, dead_flrp
, alu
,
467 * In this case either the multiply in yt will be eliminated by
468 * nir_opt_algebraic. If FMA is supported, this results in fma(x, (1 -
469 * t), ±t) for two instructions. If FMA is not supported, then the cost
470 * is 3 instructions. We rely on nir_opt_algebraic to generate the FMA
471 * instructions as well.
473 * Another possible replacement is
477 * Some groupings of this may be better on some platforms in some
478 * circumstances, bit it is probably dependent on scheduling. Futher
479 * investigation may be required.
481 double src1_as_constant
;
482 if ((all_same_constant(alu
, 1, &src1_as_constant
) &&
483 (src1_as_constant
== -1.0 || src1_as_constant
== 1.0))) {
484 replace_with_strict(bld
, dead_flrp
, alu
);
489 if (always_precise
) {
490 replace_with_strict_ffma(bld
, dead_flrp
, alu
);
495 * - If FMA is supported and other flrp(x, _, t) exists:
497 * fma(y, t, fma(-x, t, x))
499 * The hope is that the inner FMA calculation will be shared with the
500 * other lowered flrp. This results in two FMA instructions for the
501 * first flrp and one FMA instruction for each additional flrp. It
502 * also means that the live range for x might be complete after the
503 * inner ffma instead of after the last flrp.
505 struct similar_flrp_stats st
;
507 get_similar_flrp_stats(alu
, &st
);
508 if (st
.src0_and_src2
> 0) {
509 replace_with_strict_ffma(bld
, dead_flrp
, alu
);
514 * - If FMA is supported and another flrp(_, y, t) exists:
516 * fma(x, (1 - t), yt)
518 * The hope is that the (1 - t) and the yt will be shared with the
519 * other lowered flrp. This results in 3 insructions for the first
520 * flrp and 1 for each additional flrp.
522 if (st
.src1_and_src2
> 0) {
523 replace_with_single_ffma(bld
, dead_flrp
, alu
);
527 if (always_precise
) {
528 replace_with_strict(bld
, dead_flrp
, alu
);
533 * - If FMA is not supported and another flrp(x, _, t) exists:
537 * The hope is that the x(1 - t) will be shared with the other lowered
538 * flrp. This results in 4 insructions for the first flrp and 2 for
539 * each additional flrp.
541 * - If FMA is not supported and another flrp(_, y, t) exists:
545 * The hope is that the (1 - t) and the yt will be shared with the
546 * other lowered flrp. This results in 4 insructions for the first
547 * flrp and 2 for each additional flrp.
549 struct similar_flrp_stats st
;
551 get_similar_flrp_stats(alu
, &st
);
552 if (st
.src0_and_src2
> 0 || st
.src1_and_src2
> 0) {
553 replace_with_strict(bld
, dead_flrp
, alu
);
559 * - If t is constant:
563 * The cost is three instructions without FMA or two instructions with
564 * FMA. This is the same cost as the imprecise lowering, but it gives
565 * the instruction scheduler a little more freedom.
567 * There is no need to handle t = 0.5 specially. nir_opt_algebraic
568 * already has optimizations to convert 0.5x + 0.5y to 0.5(x + y).
570 if (alu
->src
[2].src
.ssa
->parent_instr
->type
== nir_instr_type_load_const
) {
571 replace_with_strict(bld
, dead_flrp
, alu
);
580 replace_with_fast(bld
, dead_flrp
, alu
);
584 lower_flrp_impl(nir_function_impl
*impl
,
585 struct u_vector
*dead_flrp
,
586 unsigned lowering_mask
,
591 nir_builder_init(&b
, impl
);
593 nir_foreach_block(block
, impl
) {
594 nir_foreach_instr_safe(instr
, block
) {
595 if (instr
->type
== nir_instr_type_alu
) {
596 nir_alu_instr
*const alu
= nir_instr_as_alu(instr
);
598 if (alu
->op
== nir_op_flrp
&&
599 (alu
->dest
.dest
.ssa
.bit_size
& lowering_mask
)) {
600 convert_flrp_instruction(&b
, dead_flrp
, alu
, always_precise
,
607 nir_metadata_preserve(impl
, nir_metadata_block_index
|
608 nir_metadata_dominance
);
612 * \param lowering_mask - Bitwise-or of the bit sizes that need to be lowered
613 * (e.g., 16 | 64 if only 16-bit and 64-bit flrp need
615 * \param always_precise - Always require precise lowering for flrp. This
616 * will always lower flrp to (a * (1 - c)) + (b * c).
617 * \param have_ffma - Set to true if the GPU has an FFMA instruction that
621 nir_lower_flrp(nir_shader
*shader
,
622 unsigned lowering_mask
,
626 struct u_vector dead_flrp
;
628 if (!u_vector_init(&dead_flrp
, sizeof(struct nir_alu_instr
*), 64))
631 nir_foreach_function(function
, shader
) {
632 if (function
->impl
) {
633 lower_flrp_impl(function
->impl
, &dead_flrp
, lowering_mask
,
634 always_precise
, have_ffma
);
638 /* Progress was made if the dead list is not empty. Remove all the
639 * instructions from the dead list.
641 const bool progress
= u_vector_length(&dead_flrp
) != 0;
643 struct nir_alu_instr
**instr
;
644 u_vector_foreach(instr
, &dead_flrp
)
645 nir_instr_remove(&(*instr
)->instr
);
647 u_vector_finish(&dead_flrp
);