nir: Add an ALU lowering pass for mul_high.
[mesa.git] / src / compiler / nir / nir_lower_alu.c
1 /*
2 * Copyright © 2010 Intel Corporation
3 * Copyright © 2018 Broadcom
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a
6 * copy of this software and associated documentation files (the "Software"),
7 * to deal in the Software without restriction, including without limitation
8 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9 * and/or sell copies of the Software, and to permit persons to whom the
10 * Software is furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice (including the next
13 * paragraph) shall be included in all copies or substantial portions of the
14 * Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
19 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22 * DEALINGS IN THE SOFTWARE.
23 */
24
25 #include "nir.h"
26 #include "nir_builder.h"
27
28 /** nir_lower_alu.c
29 *
30 * NIR's home for miscellaneous ALU operation lowering implementations.
31 *
32 * Most NIR ALU lowering occurs in nir_opt_algebraic.py, since it's generally
33 * easy to write them there. However, if terms appear multiple times in the
34 * lowered code, it can get very verbose and cause a lot of work for CSE, so
35 * it may end up being easier to write out in C code.
36 *
37 * The shader must be in SSA for this pass.
38 */
39
40 #define LOWER_MUL_HIGH (1 << 0)
41
42 static bool
43 lower_alu_instr(nir_alu_instr *instr, nir_builder *b)
44 {
45 nir_ssa_def *lowered = NULL;
46
47 assert(instr->dest.dest.is_ssa);
48
49 b->cursor = nir_before_instr(&instr->instr);
50 b->exact = instr->exact;
51
52 switch (instr->op) {
53 case nir_op_imul_high:
54 case nir_op_umul_high:
55 if (b->shader->options->lower_mul_high) {
56 nir_ssa_def *c1 = nir_imm_int(b, 1);
57 nir_ssa_def *c16 = nir_imm_int(b, 16);
58
59 nir_ssa_def *src0 = nir_ssa_for_alu_src(b, instr, 0);
60 nir_ssa_def *src1 = nir_ssa_for_alu_src(b, instr, 1);
61 nir_ssa_def *different_signs = NULL;
62 if (instr->op == nir_op_imul_high) {
63 nir_ssa_def *c0 = nir_imm_int(b, 0);
64 different_signs = nir_ixor(b,
65 nir_ilt(b, src0, c0),
66 nir_ilt(b, src1, c0));
67 src0 = nir_iabs(b, src0);
68 src1 = nir_iabs(b, src1);
69 }
70
71 /* ABCD
72 * * EFGH
73 * ======
74 * (GH * CD) + (GH * AB) << 16 + (EF * CD) << 16 + (EF * AB) << 32
75 *
76 * Start by splitting into the 4 multiplies.
77 */
78 nir_ssa_def *src0l = nir_iand(b, src0, nir_imm_int(b, 0xffff));
79 nir_ssa_def *src1l = nir_iand(b, src1, nir_imm_int(b, 0xffff));
80 nir_ssa_def *src0h = nir_ushr(b, src0, c16);
81 nir_ssa_def *src1h = nir_ushr(b, src1, c16);
82
83 nir_ssa_def *lo = nir_imul(b, src0l, src1l);
84 nir_ssa_def *m1 = nir_imul(b, src0l, src1h);
85 nir_ssa_def *m2 = nir_imul(b, src0h, src1l);
86 nir_ssa_def *hi = nir_imul(b, src0h, src1h);
87
88 nir_ssa_def *tmp;
89
90 tmp = nir_ishl(b, m1, c16);
91 hi = nir_iadd(b, hi, nir_iand(b, nir_uadd_carry(b, lo, tmp), c1));
92 lo = nir_iadd(b, lo, tmp);
93 hi = nir_iadd(b, hi, nir_ushr(b, m1, c16));
94
95 tmp = nir_ishl(b, m2, c16);
96 hi = nir_iadd(b, hi, nir_iand(b, nir_uadd_carry(b, lo, tmp), c1));
97 lo = nir_iadd(b, lo, tmp);
98 hi = nir_iadd(b, hi, nir_ushr(b, m2, c16));
99
100 if (instr->op == nir_op_imul_high) {
101 /* For channels where different_signs is set we have to perform a
102 * 64-bit negation. This is *not* the same as just negating the
103 * high 32-bits. Consider -3 * 2. The high 32-bits is 0, but the
104 * desired result is -1, not -0! Recall -x == ~x + 1.
105 */
106 hi = nir_bcsel(b, different_signs,
107 nir_iadd(b,
108 nir_inot(b, hi),
109 nir_iand(b,
110 nir_uadd_carry(b,
111 nir_inot(b, lo),
112 c1),
113 nir_imm_int(b, 1))),
114 hi);
115 }
116
117 lowered = hi;
118 }
119 break;
120
121 default:
122 break;
123 }
124
125 if (lowered) {
126 nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, nir_src_for_ssa(lowered));
127 nir_instr_remove(&instr->instr);
128 return true;
129 } else {
130 return false;
131 }
132 }
133
134 bool
135 nir_lower_alu(nir_shader *shader)
136 {
137 bool progress = false;
138
139 if (!shader->options->lower_mul_high)
140 return false;
141
142 nir_foreach_function(function, shader) {
143 if (function->impl) {
144 nir_builder builder;
145 nir_builder_init(&builder, function->impl);
146
147 nir_foreach_block(block, function->impl) {
148 nir_foreach_instr_safe(instr, block) {
149 if (instr->type == nir_instr_type_alu) {
150 progress = lower_alu_instr(nir_instr_as_alu(instr),
151 &builder) || progress;
152 }
153 }
154 }
155
156 if (progress) {
157 nir_metadata_preserve(function->impl,
158 nir_metadata_block_index |
159 nir_metadata_dominance);
160 }
161 }
162 }
163
164 return progress;
165 }