nir/lower_int64: Lower 8 and 16-bit downcasts with nir_lower_mov64
[mesa.git] / src / compiler / nir / nir_lower_int64.c
1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 static nir_ssa_def *
28 lower_b2i64(nir_builder *b, nir_ssa_def *x)
29 {
30 return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
31 }
32
33 static nir_ssa_def *
34 lower_i2b(nir_builder *b, nir_ssa_def *x)
35 {
36 return nir_ine(b, nir_ior(b, nir_unpack_64_2x32_split_x(b, x),
37 nir_unpack_64_2x32_split_y(b, x)),
38 nir_imm_int(b, 0));
39 }
40
41 static nir_ssa_def *
42 lower_i2i8(nir_builder *b, nir_ssa_def *x)
43 {
44 return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
45 }
46
47 static nir_ssa_def *
48 lower_i2i16(nir_builder *b, nir_ssa_def *x)
49 {
50 return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
51 }
52
53
54 static nir_ssa_def *
55 lower_i2i32(nir_builder *b, nir_ssa_def *x)
56 {
57 return nir_unpack_64_2x32_split_x(b, x);
58 }
59
60 static nir_ssa_def *
61 lower_i2i64(nir_builder *b, nir_ssa_def *x)
62 {
63 nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
64 return nir_pack_64_2x32_split(b, x32, nir_ishr(b, x32, nir_imm_int(b, 31)));
65 }
66
67 static nir_ssa_def *
68 lower_u2u8(nir_builder *b, nir_ssa_def *x)
69 {
70 return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
71 }
72
73 static nir_ssa_def *
74 lower_u2u16(nir_builder *b, nir_ssa_def *x)
75 {
76 return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
77 }
78
79 static nir_ssa_def *
80 lower_u2u32(nir_builder *b, nir_ssa_def *x)
81 {
82 return nir_unpack_64_2x32_split_x(b, x);
83 }
84
85 static nir_ssa_def *
86 lower_u2u64(nir_builder *b, nir_ssa_def *x)
87 {
88 nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
89 return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
90 }
91
92 static nir_ssa_def *
93 lower_bcsel64(nir_builder *b, nir_ssa_def *cond, nir_ssa_def *x, nir_ssa_def *y)
94 {
95 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
96 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
97 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
98 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
99
100 return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
101 nir_bcsel(b, cond, x_hi, y_hi));
102 }
103
104 static nir_ssa_def *
105 lower_inot64(nir_builder *b, nir_ssa_def *x)
106 {
107 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
108 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
109
110 return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
111 }
112
113 static nir_ssa_def *
114 lower_iand64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
115 {
116 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
119 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
120
121 return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
122 nir_iand(b, x_hi, y_hi));
123 }
124
125 static nir_ssa_def *
126 lower_ior64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
127 {
128 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
129 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
130 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
131 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
132
133 return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
134 nir_ior(b, x_hi, y_hi));
135 }
136
137 static nir_ssa_def *
138 lower_ixor64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
139 {
140 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
141 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
142 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
143 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
144
145 return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
146 nir_ixor(b, x_hi, y_hi));
147 }
148
149 static nir_ssa_def *
150 lower_ishl64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
151 {
152 /* Implemented as
153 *
154 * uint64_t lshift(uint64_t x, int c)
155 * {
156 * if (c == 0) return x;
157 *
158 * uint32_t lo = LO(x), hi = HI(x);
159 *
160 * if (c < 32) {
161 * uint32_t lo_shifted = lo << c;
162 * uint32_t hi_shifted = hi << c;
163 * uint32_t lo_shifted_hi = lo >> abs(32 - c);
164 * return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
165 * } else {
166 * uint32_t lo_shifted_hi = lo << abs(32 - c);
167 * return pack_64(0, lo_shifted_hi);
168 * }
169 * }
170 */
171 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
172 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
173
174 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
175 nir_ssa_def *lo_shifted = nir_ishl(b, x_lo, y);
176 nir_ssa_def *hi_shifted = nir_ishl(b, x_hi, y);
177 nir_ssa_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
178
179 nir_ssa_def *res_if_lt_32 =
180 nir_pack_64_2x32_split(b, lo_shifted,
181 nir_ior(b, hi_shifted, lo_shifted_hi));
182 nir_ssa_def *res_if_ge_32 =
183 nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
184 nir_ishl(b, x_lo, reverse_count));
185
186 return nir_bcsel(b,
187 nir_ieq(b, y, nir_imm_int(b, 0)), x,
188 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
189 res_if_ge_32, res_if_lt_32));
190 }
191
192 static nir_ssa_def *
193 lower_ishr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
194 {
195 /* Implemented as
196 *
197 * uint64_t arshift(uint64_t x, int c)
198 * {
199 * if (c == 0) return x;
200 *
201 * uint32_t lo = LO(x);
202 * int32_t hi = HI(x);
203 *
204 * if (c < 32) {
205 * uint32_t lo_shifted = lo >> c;
206 * uint32_t hi_shifted = hi >> c;
207 * uint32_t hi_shifted_lo = hi << abs(32 - c);
208 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
209 * } else {
210 * uint32_t hi_shifted = hi >> 31;
211 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
212 * return pack_64(hi_shifted, hi_shifted_lo);
213 * }
214 * }
215 */
216 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
217 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
218
219 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
220 nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
221 nir_ssa_def *hi_shifted = nir_ishr(b, x_hi, y);
222 nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
223
224 nir_ssa_def *res_if_lt_32 =
225 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
226 hi_shifted);
227 nir_ssa_def *res_if_ge_32 =
228 nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
229 nir_ishr(b, x_hi, nir_imm_int(b, 31)));
230
231 return nir_bcsel(b,
232 nir_ieq(b, y, nir_imm_int(b, 0)), x,
233 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
234 res_if_ge_32, res_if_lt_32));
235 }
236
237 static nir_ssa_def *
238 lower_ushr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
239 {
240 /* Implemented as
241 *
242 * uint64_t rshift(uint64_t x, int c)
243 * {
244 * if (c == 0) return x;
245 *
246 * uint32_t lo = LO(x), hi = HI(x);
247 *
248 * if (c < 32) {
249 * uint32_t lo_shifted = lo >> c;
250 * uint32_t hi_shifted = hi >> c;
251 * uint32_t hi_shifted_lo = hi << abs(32 - c);
252 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
253 * } else {
254 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
255 * return pack_64(0, hi_shifted_lo);
256 * }
257 * }
258 */
259
260 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
261 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
262
263 nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
264 nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
265 nir_ssa_def *hi_shifted = nir_ushr(b, x_hi, y);
266 nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
267
268 nir_ssa_def *res_if_lt_32 =
269 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
270 hi_shifted);
271 nir_ssa_def *res_if_ge_32 =
272 nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
273 nir_imm_int(b, 0));
274
275 return nir_bcsel(b,
276 nir_ieq(b, y, nir_imm_int(b, 0)), x,
277 nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
278 res_if_ge_32, res_if_lt_32));
279 }
280
281 static nir_ssa_def *
282 lower_iadd64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
283 {
284 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
285 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
286 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
287 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
288
289 nir_ssa_def *res_lo = nir_iadd(b, x_lo, y_lo);
290 nir_ssa_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
291 nir_ssa_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
292
293 return nir_pack_64_2x32_split(b, res_lo, res_hi);
294 }
295
296 static nir_ssa_def *
297 lower_isub64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
298 {
299 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303
304 nir_ssa_def *res_lo = nir_isub(b, x_lo, y_lo);
305 nir_ssa_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
306 nir_ssa_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
307
308 return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310
311 static nir_ssa_def *
312 lower_ineg64(nir_builder *b, nir_ssa_def *x)
313 {
314 /* Since isub is the same number of instructions (with better dependencies)
315 * as iadd, subtraction is actually more efficient for ineg than the usual
316 * 2's complement "flip the bits and add one".
317 */
318 return lower_isub64(b, nir_imm_int64(b, 0), x);
319 }
320
321 static nir_ssa_def *
322 lower_iabs64(nir_builder *b, nir_ssa_def *x)
323 {
324 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
325 nir_ssa_def *x_is_neg = nir_ilt(b, x_hi, nir_imm_int(b, 0));
326 return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
327 }
328
329 static nir_ssa_def *
330 lower_int64_compare(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *y)
331 {
332 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
333 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
334 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
335 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
336
337 switch (op) {
338 case nir_op_ieq:
339 return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
340 case nir_op_ine:
341 return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
342 case nir_op_ult:
343 return nir_ior(b, nir_ult(b, x_hi, y_hi),
344 nir_iand(b, nir_ieq(b, x_hi, y_hi),
345 nir_ult(b, x_lo, y_lo)));
346 case nir_op_ilt:
347 return nir_ior(b, nir_ilt(b, x_hi, y_hi),
348 nir_iand(b, nir_ieq(b, x_hi, y_hi),
349 nir_ult(b, x_lo, y_lo)));
350 break;
351 case nir_op_uge:
352 /* Lower as !(x < y) in the hopes of better CSE */
353 return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
354 case nir_op_ige:
355 /* Lower as !(x < y) in the hopes of better CSE */
356 return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
357 default:
358 unreachable("Invalid comparison");
359 }
360 }
361
362 static nir_ssa_def *
363 lower_umax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
364 {
365 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
366 }
367
368 static nir_ssa_def *
369 lower_imax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
370 {
371 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
372 }
373
374 static nir_ssa_def *
375 lower_umin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
376 {
377 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
378 }
379
380 static nir_ssa_def *
381 lower_imin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
382 {
383 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
384 }
385
386 static nir_ssa_def *
387 lower_mul_2x32_64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
388 bool sign_extend)
389 {
390 nir_ssa_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
391 : nir_umul_high(b, x, y);
392
393 return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
394 }
395
396 static nir_ssa_def *
397 lower_imul64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
398 {
399 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
400 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
401 nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
402 nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
403
404 nir_ssa_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
405 nir_ssa_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
406 nir_iadd(b, nir_imul(b, x_lo, y_hi),
407 nir_imul(b, x_hi, y_lo)));
408
409 return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
410 res_hi);
411 }
412
413 static nir_ssa_def *
414 lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
415 bool sign_extend)
416 {
417 nir_ssa_def *x32[4], *y32[4];
418 x32[0] = nir_unpack_64_2x32_split_x(b, x);
419 x32[1] = nir_unpack_64_2x32_split_y(b, x);
420 if (sign_extend) {
421 x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
422 } else {
423 x32[2] = x32[3] = nir_imm_int(b, 0);
424 }
425
426 y32[0] = nir_unpack_64_2x32_split_x(b, y);
427 y32[1] = nir_unpack_64_2x32_split_y(b, y);
428 if (sign_extend) {
429 y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
430 } else {
431 y32[2] = y32[3] = nir_imm_int(b, 0);
432 }
433
434 nir_ssa_def *res[8] = { NULL, };
435
436 /* Yes, the following generates a pile of code. However, we throw res[0]
437 * and res[1] away in the end and, if we're in the umul case, four of our
438 * eight dword operands will be constant zero and opt_algebraic will clean
439 * this up nicely.
440 */
441 for (unsigned i = 0; i < 4; i++) {
442 nir_ssa_def *carry = NULL;
443 for (unsigned j = 0; j < 4; j++) {
444 /* The maximum values of x32[i] and y32[i] are UINT32_MAX so the
445 * maximum value of tmp is UINT32_MAX * UINT32_MAX. The maximum
446 * value that will fit in tmp is
447 *
448 * UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
449 * = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
450 * = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
451 *
452 * so we're guaranteed that we can add in two more 32-bit values
453 * without overflowing tmp.
454 */
455 nir_ssa_def *tmp = nir_umul_2x32_64(b, x32[i], y32[i]);
456
457 if (res[i + j])
458 tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
459 if (carry)
460 tmp = nir_iadd(b, tmp, carry);
461 res[i + j] = nir_u2u32(b, tmp);
462 carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
463 }
464 res[i + 4] = nir_u2u32(b, carry);
465 }
466
467 return nir_pack_64_2x32_split(b, res[2], res[3]);
468 }
469
470 static nir_ssa_def *
471 lower_isign64(nir_builder *b, nir_ssa_def *x)
472 {
473 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
474 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
475
476 nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
477 nir_ssa_def *res_hi = nir_ishr(b, x_hi, nir_imm_int(b, 31));
478 nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
479
480 return nir_pack_64_2x32_split(b, res_lo, res_hi);
481 }
482
483 static void
484 lower_udiv64_mod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d,
485 nir_ssa_def **q, nir_ssa_def **r)
486 {
487 /* TODO: We should specially handle the case where the denominator is a
488 * constant. In that case, we should be able to reduce it to a multiply by
489 * a constant, some shifts, and an add.
490 */
491 nir_ssa_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
492 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
493 nir_ssa_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
494 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
495
496 nir_ssa_def *q_lo = nir_imm_zero(b, n->num_components, 32);
497 nir_ssa_def *q_hi = nir_imm_zero(b, n->num_components, 32);
498
499 nir_ssa_def *n_hi_before_if = n_hi;
500 nir_ssa_def *q_hi_before_if = q_hi;
501
502 /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
503 * greater than 32 bits to occur. If the upper 32 bits of the numerator
504 * are zero, it is impossible for (denom << [63, 32]) <= numer unless
505 * denom == 0.
506 */
507 nir_ssa_def *need_high_div =
508 nir_iand(b, nir_ieq(b, d_hi, nir_imm_int(b, 0)), nir_uge(b, n_hi, d_lo));
509 nir_push_if(b, nir_bany(b, need_high_div));
510 {
511 /* If we only have one component, then the bany above goes away and
512 * this is always true within the if statement.
513 */
514 if (n->num_components == 1)
515 need_high_div = nir_imm_true(b);
516
517 nir_ssa_def *log2_d_lo = nir_ufind_msb(b, d_lo);
518
519 for (int i = 31; i >= 0; i--) {
520 /* if ((d.x << i) <= n.y) {
521 * n.y -= d.x << i;
522 * quot.y |= 1U << i;
523 * }
524 */
525 nir_ssa_def *d_shift = nir_ishl(b, d_lo, nir_imm_int(b, i));
526 nir_ssa_def *new_n_hi = nir_isub(b, n_hi, d_shift);
527 nir_ssa_def *new_q_hi = nir_ior(b, q_hi, nir_imm_int(b, 1u << i));
528 nir_ssa_def *cond = nir_iand(b, need_high_div,
529 nir_uge(b, n_hi, d_shift));
530 if (i != 0) {
531 /* log2_d_lo is always <= 31, so we don't need to bother with it
532 * in the last iteration.
533 */
534 cond = nir_iand(b, cond,
535 nir_ige(b, nir_imm_int(b, 31 - i), log2_d_lo));
536 }
537 n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
538 q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
539 }
540 }
541 nir_pop_if(b, NULL);
542 n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
543 q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
544
545 nir_ssa_def *log2_denom = nir_ufind_msb(b, d_hi);
546
547 n = nir_pack_64_2x32_split(b, n_lo, n_hi);
548 d = nir_pack_64_2x32_split(b, d_lo, d_hi);
549 for (int i = 31; i >= 0; i--) {
550 /* if ((d64 << i) <= n64) {
551 * n64 -= d64 << i;
552 * quot.x |= 1U << i;
553 * }
554 */
555 nir_ssa_def *d_shift = nir_ishl(b, d, nir_imm_int(b, i));
556 nir_ssa_def *new_n = nir_isub(b, n, d_shift);
557 nir_ssa_def *new_q_lo = nir_ior(b, q_lo, nir_imm_int(b, 1u << i));
558 nir_ssa_def *cond = nir_uge(b, n, d_shift);
559 if (i != 0) {
560 /* log2_denom is always <= 31, so we don't need to bother with it
561 * in the last iteration.
562 */
563 cond = nir_iand(b, cond,
564 nir_ige(b, nir_imm_int(b, 31 - i), log2_denom));
565 }
566 n = nir_bcsel(b, cond, new_n, n);
567 q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
568 }
569
570 *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
571 *r = n;
572 }
573
574 static nir_ssa_def *
575 lower_udiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
576 {
577 nir_ssa_def *q, *r;
578 lower_udiv64_mod64(b, n, d, &q, &r);
579 return q;
580 }
581
582 static nir_ssa_def *
583 lower_idiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
584 {
585 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
586 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
587
588 nir_ssa_def *negate = nir_ine(b, nir_ilt(b, n_hi, nir_imm_int(b, 0)),
589 nir_ilt(b, d_hi, nir_imm_int(b, 0)));
590 nir_ssa_def *q, *r;
591 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
592 return nir_bcsel(b, negate, nir_ineg(b, q), q);
593 }
594
595 static nir_ssa_def *
596 lower_umod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
597 {
598 nir_ssa_def *q, *r;
599 lower_udiv64_mod64(b, n, d, &q, &r);
600 return r;
601 }
602
603 static nir_ssa_def *
604 lower_imod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
605 {
606 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
607 nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
608 nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
609 nir_ssa_def *d_is_neg = nir_ilt(b, d_hi, nir_imm_int(b, 0));
610
611 nir_ssa_def *q, *r;
612 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
613
614 nir_ssa_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
615
616 return nir_bcsel(b, nir_ieq(b, r, nir_imm_int64(b, 0)), nir_imm_int64(b, 0),
617 nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
618 nir_iadd(b, rem, d)));
619 }
620
621 static nir_ssa_def *
622 lower_irem64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
623 {
624 nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
625 nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
626
627 nir_ssa_def *q, *r;
628 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
629 return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
630 }
631
632 static nir_ssa_def *
633 lower_extract(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *c)
634 {
635 assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
636 op == nir_op_extract_u16 || op == nir_op_extract_i16);
637
638 const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
639 const int chunk_bits =
640 (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
641 const int num_chunks_in_32 = 32 / chunk_bits;
642
643 nir_ssa_def *extract32;
644 if (chunk < num_chunks_in_32) {
645 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
646 nir_imm_int(b, chunk),
647 NULL, NULL);
648 } else {
649 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
650 nir_imm_int(b, chunk - num_chunks_in_32),
651 NULL, NULL);
652 }
653
654 if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
655 return lower_i2i64(b, extract32);
656 else
657 return lower_u2u64(b, extract32);
658 }
659
660 static nir_ssa_def *
661 lower_ufind_msb64(nir_builder *b, nir_ssa_def *x)
662 {
663
664 nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
665 nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
666 nir_ssa_def *lo_count = nir_ufind_msb(b, x_lo);
667 nir_ssa_def *hi_count = nir_ufind_msb(b, x_hi);
668 nir_ssa_def *valid_hi_bits = nir_ine(b, x_hi, nir_imm_int(b, 0));
669 nir_ssa_def *hi_res = nir_iadd(b, nir_imm_intN_t(b, 32, 32), hi_count);
670 return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
671 }
672
673 nir_lower_int64_options
674 nir_lower_int64_op_to_options_mask(nir_op opcode)
675 {
676 switch (opcode) {
677 case nir_op_imul:
678 return nir_lower_imul64;
679 case nir_op_imul_2x32_64:
680 case nir_op_umul_2x32_64:
681 return nir_lower_imul_2x32_64;
682 case nir_op_imul_high:
683 case nir_op_umul_high:
684 return nir_lower_imul_high64;
685 case nir_op_isign:
686 return nir_lower_isign64;
687 case nir_op_udiv:
688 case nir_op_idiv:
689 case nir_op_umod:
690 case nir_op_imod:
691 case nir_op_irem:
692 return nir_lower_divmod64;
693 case nir_op_b2i64:
694 case nir_op_i2b1:
695 case nir_op_i2i8:
696 case nir_op_i2i16:
697 case nir_op_i2i32:
698 case nir_op_i2i64:
699 case nir_op_u2u8:
700 case nir_op_u2u16:
701 case nir_op_u2u32:
702 case nir_op_u2u64:
703 case nir_op_bcsel:
704 return nir_lower_mov64;
705 case nir_op_ieq:
706 case nir_op_ine:
707 case nir_op_ult:
708 case nir_op_ilt:
709 case nir_op_uge:
710 case nir_op_ige:
711 return nir_lower_icmp64;
712 case nir_op_iadd:
713 case nir_op_isub:
714 return nir_lower_iadd64;
715 case nir_op_imin:
716 case nir_op_imax:
717 case nir_op_umin:
718 case nir_op_umax:
719 return nir_lower_minmax64;
720 case nir_op_iabs:
721 return nir_lower_iabs64;
722 case nir_op_ineg:
723 return nir_lower_ineg64;
724 case nir_op_iand:
725 case nir_op_ior:
726 case nir_op_ixor:
727 case nir_op_inot:
728 return nir_lower_logic64;
729 case nir_op_ishl:
730 case nir_op_ishr:
731 case nir_op_ushr:
732 return nir_lower_shift64;
733 case nir_op_extract_u8:
734 case nir_op_extract_i8:
735 case nir_op_extract_u16:
736 case nir_op_extract_i16:
737 return nir_lower_extract64;
738 case nir_op_ufind_msb:
739 return nir_lower_ufind_msb64;
740 default:
741 return 0;
742 }
743 }
744
745 static nir_ssa_def *
746 lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
747 {
748 nir_alu_instr *alu = nir_instr_as_alu(instr);
749
750 nir_ssa_def *src[4];
751 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
752 src[i] = nir_ssa_for_alu_src(b, alu, i);
753
754 switch (alu->op) {
755 case nir_op_imul:
756 return lower_imul64(b, src[0], src[1]);
757 case nir_op_imul_2x32_64:
758 return lower_mul_2x32_64(b, src[0], src[1], true);
759 case nir_op_umul_2x32_64:
760 return lower_mul_2x32_64(b, src[0], src[1], false);
761 case nir_op_imul_high:
762 return lower_mul_high64(b, src[0], src[1], true);
763 case nir_op_umul_high:
764 return lower_mul_high64(b, src[0], src[1], false);
765 case nir_op_isign:
766 return lower_isign64(b, src[0]);
767 case nir_op_udiv:
768 return lower_udiv64(b, src[0], src[1]);
769 case nir_op_idiv:
770 return lower_idiv64(b, src[0], src[1]);
771 case nir_op_umod:
772 return lower_umod64(b, src[0], src[1]);
773 case nir_op_imod:
774 return lower_imod64(b, src[0], src[1]);
775 case nir_op_irem:
776 return lower_irem64(b, src[0], src[1]);
777 case nir_op_b2i64:
778 return lower_b2i64(b, src[0]);
779 case nir_op_i2b1:
780 return lower_i2b(b, src[0]);
781 case nir_op_i2i8:
782 return lower_i2i8(b, src[0]);
783 case nir_op_i2i16:
784 return lower_i2i16(b, src[0]);
785 case nir_op_i2i32:
786 return lower_i2i32(b, src[0]);
787 case nir_op_i2i64:
788 return lower_i2i64(b, src[0]);
789 case nir_op_u2u8:
790 return lower_u2u8(b, src[0]);
791 case nir_op_u2u16:
792 return lower_u2u16(b, src[0]);
793 case nir_op_u2u32:
794 return lower_u2u32(b, src[0]);
795 case nir_op_u2u64:
796 return lower_u2u64(b, src[0]);
797 case nir_op_bcsel:
798 return lower_bcsel64(b, src[0], src[1], src[2]);
799 case nir_op_ieq:
800 case nir_op_ine:
801 case nir_op_ult:
802 case nir_op_ilt:
803 case nir_op_uge:
804 case nir_op_ige:
805 return lower_int64_compare(b, alu->op, src[0], src[1]);
806 case nir_op_iadd:
807 return lower_iadd64(b, src[0], src[1]);
808 case nir_op_isub:
809 return lower_isub64(b, src[0], src[1]);
810 case nir_op_imin:
811 return lower_imin64(b, src[0], src[1]);
812 case nir_op_imax:
813 return lower_imax64(b, src[0], src[1]);
814 case nir_op_umin:
815 return lower_umin64(b, src[0], src[1]);
816 case nir_op_umax:
817 return lower_umax64(b, src[0], src[1]);
818 case nir_op_iabs:
819 return lower_iabs64(b, src[0]);
820 case nir_op_ineg:
821 return lower_ineg64(b, src[0]);
822 case nir_op_iand:
823 return lower_iand64(b, src[0], src[1]);
824 case nir_op_ior:
825 return lower_ior64(b, src[0], src[1]);
826 case nir_op_ixor:
827 return lower_ixor64(b, src[0], src[1]);
828 case nir_op_inot:
829 return lower_inot64(b, src[0]);
830 case nir_op_ishl:
831 return lower_ishl64(b, src[0], src[1]);
832 case nir_op_ishr:
833 return lower_ishr64(b, src[0], src[1]);
834 case nir_op_ushr:
835 return lower_ushr64(b, src[0], src[1]);
836 case nir_op_extract_u8:
837 case nir_op_extract_i8:
838 case nir_op_extract_u16:
839 case nir_op_extract_i16:
840 return lower_extract(b, alu->op, src[0], src[1]);
841 case nir_op_ufind_msb:
842 return lower_ufind_msb64(b, src[0]);
843 break;
844 default:
845 unreachable("Invalid ALU opcode to lower");
846 }
847 }
848
849 static bool
850 should_lower_int64_alu_instr(const nir_instr *instr, const void *_options)
851 {
852 const nir_lower_int64_options options =
853 *(const nir_lower_int64_options *)_options;
854
855 if (instr->type != nir_instr_type_alu)
856 return false;
857
858 const nir_alu_instr *alu = nir_instr_as_alu(instr);
859
860 switch (alu->op) {
861 case nir_op_i2b1:
862 case nir_op_i2i8:
863 case nir_op_i2i16:
864 case nir_op_i2i32:
865 case nir_op_u2u8:
866 case nir_op_u2u16:
867 case nir_op_u2u32:
868 assert(alu->src[0].src.is_ssa);
869 if (alu->src[0].src.ssa->bit_size != 64)
870 return false;
871 break;
872 case nir_op_bcsel:
873 assert(alu->src[1].src.is_ssa);
874 assert(alu->src[2].src.is_ssa);
875 assert(alu->src[1].src.ssa->bit_size ==
876 alu->src[2].src.ssa->bit_size);
877 if (alu->src[1].src.ssa->bit_size != 64)
878 return false;
879 break;
880 case nir_op_ieq:
881 case nir_op_ine:
882 case nir_op_ult:
883 case nir_op_ilt:
884 case nir_op_uge:
885 case nir_op_ige:
886 assert(alu->src[0].src.is_ssa);
887 assert(alu->src[1].src.is_ssa);
888 assert(alu->src[0].src.ssa->bit_size ==
889 alu->src[1].src.ssa->bit_size);
890 if (alu->src[0].src.ssa->bit_size != 64)
891 return false;
892 break;
893 case nir_op_ufind_msb:
894 assert(alu->src[0].src.is_ssa);
895 if (alu->src[0].src.ssa->bit_size != 64)
896 return false;
897 break;
898 default:
899 assert(alu->dest.dest.is_ssa);
900 if (alu->dest.dest.ssa.bit_size != 64)
901 return false;
902 break;
903 }
904
905 return (options & nir_lower_int64_op_to_options_mask(alu->op)) != 0;
906 }
907
908 bool
909 nir_lower_int64(nir_shader *shader, nir_lower_int64_options options)
910 {
911 return nir_shader_lower_instructions(shader,
912 should_lower_int64_alu_instr,
913 lower_int64_alu_instr,
914 &options);
915 }