nir/glsl: Add another way of doing lower_imul64 for gen8+
[mesa.git] / src / compiler / nir / nir_lower_int64.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 static nir_ssa_def *
28 lower_b2i64(nir_builder *b, nir_ssa_def *x)
29 {
30 return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
31 }
32
33 static nir_ssa_def *
34 lower_i2b(nir_builder *b, nir_ssa_def *x)
35 {
36 return nir_ine(b, nir_ior(b, nir_unpack_64_2x32_split_x(b, x),
37 nir_unpack_64_2x32_split_y(b, x)),
38 nir_imm_int(b, 0));
39 }
40
41 static nir_ssa_def *
42 lower_i2i8(nir_builder *b, nir_ssa_def *x)
43 {
44 return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
45 }
46
47 static nir_ssa_def *
48 lower_i2i16(nir_builder *b, nir_ssa_def *x)
49 {
50 return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
51 }
52
53
54 static nir_ssa_def *
55 lower_i2i32(nir_builder *b, nir_ssa_def *x)
56 {
57 return nir_unpack_64_2x32_split_x(b, x);
58 }
59
60 static nir_ssa_def *
61 lower_i2i64(nir_builder *b, nir_ssa_def *x)
62 {
63 nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
64 return nir_pack_64_2x32_split(b, x32, nir_ishr(b, x32, nir_imm_int(b, 31)));
65 }
66
67 static nir_ssa_def *
68 lower_u2u8(nir_builder *b, nir_ssa_def *x)
69 {
70 return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
71 }
72
73 static nir_ssa_def *
74 lower_u2u16(nir_builder *b, nir_ssa_def *x)
75 {
76 return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
77 }
78
79 static nir_ssa_def *
80 lower_u2u32(nir_builder *b, nir_ssa_def *x)
81 {
82 return nir_unpack_64_2x32_split_x(b, x);
83 }
84
85 static nir_ssa_def *
86 lower_u2u64(nir_builder *b, nir_ssa_def *x)
87 {
88 nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
89 return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
90 }
91
92 static nir_ssa_def *
93 lower_bcsel64(nir_builder *b, nir_ssa_def *cond, nir_ssa_def *x, nir_ssa_def *y)
94 {
95 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
96 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
97 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
98 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
99
100 return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
101 nir_bcsel(b, cond, x_hi, y_hi));
102 }
103
104 static nir_ssa_def *
105 lower_inot64(nir_builder *b, nir_ssa_def *x)
106 {
107 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
108 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
109
110 return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
111 }
112
113 static nir_ssa_def *
114 lower_iand64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
115 {
116 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
119 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
120
121 return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
122 nir_iand(b, x_hi, y_hi));
123 }
124
125 static nir_ssa_def *
126 lower_ior64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
127 {
128 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
129 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
130 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
131 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
132
133 return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
134 nir_ior(b, x_hi, y_hi));
135 }
136
137 static nir_ssa_def *
138 lower_ixor64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
139 {
140 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
141 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
142 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
143 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
144
145 return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
146 nir_ixor(b, x_hi, y_hi));
147 }
148
149 static nir_ssa_def *
150 lower_ishl64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
151 {
152 /* Implemented as
153 *
154 * uint64_t lshift(uint64_t x, int c)
155 * {
156 * if (c == 0) return x;
157 *
158 * uint32_t lo = LO(x), hi = HI(x);
159 *
160 * if (c < 32) {
161 * uint32_t lo_shifted = lo << c;
162 * uint32_t hi_shifted = hi << c;
163 * uint32_t lo_shifted_hi = lo >> abs(32 - c);
164 * return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
165 * } else {
166 * uint32_t lo_shifted_hi = lo << abs(32 - c);
167 * return pack_64(0, lo_shifted_hi);
168 * }
169 * }
170 */
171 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
172 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
173
174 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
175 nir_ssa_def *lo_shifted = nir_ishl(b, x_lo, y);
176 nir_ssa_def *hi_shifted = nir_ishl(b, x_hi, y);
177 nir_ssa_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
178
179 nir_ssa_def *res_if_lt_32 =
180 nir_pack_64_2x32_split(b, lo_shifted,
181 nir_ior(b, hi_shifted, lo_shifted_hi));
182 nir_ssa_def *res_if_ge_32 =
183 nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
184 nir_ishl(b, x_lo, reverse_count));
185
186 return nir_bcsel(b,
187 nir_ieq(b, y, nir_imm_int(b, 0)), x,
188 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
189 res_if_ge_32, res_if_lt_32));
190 }
191
192 static nir_ssa_def *
193 lower_ishr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
194 {
195 /* Implemented as
196 *
197 * uint64_t arshift(uint64_t x, int c)
198 * {
199 * if (c == 0) return x;
200 *
201 * uint32_t lo = LO(x);
202 * int32_t hi = HI(x);
203 *
204 * if (c < 32) {
205 * uint32_t lo_shifted = lo >> c;
206 * uint32_t hi_shifted = hi >> c;
207 * uint32_t hi_shifted_lo = hi << abs(32 - c);
208 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
209 * } else {
210 * uint32_t hi_shifted = hi >> 31;
211 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
212 * return pack_64(hi_shifted, hi_shifted_lo);
213 * }
214 * }
215 */
216 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
217 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
218
219 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
220 nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
221 nir_ssa_def *hi_shifted = nir_ishr(b, x_hi, y);
222 nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
223
224 nir_ssa_def *res_if_lt_32 =
225 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
226 hi_shifted);
227 nir_ssa_def *res_if_ge_32 =
228 nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
229 nir_ishr(b, x_hi, nir_imm_int(b, 31)));
230
231 return nir_bcsel(b,
232 nir_ieq(b, y, nir_imm_int(b, 0)), x,
233 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
234 res_if_ge_32, res_if_lt_32));
235 }
236
237 static nir_ssa_def *
238 lower_ushr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
239 {
240 /* Implemented as
241 *
242 * uint64_t rshift(uint64_t x, int c)
243 * {
244 * if (c == 0) return x;
245 *
246 * uint32_t lo = LO(x), hi = HI(x);
247 *
248 * if (c < 32) {
249 * uint32_t lo_shifted = lo >> c;
250 * uint32_t hi_shifted = hi >> c;
251 * uint32_t hi_shifted_lo = hi << abs(32 - c);
252 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
253 * } else {
254 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
255 * return pack_64(0, hi_shifted_lo);
256 * }
257 * }
258 */
259
260 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
261 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
262
263 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
264 nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
265 nir_ssa_def *hi_shifted = nir_ushr(b, x_hi, y);
266 nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
267
268 nir_ssa_def *res_if_lt_32 =
269 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
270 hi_shifted);
271 nir_ssa_def *res_if_ge_32 =
272 nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
273 nir_imm_int(b, 0));
274
275 return nir_bcsel(b,
276 nir_ieq(b, y, nir_imm_int(b, 0)), x,
277 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
278 res_if_ge_32, res_if_lt_32));
279 }
280
281 static nir_ssa_def *
282 lower_iadd64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
283 {
284 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
285 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
286 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
287 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
288
289 nir_ssa_def *res_lo = nir_iadd(b, x_lo, y_lo);
290 nir_ssa_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
291 nir_ssa_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
292
293 return nir_pack_64_2x32_split(b, res_lo, res_hi);
294 }
295
296 static nir_ssa_def *
297 lower_isub64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
298 {
299 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303
304 nir_ssa_def *res_lo = nir_isub(b, x_lo, y_lo);
305 nir_ssa_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
306 nir_ssa_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
307
308 return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310
311 static nir_ssa_def *
312 lower_ineg64(nir_builder *b, nir_ssa_def *x)
313 {
314 /* Since isub is the same number of instructions (with better dependencies)
315 * as iadd, subtraction is actually more efficient for ineg than the usual
316 * 2's complement "flip the bits and add one".
317 */
318 return lower_isub64(b, nir_imm_int64(b, 0), x);
319 }
320
321 static nir_ssa_def *
322 lower_iabs64(nir_builder *b, nir_ssa_def *x)
323 {
324 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
325 nir_ssa_def *x_is_neg = nir_ilt(b, x_hi, nir_imm_int(b, 0));
326 return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
327 }
328
329 static nir_ssa_def *
330 lower_int64_compare(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *y)
331 {
332 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
333 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
334 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
335 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
336
337 switch (op) {
338 case nir_op_ieq:
339 return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
340 case nir_op_ine:
341 return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
342 case nir_op_ult:
343 return nir_ior(b, nir_ult(b, x_hi, y_hi),
344 nir_iand(b, nir_ieq(b, x_hi, y_hi),
345 nir_ult(b, x_lo, y_lo)));
346 case nir_op_ilt:
347 return nir_ior(b, nir_ilt(b, x_hi, y_hi),
348 nir_iand(b, nir_ieq(b, x_hi, y_hi),
349 nir_ult(b, x_lo, y_lo)));
350 break;
351 case nir_op_uge:
352 /* Lower as !(x < y) in the hopes of better CSE */
353 return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
354 case nir_op_ige:
355 /* Lower as !(x < y) in the hopes of better CSE */
356 return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
357 default:
358 unreachable("Invalid comparison");
359 }
360 }
361
362 static nir_ssa_def *
363 lower_umax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
364 {
365 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
366 }
367
368 static nir_ssa_def *
369 lower_imax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
370 {
371 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
372 }
373
374 static nir_ssa_def *
375 lower_umin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
376 {
377 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
378 }
379
380 static nir_ssa_def *
381 lower_imin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
382 {
383 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
384 }
385
386 static nir_ssa_def *
387 lower_mul_2x32_64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
388 bool sign_extend)
389 {
390 nir_ssa_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
391 : nir_umul_high(b, x, y);
392
393 return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
394 }
395
396 static nir_ssa_def *
397 lower_imul64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
398 {
399 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
400 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
401 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
402 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
403
404 nir_ssa_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
405 nir_ssa_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
406 nir_iadd(b, nir_imul(b, x_lo, y_hi),
407 nir_imul(b, x_hi, y_lo)));
408
409 return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
410 res_hi);
411 }
412
413 static nir_ssa_def *
414 lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
415 bool sign_extend)
416 {
417 nir_ssa_def *x32[4], *y32[4];
418 x32[0] = nir_unpack_64_2x32_split_x(b, x);
419 x32[1] = nir_unpack_64_2x32_split_y(b, x);
420 if (sign_extend) {
421 x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
422 } else {
423 x32[2] = x32[3] = nir_imm_int(b, 0);
424 }
425
426 y32[0] = nir_unpack_64_2x32_split_x(b, y);
427 y32[1] = nir_unpack_64_2x32_split_y(b, y);
428 if (sign_extend) {
429 y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
430 } else {
431 y32[2] = y32[3] = nir_imm_int(b, 0);
432 }
433
434 nir_ssa_def *res[8] = { NULL, };
435
436 /* Yes, the following generates a pile of code. However, we throw res[0]
437 * and res[1] away in the end and, if we're in the umul case, four of our
438 * eight dword operands will be constant zero and opt_algebraic will clean
439 * this up nicely.
440 */
441 for (unsigned i = 0; i < 4; i++) {
442 nir_ssa_def *carry = NULL;
443 for (unsigned j = 0; j < 4; j++) {
444 /* The maximum values of x32[i] and y32[i] are UINT32_MAX so the
445 * maximum value of tmp is UINT32_MAX * UINT32_MAX. The maximum
446 * value that will fit in tmp is
447 *
448 * UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
449 * = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
450 * = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
451 *
452 * so we're guaranteed that we can add in two more 32-bit values
453 * without overflowing tmp.
454 */
455 nir_ssa_def *tmp = nir_umul_2x32_64(b, x32[i], y32[i]);
456
457 if (res[i + j])
458 tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
459 if (carry)
460 tmp = nir_iadd(b, tmp, carry);
461 res[i + j] = nir_u2u32(b, tmp);
462 carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
463 }
464 res[i + 4] = nir_u2u32(b, carry);
465 }
466
467 return nir_pack_64_2x32_split(b, res[2], res[3]);
468 }
469
470 static nir_ssa_def *
471 lower_isign64(nir_builder *b, nir_ssa_def *x)
472 {
473 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
474 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
475
476 nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
477 nir_ssa_def *res_hi = nir_ishr(b, x_hi, nir_imm_int(b, 31));
478 nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
479
480 return nir_pack_64_2x32_split(b, res_lo, res_hi);
481 }
482
483 static void
484 lower_udiv64_mod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d,
485 nir_ssa_def **q, nir_ssa_def **r)
486 {
487 /* TODO: We should specially handle the case where the denominator is a
488 * constant. In that case, we should be able to reduce it to a multiply by
489 * a constant, some shifts, and an add.
490 */
491 nir_ssa_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
492 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
493 nir_ssa_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
494 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
495
496 nir_const_value v = { .u32 = { 0, 0, 0, 0 } };
497 nir_ssa_def *q_lo = nir_build_imm(b, n->num_components, 32, v);
498 nir_ssa_def *q_hi = nir_build_imm(b, n->num_components, 32, v);
499
500 nir_ssa_def *n_hi_before_if = n_hi;
501 nir_ssa_def *q_hi_before_if = q_hi;
502
503 /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
504 * greater than 32 bits to occur. If the upper 32 bits of the numerator
505 * are zero, it is impossible for (denom << [63, 32]) <= numer unless
506 * denom == 0.
507 */
508 nir_ssa_def *need_high_div =
509 nir_iand(b, nir_ieq(b, d_hi, nir_imm_int(b, 0)), nir_uge(b, n_hi, d_lo));
510 nir_push_if(b, nir_bany(b, need_high_div));
511 {
512 /* If we only have one component, then the bany above goes away and
513 * this is always true within the if statement.
514 */
515 if (n->num_components == 1)
516 need_high_div = nir_imm_true(b);
517
518 nir_ssa_def *log2_d_lo = nir_ufind_msb(b, d_lo);
519
520 for (int i = 31; i >= 0; i--) {
521 /* if ((d.x << i) <= n.y) {
522 * n.y -= d.x << i;
523 * quot.y |= 1U << i;
524 * }
525 */
526 nir_ssa_def *d_shift = nir_ishl(b, d_lo, nir_imm_int(b, i));
527 nir_ssa_def *new_n_hi = nir_isub(b, n_hi, d_shift);
528 nir_ssa_def *new_q_hi = nir_ior(b, q_hi, nir_imm_int(b, 1u << i));
529 nir_ssa_def *cond = nir_iand(b, need_high_div,
530 nir_uge(b, n_hi, d_shift));
531 if (i != 0) {
532 /* log2_d_lo is always <= 31, so we don't need to bother with it
533 * in the last iteration.
534 */
535 cond = nir_iand(b, cond,
536 nir_ige(b, nir_imm_int(b, 31 - i), log2_d_lo));
537 }
538 n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
539 q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
540 }
541 }
542 nir_pop_if(b, NULL);
543 n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
544 q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
545
546 nir_ssa_def *log2_denom = nir_ufind_msb(b, d_hi);
547
548 n = nir_pack_64_2x32_split(b, n_lo, n_hi);
549 d = nir_pack_64_2x32_split(b, d_lo, d_hi);
550 for (int i = 31; i >= 0; i--) {
551 /* if ((d64 << i) <= n64) {
552 * n64 -= d64 << i;
553 * quot.x |= 1U << i;
554 * }
555 */
556 nir_ssa_def *d_shift = nir_ishl(b, d, nir_imm_int(b, i));
557 nir_ssa_def *new_n = nir_isub(b, n, d_shift);
558 nir_ssa_def *new_q_lo = nir_ior(b, q_lo, nir_imm_int(b, 1u << i));
559 nir_ssa_def *cond = nir_uge(b, n, d_shift);
560 if (i != 0) {
561 /* log2_denom is always <= 31, so we don't need to bother with it
562 * in the last iteration.
563 */
564 cond = nir_iand(b, cond,
565 nir_ige(b, nir_imm_int(b, 31 - i), log2_denom));
566 }
567 n = nir_bcsel(b, cond, new_n, n);
568 q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
569 }
570
571 *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
572 *r = n;
573 }
574
575 static nir_ssa_def *
576 lower_udiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
577 {
578 nir_ssa_def *q, *r;
579 lower_udiv64_mod64(b, n, d, &q, &r);
580 return q;
581 }
582
583 static nir_ssa_def *
584 lower_idiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
585 {
586 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
587 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
588
589 nir_ssa_def *negate = nir_ine(b, nir_ilt(b, n_hi, nir_imm_int(b, 0)),
590 nir_ilt(b, d_hi, nir_imm_int(b, 0)));
591 nir_ssa_def *q, *r;
592 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
593 return nir_bcsel(b, negate, nir_ineg(b, q), q);
594 }
595
596 static nir_ssa_def *
597 lower_umod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
598 {
599 nir_ssa_def *q, *r;
600 lower_udiv64_mod64(b, n, d, &q, &r);
601 return r;
602 }
603
604 static nir_ssa_def *
605 lower_imod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
606 {
607 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
608 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
609 nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
610 nir_ssa_def *d_is_neg = nir_ilt(b, d_hi, nir_imm_int(b, 0));
611
612 nir_ssa_def *q, *r;
613 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
614
615 nir_ssa_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
616
617 return nir_bcsel(b, nir_ieq(b, r, nir_imm_int64(b, 0)), nir_imm_int64(b, 0),
618 nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
619 nir_iadd(b, rem, d)));
620 }
621
622 static nir_ssa_def *
623 lower_irem64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
624 {
625 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
626 nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
627
628 nir_ssa_def *q, *r;
629 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
630 return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
631 }
632
633 static nir_lower_int64_options
634 opcode_to_options_mask(nir_op opcode)
635 {
636 switch (opcode) {
637 case nir_op_imul:
638 return nir_lower_imul64;
639 case nir_op_imul_2x32_64:
640 case nir_op_umul_2x32_64:
641 return nir_lower_imul_2x32_64;
642 case nir_op_imul_high:
643 case nir_op_umul_high:
644 return nir_lower_imul_high64;
645 case nir_op_isign:
646 return nir_lower_isign64;
647 case nir_op_udiv:
648 case nir_op_idiv:
649 case nir_op_umod:
650 case nir_op_imod:
651 case nir_op_irem:
652 return nir_lower_divmod64;
653 case nir_op_b2i64:
654 case nir_op_i2b1:
655 case nir_op_i2i32:
656 case nir_op_i2i64:
657 case nir_op_u2u32:
658 case nir_op_u2u64:
659 case nir_op_bcsel:
660 return nir_lower_mov64;
661 case nir_op_ieq:
662 case nir_op_ine:
663 case nir_op_ult:
664 case nir_op_ilt:
665 case nir_op_uge:
666 case nir_op_ige:
667 return nir_lower_icmp64;
668 case nir_op_iadd:
669 case nir_op_isub:
670 return nir_lower_iadd64;
671 case nir_op_imin:
672 case nir_op_imax:
673 case nir_op_umin:
674 case nir_op_umax:
675 return nir_lower_minmax64;
676 case nir_op_iabs:
677 return nir_lower_iabs64;
678 case nir_op_ineg:
679 return nir_lower_ineg64;
680 case nir_op_iand:
681 case nir_op_ior:
682 case nir_op_ixor:
683 case nir_op_inot:
684 return nir_lower_logic64;
685 case nir_op_ishl:
686 case nir_op_ishr:
687 case nir_op_ushr:
688 return nir_lower_shift64;
689 default:
690 return 0;
691 }
692 }
693
694 static nir_ssa_def *
695 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
696 {
697 nir_ssa_def *src[4];
698 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
699 src[i] = nir_ssa_for_alu_src(b, alu, i);
700
701 switch (alu->op) {
702 case nir_op_imul:
703 return lower_imul64(b, src[0], src[1]);
704 case nir_op_imul_2x32_64:
705 return lower_mul_2x32_64(b, src[0], src[1], true);
706 case nir_op_umul_2x32_64:
707 return lower_mul_2x32_64(b, src[0], src[1], false);
708 case nir_op_imul_high:
709 return lower_mul_high64(b, src[0], src[1], true);
710 case nir_op_umul_high:
711 return lower_mul_high64(b, src[0], src[1], false);
712 case nir_op_isign:
713 return lower_isign64(b, src[0]);
714 case nir_op_udiv:
715 return lower_udiv64(b, src[0], src[1]);
716 case nir_op_idiv:
717 return lower_idiv64(b, src[0], src[1]);
718 case nir_op_umod:
719 return lower_umod64(b, src[0], src[1]);
720 case nir_op_imod:
721 return lower_imod64(b, src[0], src[1]);
722 case nir_op_irem:
723 return lower_irem64(b, src[0], src[1]);
724 case nir_op_b2i64:
725 return lower_b2i64(b, src[0]);
726 case nir_op_i2b1:
727 return lower_i2b(b, src[0]);
728 case nir_op_i2i8:
729 return lower_i2i8(b, src[0]);
730 case nir_op_i2i16:
731 return lower_i2i16(b, src[0]);
732 case nir_op_i2i32:
733 return lower_i2i32(b, src[0]);
734 case nir_op_i2i64:
735 return lower_i2i64(b, src[0]);
736 case nir_op_u2u8:
737 return lower_u2u8(b, src[0]);
738 case nir_op_u2u16:
739 return lower_u2u16(b, src[0]);
740 case nir_op_u2u32:
741 return lower_u2u32(b, src[0]);
742 case nir_op_u2u64:
743 return lower_u2u64(b, src[0]);
744 case nir_op_bcsel:
745 return lower_bcsel64(b, src[0], src[1], src[2]);
746 case nir_op_ieq:
747 case nir_op_ine:
748 case nir_op_ult:
749 case nir_op_ilt:
750 case nir_op_uge:
751 case nir_op_ige:
752 return lower_int64_compare(b, alu->op, src[0], src[1]);
753 case nir_op_iadd:
754 return lower_iadd64(b, src[0], src[1]);
755 case nir_op_isub:
756 return lower_isub64(b, src[0], src[1]);
757 case nir_op_imin:
758 return lower_imin64(b, src[0], src[1]);
759 case nir_op_imax:
760 return lower_imax64(b, src[0], src[1]);
761 case nir_op_umin:
762 return lower_umin64(b, src[0], src[1]);
763 case nir_op_umax:
764 return lower_umax64(b, src[0], src[1]);
765 case nir_op_iabs:
766 return lower_iabs64(b, src[0]);
767 case nir_op_ineg:
768 return lower_ineg64(b, src[0]);
769 case nir_op_iand:
770 return lower_iand64(b, src[0], src[1]);
771 case nir_op_ior:
772 return lower_ior64(b, src[0], src[1]);
773 case nir_op_ixor:
774 return lower_ixor64(b, src[0], src[1]);
775 case nir_op_inot:
776 return lower_inot64(b, src[0]);
777 case nir_op_ishl:
778 return lower_ishl64(b, src[0], src[1]);
779 case nir_op_ishr:
780 return lower_ishr64(b, src[0], src[1]);
781 case nir_op_ushr:
782 return lower_ushr64(b, src[0], src[1]);
783 default:
784 unreachable("Invalid ALU opcode to lower");
785 }
786 }
787
788 static bool
789 lower_int64_impl(nir_function_impl *impl, nir_lower_int64_options options)
790 {
791 nir_builder b;
792 nir_builder_init(&b, impl);
793
794 bool progress = false;
795 nir_foreach_block(block, impl) {
796 nir_foreach_instr_safe(instr, block) {
797 if (instr->type != nir_instr_type_alu)
798 continue;
799
800 nir_alu_instr *alu = nir_instr_as_alu(instr);
801 switch (alu->op) {
802 case nir_op_i2b1:
803 case nir_op_i2i32:
804 case nir_op_u2u32:
805 assert(alu->src[0].src.is_ssa);
806 if (alu->src[0].src.ssa->bit_size != 64)
807 continue;
808 break;
809 case nir_op_bcsel:
810 assert(alu->src[1].src.is_ssa);
811 assert(alu->src[2].src.is_ssa);
812 assert(alu->src[1].src.ssa->bit_size ==
813 alu->src[2].src.ssa->bit_size);
814 if (alu->src[1].src.ssa->bit_size != 64)
815 continue;
816 break;
817 case nir_op_ieq:
818 case nir_op_ine:
819 case nir_op_ult:
820 case nir_op_ilt:
821 case nir_op_uge:
822 case nir_op_ige:
823 assert(alu->src[0].src.is_ssa);
824 assert(alu->src[1].src.is_ssa);
825 assert(alu->src[0].src.ssa->bit_size ==
826 alu->src[1].src.ssa->bit_size);
827 if (alu->src[0].src.ssa->bit_size != 64)
828 continue;
829 break;
830 default:
831 assert(alu->dest.dest.is_ssa);
832 if (alu->dest.dest.ssa.bit_size != 64)
833 continue;
834 break;
835 }
836
837 if (!(options & opcode_to_options_mask(alu->op)))
838 continue;
839
840 b.cursor = nir_before_instr(instr);
841
842 nir_ssa_def *lowered = lower_int64_alu_instr(&b, alu);
843 nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa,
844 nir_src_for_ssa(lowered));
845 nir_instr_remove(&alu->instr);
846 progress = true;
847 }
848 }
849
850 if (progress) {
851 nir_metadata_preserve(impl, nir_metadata_none);
852 } else {
853 #ifndef NDEBUG
854 impl->valid_metadata &= ~nir_metadata_not_properly_reset;
855 #endif
856 }
857
858 return progress;
859 }
860
861 bool
862 nir_lower_int64(nir_shader *shader, nir_lower_int64_options options)
863 {
864 bool progress = false;
865
866 nir_foreach_function(function, shader) {
867 if (function->impl)
868 progress |= lower_int64_impl(function->impl, options);
869 }
870
871 return progress;
872 }