nir/scheduler: Move nir_scheduler to its own header
[mesa.git] / src / compiler / nir / nir_lower_idiv.c
1 /*
2 * Copyright © 2015 Red Hat
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Rob Clark <robclark@freedesktop.org>
25 */
26
27 #include "nir.h"
28 #include "nir_builder.h"
29
30 /* Has two paths
31 * One (nir_lower_idiv_fast) lowers idiv/udiv/umod and is based on
32 * NV50LegalizeSSA::handleDIV()
33 *
34 * Note that this path probably does not have not enough precision for
35 * compute shaders. Perhaps we want a second higher precision (looping)
36 * version of this? Or perhaps we assume if you can do compute shaders you
37 * can also branch out to a pre-optimized shader library routine..
38 *
39 * The other path (nir_lower_idiv_precise) is based off of code used by LLVM's
40 * AMDGPU target. It should handle 32-bit idiv/irem/imod/udiv/umod exactly.
41 */
42
43 static bool
44 convert_instr(nir_builder *bld, nir_alu_instr *alu)
45 {
46 nir_ssa_def *numer, *denom, *af, *bf, *a, *b, *q, *r, *rt;
47 nir_op op = alu->op;
48 bool is_signed;
49
50 if ((op != nir_op_idiv) &&
51 (op != nir_op_udiv) &&
52 (op != nir_op_imod) &&
53 (op != nir_op_umod) &&
54 (op != nir_op_irem))
55 return false;
56
57 is_signed = (op == nir_op_idiv ||
58 op == nir_op_imod ||
59 op == nir_op_irem);
60
61 bld->cursor = nir_before_instr(&alu->instr);
62
63 numer = nir_ssa_for_alu_src(bld, alu, 0);
64 denom = nir_ssa_for_alu_src(bld, alu, 1);
65
66 if (is_signed) {
67 af = nir_i2f32(bld, numer);
68 bf = nir_i2f32(bld, denom);
69 af = nir_fabs(bld, af);
70 bf = nir_fabs(bld, bf);
71 a = nir_iabs(bld, numer);
72 b = nir_iabs(bld, denom);
73 } else {
74 af = nir_u2f32(bld, numer);
75 bf = nir_u2f32(bld, denom);
76 a = numer;
77 b = denom;
78 }
79
80 /* get first result: */
81 bf = nir_frcp(bld, bf);
82 bf = nir_isub(bld, bf, nir_imm_int(bld, 2)); /* yes, really */
83 q = nir_fmul(bld, af, bf);
84
85 if (is_signed) {
86 q = nir_f2i32(bld, q);
87 } else {
88 q = nir_f2u32(bld, q);
89 }
90
91 /* get error of first result: */
92 r = nir_imul(bld, q, b);
93 r = nir_isub(bld, a, r);
94 r = nir_u2f32(bld, r);
95 r = nir_fmul(bld, r, bf);
96 r = nir_f2u32(bld, r);
97
98 /* add quotients: */
99 q = nir_iadd(bld, q, r);
100
101 /* correction: if modulus >= divisor, add 1 */
102 r = nir_imul(bld, q, b);
103 r = nir_isub(bld, a, r);
104 rt = nir_uge(bld, r, b);
105
106 if (op == nir_op_umod) {
107 q = nir_bcsel(bld, rt, nir_isub(bld, r, b), r);
108 } else {
109 r = nir_b2i32(bld, rt);
110
111 q = nir_iadd(bld, q, r);
112 if (is_signed) {
113 /* fix the sign: */
114 r = nir_ixor(bld, numer, denom);
115 r = nir_ilt(bld, r, nir_imm_int(bld, 0));
116 b = nir_ineg(bld, q);
117 q = nir_bcsel(bld, r, b, q);
118
119 if (op == nir_op_imod || op == nir_op_irem) {
120 q = nir_imul(bld, q, denom);
121 q = nir_isub(bld, numer, q);
122 if (op == nir_op_imod) {
123 q = nir_bcsel(bld, nir_ieq(bld, q, nir_imm_int(bld, 0)),
124 nir_imm_int(bld, 0),
125 nir_bcsel(bld, r, nir_iadd(bld, q, denom), q));
126 }
127 }
128 }
129 }
130
131 assert(alu->dest.dest.is_ssa);
132 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(q));
133
134 return true;
135 }
136
137 /* ported from LLVM's AMDGPUTargetLowering::LowerUDIVREM */
138 static nir_ssa_def *
139 emit_udiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, bool modulo)
140 {
141 nir_ssa_def *rcp = nir_frcp(bld, nir_u2f32(bld, denom));
142 rcp = nir_f2u32(bld, nir_fmul_imm(bld, rcp, 4294967296.0));
143 nir_ssa_def *rcp_lo = nir_imul(bld, rcp, denom);
144 nir_ssa_def *rcp_hi = nir_umul_high(bld, rcp, denom);
145 nir_ssa_def *rcp_hi_ne_zero = nir_ine(bld, rcp_hi, nir_imm_int(bld, 0));
146 nir_ssa_def *neg_rcp_lo = nir_ineg(bld, rcp_lo);
147 nir_ssa_def *abs_rcp_lo = nir_bcsel(bld, rcp_hi_ne_zero, rcp_lo, neg_rcp_lo);
148 nir_ssa_def *e = nir_umul_high(bld, abs_rcp_lo, rcp);
149 nir_ssa_def *rcp_plus_e = nir_iadd(bld, rcp, e);
150 nir_ssa_def *rcp_minus_e = nir_isub(bld, rcp, e);
151 nir_ssa_def *tmp0 = nir_bcsel(bld, rcp_hi_ne_zero, rcp_minus_e, rcp_plus_e);
152 nir_ssa_def *quotient = nir_umul_high(bld, tmp0, numer);
153 nir_ssa_def *num_s_remainder = nir_imul(bld, quotient, denom);
154 nir_ssa_def *remainder = nir_isub(bld, numer, num_s_remainder);
155 nir_ssa_def *remainder_ge_den = nir_uge(bld, remainder, denom);
156 nir_ssa_def *remainder_ge_zero = nir_uge(bld, numer, num_s_remainder);
157 nir_ssa_def *tmp1 = nir_iand(bld, remainder_ge_den, remainder_ge_zero);
158
159 if (modulo) {
160 nir_ssa_def *rem = nir_bcsel(bld, tmp1,
161 nir_isub(bld, remainder, denom), remainder);
162 return nir_bcsel(bld, remainder_ge_zero,
163 rem, nir_iadd(bld, remainder, denom));
164 } else {
165 nir_ssa_def *one = nir_imm_int(bld, 1);
166 nir_ssa_def *div = nir_bcsel(bld, tmp1,
167 nir_iadd(bld, quotient, one), quotient);
168 return nir_bcsel(bld, remainder_ge_zero,
169 div, nir_isub(bld, quotient, one));
170 }
171 }
172
173 /* ported from LLVM's AMDGPUTargetLowering::LowerSDIVREM */
174 static nir_ssa_def *
175 emit_idiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, nir_op op)
176 {
177 nir_ssa_def *lh_sign = nir_ilt(bld, numer, nir_imm_int(bld, 0));
178 nir_ssa_def *rh_sign = nir_ilt(bld, denom, nir_imm_int(bld, 0));
179 lh_sign = nir_bcsel(bld, lh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
180 rh_sign = nir_bcsel(bld, rh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
181
182 nir_ssa_def *lhs = nir_iadd(bld, numer, lh_sign);
183 nir_ssa_def *rhs = nir_iadd(bld, denom, rh_sign);
184 lhs = nir_ixor(bld, lhs, lh_sign);
185 rhs = nir_ixor(bld, rhs, rh_sign);
186
187 if (op == nir_op_idiv) {
188 nir_ssa_def *d_sign = nir_ixor(bld, lh_sign, rh_sign);
189 nir_ssa_def *res = emit_udiv(bld, lhs, rhs, false);
190 res = nir_ixor(bld, res, d_sign);
191 return nir_isub(bld, res, d_sign);
192 } else {
193 nir_ssa_def *res = emit_udiv(bld, lhs, rhs, true);
194 res = nir_ixor(bld, res, lh_sign);
195 res = nir_isub(bld, res, lh_sign);
196 if (op == nir_op_imod) {
197 nir_ssa_def *cond = nir_ieq(bld, res, nir_imm_int(bld, 0));
198 cond = nir_ior(bld, nir_ieq(bld, lh_sign, rh_sign), cond);
199 res = nir_bcsel(bld, cond, res, nir_iadd(bld, res, denom));
200 }
201 return res;
202 }
203 }
204
205 static bool
206 convert_instr_precise(nir_builder *bld, nir_alu_instr *alu)
207 {
208 nir_op op = alu->op;
209
210 if ((op != nir_op_idiv) &&
211 (op != nir_op_imod) &&
212 (op != nir_op_irem) &&
213 (op != nir_op_udiv) &&
214 (op != nir_op_umod))
215 return false;
216
217 if (alu->dest.dest.ssa.bit_size != 32)
218 return false;
219
220 bld->cursor = nir_before_instr(&alu->instr);
221
222 nir_ssa_def *numer = nir_ssa_for_alu_src(bld, alu, 0);
223 nir_ssa_def *denom = nir_ssa_for_alu_src(bld, alu, 1);
224
225 nir_ssa_def *res = NULL;
226
227 if (op == nir_op_udiv || op == nir_op_umod)
228 res = emit_udiv(bld, numer, denom, op == nir_op_umod);
229 else
230 res = emit_idiv(bld, numer, denom, op);
231
232 assert(alu->dest.dest.is_ssa);
233 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(res));
234
235 return true;
236 }
237
238 static bool
239 convert_impl(nir_function_impl *impl, enum nir_lower_idiv_path path)
240 {
241 nir_builder b;
242 nir_builder_init(&b, impl);
243 bool progress = false;
244
245 nir_foreach_block(block, impl) {
246 nir_foreach_instr_safe(instr, block) {
247 if (instr->type == nir_instr_type_alu && path == nir_lower_idiv_precise)
248 progress |= convert_instr_precise(&b, nir_instr_as_alu(instr));
249 else if (instr->type == nir_instr_type_alu)
250 progress |= convert_instr(&b, nir_instr_as_alu(instr));
251 }
252 }
253
254 nir_metadata_preserve(impl, nir_metadata_block_index |
255 nir_metadata_dominance);
256
257 return progress;
258 }
259
260 bool
261 nir_lower_idiv(nir_shader *shader, enum nir_lower_idiv_path path)
262 {
263 bool progress = false;
264
265 nir_foreach_function(function, shader) {
266 if (function->impl)
267 progress |= convert_impl(function->impl, path);
268 }
269
270 return progress;
271 }