pan/bit: Interpret CMP
[mesa.git] / src / panfrost / bifrost / test / bi_interpret.c
1 /*
2 * Copyright (C) 2020 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 *
23 * Authors (Collabora):
24 * Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
25 */
26
27 #include <math.h>
28 #include "bit.h"
29 #include "util/half_float.h"
30
31 typedef union {
32 uint64_t u64;
33 uint32_t u32;
34 uint16_t u16[2];
35 uint8_t u8[4];
36 int64_t i64;
37 int32_t i32;
38 int16_t i16[2];
39 int8_t i8[4];
40 double f64;
41 float f32;
42 uint16_t f16[2];
43 } bit_t;
44
45 /* Interprets a subset of Bifrost IR required for automated testing */
46
47 static uint64_t
48 bit_read(struct bit_state *s, bi_instruction *ins, unsigned index, nir_alu_type T, bool FMA)
49 {
50 if (index & BIR_INDEX_REGISTER) {
51 uint32_t reg = index & ~BIR_INDEX_REGISTER;
52 assert(reg < 64);
53 return s->r[reg];
54 } else if (index & BIR_INDEX_UNIFORM) {
55 unreachable("Uniform registers to be implemented");
56 } else if (index & BIR_INDEX_CONSTANT) {
57 return ins->constant.u64 >> (index & ~BIR_INDEX_CONSTANT);
58 } else if (index & BIR_INDEX_ZERO) {
59 return 0;
60 } else if (index & (BIR_INDEX_PASS | BIFROST_SRC_STAGE)) {
61 return FMA ? 0 : s->T;
62 } else if (index & (BIR_INDEX_PASS | BIFROST_SRC_PASS_FMA)) {
63 return s->T0;
64 } else if (index & (BIR_INDEX_PASS | BIFROST_SRC_PASS_ADD)) {
65 return s->T1;
66 } else if (!index) {
67 /* Placeholder */
68 return 0;
69 } else {
70 unreachable("Invalid source");
71 }
72 }
73
74 static void
75 bit_write(struct bit_state *s, unsigned index, nir_alu_type T, bit_t value, bool FMA)
76 {
77 /* Always write stage passthrough */
78 if (FMA)
79 s->T = value.u32;
80
81 if (index & BIR_INDEX_REGISTER) {
82 uint32_t reg = index & ~BIR_INDEX_REGISTER;
83 assert(reg < 64);
84 s->r[reg] = value.u32;
85 } else if (!index) {
86 /* Nothing to do */
87 } else {
88 unreachable("Invalid destination");
89 }
90 }
91
92 #define bh _mesa_float_to_half
93 #define bf _mesa_half_to_float
94
95 #define bv2f16(fxn) \
96 for (unsigned c = 0; c < 2; ++c) { \
97 dest.f16[c] = bh(fxn(bf(srcs[0].f16[ins->swizzle[0][c]]), \
98 bf(srcs[1].f16[ins->swizzle[1][c]]), \
99 bf(srcs[2].f16[ins->swizzle[2][c]]), \
100 bf(srcs[3].f16[ins->swizzle[3][c]]))); \
101 }
102
103 #define bv2i16(fxn) \
104 for (unsigned c = 0; c < 2; ++c) { \
105 dest.f16[c] = fxn(srcs[0].u16[ins->swizzle[0][c]], \
106 srcs[1].u16[ins->swizzle[1][c]], \
107 srcs[2].u16[ins->swizzle[2][c]], \
108 srcs[3].u16[ins->swizzle[3][c]]); \
109 }
110
111 #define bf32(fxn) dest.f32 = fxn(srcs[0].f32, srcs[1].f32, srcs[2].f32, srcs[3].f32)
112 #define bi32(fxn) dest.i32 = fxn(srcs[0].u32, srcs[1].u32, srcs[2].u32, srcs[3].i32)
113
114 #define bfloat(fxn64, fxn32) \
115 if (ins->dest_type == nir_type_float64) { \
116 unreachable("TODO: 64-bit"); \
117 } else if (ins->dest_type == nir_type_float32) { \
118 bf32(fxn64); \
119 break; \
120 } else if (ins->dest_type == nir_type_float16) { \
121 bv2f16(fxn32); \
122 break; \
123 }
124
125 #define bint(fxn64, fxn32, fxn16, fxn8) \
126 if (ins->dest_type == nir_type_int64 || ins->dest_type == nir_type_uint64) { \
127 unreachable("TODO: 64-bit"); \
128 } else if (ins->dest_type == nir_type_int32 || ins->dest_type == nir_type_uint32) { \
129 bi32(fxn32); \
130 break; \
131 } else if (ins->dest_type == nir_type_int16 || ins->dest_type == nir_type_uint16) { \
132 bv2i16(fxn16); \
133 break; \
134 } else if (ins->dest_type == nir_type_int8 || ins->dest_type == nir_type_uint8) { \
135 unreachable("TODO: 8-bit"); \
136 }
137
138 #define bpoly(name) \
139 bfloat(bit_f64 ## name, bit_f32 ## name); \
140 bint(bit_i64 ## name, bit_i32 ## name, bit_i16 ## name, bit_i8 ## name); \
141 unreachable("Invalid type");
142
143 #define bit_make_float(name, expr) \
144 static inline double \
145 bit_f64 ## name(double a, double b, double c, double d) \
146 { \
147 return expr; \
148 } \
149 static inline float \
150 bit_f32 ## name(float a, float b, float c, float d) \
151 { \
152 return expr; \
153 } \
154
155 #define bit_make_int(name, expr) \
156 static inline int64_t \
157 bit_i64 ## name (int64_t a, int64_t b, int64_t c, int64_t d) \
158 { \
159 return expr; \
160 } \
161 \
162 static inline int32_t \
163 bit_i32 ## name (int32_t a, int32_t b, int32_t c, int32_t d) \
164 { \
165 return expr; \
166 } \
167 \
168 static inline int16_t \
169 bit_i16 ## name (int16_t a, int16_t b, int16_t c, int16_t d) \
170 { \
171 return expr; \
172 } \
173 \
174 static inline int8_t \
175 bit_i8 ## name (int8_t a, int8_t b, int8_t c, int8_t d) \
176 { \
177 return expr; \
178 } \
179
180 #define bit_make_poly(name, expr) \
181 bit_make_float(name, expr) \
182 bit_make_int(name, expr) \
183
184 bit_make_poly(add, a + b);
185 bit_make_float(fma, (a * b) + c);
186 bit_make_poly(mov, a);
187 bit_make_poly(min, MIN2(a, b));
188 bit_make_poly(max, MAX2(a, b));
189
190 /* Modifiers */
191
192 static float
193 bit_outmod(float raw, enum bifrost_outmod mod)
194 {
195 switch (mod) {
196 case BIFROST_POS:
197 return MAX2(raw, 0.0);
198 case BIFROST_SAT_SIGNED:
199 return CLAMP(raw, -1.0, 1.0);
200 case BIFROST_SAT:
201 return CLAMP(raw, 0.0, 1.0);
202 default:
203 return raw;
204 }
205 }
206
207 static float
208 bit_srcmod(float raw, bool abs, bool neg)
209 {
210 if (abs)
211 raw = fabs(raw);
212
213 if (neg)
214 raw = -raw;
215
216 return raw;
217 }
218
219 #define BIT_COND(cond, left, right) \
220 if (cond == BI_COND_LT) return left < right; \
221 else if (cond == BI_COND_LE) return left <= right; \
222 else if (cond == BI_COND_GE) return left >= right; \
223 else if (cond == BI_COND_GT) return left > right; \
224 else if (cond == BI_COND_EQ) return left == right; \
225 else if (cond == BI_COND_NE) return left != right; \
226 else { return true; }
227
228 static bool
229 bit_eval_cond(enum bi_cond cond, bit_t l, bit_t r, nir_alu_type T, unsigned cl, unsigned cr)
230 {
231 if (T == nir_type_float32) {
232 BIT_COND(cond, l.f32, r.f32);
233 } else if (T == nir_type_float16) {
234 float left = bf(l.f16[cl]);
235 float right = bf(r.f16[cr]);
236 BIT_COND(cond, left, right);
237 } else if (T == nir_type_int32) {
238 int32_t left = l.u32;
239 int32_t right = r.u32;
240 BIT_COND(cond, left, right);
241 } else if (T == nir_type_int16) {
242 int16_t left = l.i16[cl];
243 int16_t right = r.i16[cr];
244 BIT_COND(cond, left, right);
245 } else if (T == nir_type_uint32) {
246 BIT_COND(cond, l.u32, r.u32);
247 } else if (T == nir_type_uint16) {
248 BIT_COND(cond, l.u16[cl], r.u16[cr]);
249 } else {
250 unreachable("Unknown type evaluated");
251 }
252 }
253
254 static unsigned
255 bit_cmp(enum bi_cond cond, bit_t l, bit_t r, nir_alu_type T, unsigned cl, unsigned cr, bool d3d)
256 {
257 bool v = bit_eval_cond(cond, l, r, T, cl, cr);
258
259 /* Fill for D3D but only up to 32-bit... 64-bit is only partial
260 * (although we probably need a cleverer representation for 64-bit) */
261
262 unsigned sz = MIN2(nir_alu_type_get_type_size(T), 32);
263 unsigned max = (sz == 32) ? (~0) : ((1 << sz) - 1);
264
265 return v ? (d3d ? max : 1) : 0;
266 }
267
268 static float
269 biti_special(float Q, enum bi_special_op op)
270 {
271 switch (op) {
272 case BI_SPECIAL_FRCP: return 1.0 / Q;
273 case BI_SPECIAL_FRSQ: {
274 double Qf = 1.0 / sqrt(Q);
275 return Qf;
276 }
277 default: unreachable("Invalid special");
278 }
279 }
280
281 /* For BI_CONVERT. */
282
283 #define _AS_ROUNDMODE(mode) \
284 ((mode == BIFROST_RTZ) ? FP_INT_TOWARDZERO : \
285 (mode == BIFROST_RTE) ? FP_INT_TONEAREST : \
286 (mode == BIFROST_RTN) ? FP_INT_DOWNWARD : \
287 FP_INT_UPWARD)
288
289 static float
290 bit_as_float32(nir_alu_type T, bit_t src, unsigned C)
291 {
292 switch (T) {
293 case nir_type_int32: return src.i32;
294 case nir_type_uint32: return src.u32;
295 case nir_type_float16: return bf(src.u16[C]);
296 default: unreachable("Invalid");
297 }
298 }
299
300 static uint32_t
301 bit_as_uint32(nir_alu_type T, bit_t src, unsigned C, enum bifrost_roundmode rm)
302 {
303 switch (T) {
304 case nir_type_float16: return bf(src.u16[C]);
305 case nir_type_float32: return ufromfpf(src.f32, _AS_ROUNDMODE(rm), 32);
306 default: unreachable("Invalid");
307 }
308 }
309
310 static int32_t
311 bit_as_int32(nir_alu_type T, bit_t src, unsigned C, enum bifrost_roundmode rm)
312 {
313 switch (T) {
314 case nir_type_float16: return bf(src.u16[C]);
315 case nir_type_float32: return fromfpf(src.f32, _AS_ROUNDMODE(rm), 32);
316 default: unreachable("Invalid");
317 }
318 }
319
320 static uint16_t
321 bit_as_float16(nir_alu_type T, bit_t src, unsigned C)
322 {
323 switch (T) {
324 case nir_type_int32: return bh(src.i32);
325 case nir_type_uint32: return bh(src.u32);
326 case nir_type_float32: return bh(src.f32);
327 case nir_type_int16: return bh(src.i16[C]);
328 case nir_type_uint16: return bh(src.u16[C]);
329 default: unreachable("Invalid");
330 }
331 }
332
333 static uint16_t
334 bit_as_uint16(nir_alu_type T, bit_t src, unsigned C, enum bifrost_roundmode rm)
335 {
336 switch (T) {
337 case nir_type_int32: return src.i32;
338 case nir_type_uint32: return src.u32;
339 case nir_type_float16: return ufromfpf(bf(src.u16[C]), _AS_ROUNDMODE(rm), 16);
340 case nir_type_float32: return src.f32;
341 default: unreachable("Invalid");
342 }
343 }
344
345 static int16_t
346 bit_as_int16(nir_alu_type T, bit_t src, unsigned C, enum bifrost_roundmode rm)
347 {
348 switch (T) {
349 case nir_type_int32: return src.i32;
350 case nir_type_uint32: return src.u32;
351 case nir_type_float16: return fromfpf(bf(src.u16[C]), _AS_ROUNDMODE(rm), 16);
352 case nir_type_float32: return src.f32;
353 default: unreachable("Invalid");
354 }
355 }
356
357 static float
358 frexp_log(float x, int *e)
359 {
360 /* Ignore sign until end */
361 float xa = fabs(x);
362
363 /* frexp reduces to [0.5, 1) */
364 float f = frexpf(xa, e);
365
366 /* reduce to [0.75, 1.5) */
367 if (f < 0.75) {
368 f *= 2.0;
369 (*e)--;
370 }
371
372 /* Reattach sign */
373 if (xa < 0.0)
374 f = -f;
375
376 return f;
377 }
378
379 void
380 bit_step(struct bit_state *s, bi_instruction *ins, bool FMA)
381 {
382 /* First, load sources */
383 bit_t srcs[BIR_SRC_COUNT] = { 0 };
384
385 bi_foreach_src(ins, src)
386 srcs[src].u64 = bit_read(s, ins, ins->src[src], ins->src_types[src], FMA);
387
388 /* Apply source modifiers if we need to */
389 if (bi_has_source_mods(ins)) {
390 bi_foreach_src(ins, src) {
391 if (ins->src_types[src] == nir_type_float16) {
392 for (unsigned c = 0; c < 2; ++c) {
393 srcs[src].f16[c] = bh(bit_srcmod(bf(srcs[src].f16[c]),
394 ins->src_abs[src],
395 ins->src_neg[src]));
396 }
397 } else if (ins->src_types[src] == nir_type_float32) {
398 srcs[src].f32 = bit_srcmod(srcs[src].f32,
399 ins->src_abs[src],
400 ins->src_neg[src]);
401 }
402 }
403 }
404
405 /* Next, do the action of the instruction */
406 bit_t dest = { 0 };
407
408 switch (ins->type) {
409 case BI_ADD:
410 bpoly(add);
411
412 case BI_BRANCH:
413 unreachable("Unsupported op");
414
415 case BI_CMP: {
416 nir_alu_type T = ins->src_types[0];
417 unsigned sz = nir_alu_type_get_type_size(T);
418
419 if (sz == 32 || sz == 64) {
420 dest.u32 = bit_cmp(ins->cond, srcs[0], srcs[1], T, 0, 0, false);
421 } else if (sz == 16) {
422 for (unsigned c = 0; c < 2; ++c) {
423 dest.u16[c] = bit_cmp(ins->cond, srcs[0], srcs[1],
424 T, ins->swizzle[0][c], ins->swizzle[1][c],
425 false);
426 }
427 } else if (sz == 8) {
428 for (unsigned c = 0; c < 4; ++c) {
429 dest.u8[c] = bit_cmp(ins->cond, srcs[0], srcs[1],
430 T, ins->swizzle[0][c], ins->swizzle[1][c],
431 false);
432 }
433 } else {
434 unreachable("Invalid");
435 }
436
437 break;
438 }
439
440 case BI_BITWISE:
441 unreachable("Unsupported op");
442
443 case BI_CONVERT: {
444 /* If it exists */
445 unsigned comp = ins->swizzle[0][1];
446
447 if (ins->dest_type == nir_type_float32)
448 dest.f32 = bit_as_float32(ins->src_types[0], srcs[0], comp);
449 else if (ins->dest_type == nir_type_uint32)
450 dest.u32 = bit_as_uint32(ins->src_types[0], srcs[0], comp, ins->roundmode);
451 else if (ins->dest_type == nir_type_int32)
452 dest.i32 = bit_as_int32(ins->src_types[0], srcs[0], comp, ins->roundmode);
453 else if (ins->dest_type == nir_type_float16) {
454 dest.u16[0] = bit_as_float16(ins->src_types[0], srcs[0], ins->swizzle[0][0]);
455
456 if (ins->src_types[0] == nir_type_float32) {
457 /* TODO: Second argument */
458 dest.u16[1] = 0;
459 } else {
460 dest.u16[1] = bit_as_float16(ins->src_types[0], srcs[0], ins->swizzle[0][1]);
461 }
462 } else if (ins->dest_type == nir_type_uint16) {
463 dest.u16[0] = bit_as_uint16(ins->src_types[0], srcs[0], ins->swizzle[0][0], ins->roundmode);
464 dest.u16[1] = bit_as_uint16(ins->src_types[0], srcs[0], ins->swizzle[0][1], ins->roundmode);
465 } else if (ins->dest_type == nir_type_int16) {
466 dest.i16[0] = bit_as_int16(ins->src_types[0], srcs[0], ins->swizzle[0][0], ins->roundmode);
467 dest.i16[1] = bit_as_int16(ins->src_types[0], srcs[0], ins->swizzle[0][1], ins->roundmode);
468 } else {
469 unreachable("Unknown convert type");
470 }
471
472 break;
473 }
474
475 case BI_CSEL: {
476 bool direct = ins->cond == BI_COND_ALWAYS;
477 bool cond = direct ? srcs[0].u32 :
478 bit_eval_cond(ins->cond, srcs[0], srcs[1], ins->src_types[0], 0, 0);
479
480 dest = cond ? srcs[2] : srcs[3];
481 break;
482 }
483
484 case BI_FMA: {
485 bfloat(bit_f64fma, bit_f32fma);
486 unreachable("Unknown type");
487 }
488
489 case BI_FREXP: {
490 if (ins->src_types[0] != nir_type_float32)
491 unreachable("Unknown frexp type");
492
493
494 if (ins->op.frexp == BI_FREXPE_LOG)
495 frexp_log(srcs[0].f32, &dest.i32);
496 else
497 unreachable("Unknown frexp");
498
499 break;
500 }
501 case BI_ISUB:
502 unreachable("Unsupported op");
503
504 case BI_MINMAX: {
505 if (ins->op.minmax == BI_MINMAX_MIN) {
506 bpoly(min);
507 } else {
508 bpoly(max);
509 }
510 }
511
512 case BI_MOV:
513 bpoly(mov);
514
515 case BI_REDUCE_FMA: {
516 if (ins->src_types[0] != nir_type_float32)
517 unreachable("Unknown reduce type");
518
519 if (ins->op.reduce == BI_REDUCE_ADD_FREXPM) {
520 int _nop = 0;
521 float f = frexp_log(srcs[1].f32, &_nop);
522 dest.f32 = srcs[0].f32 + f;
523 } else {
524 unreachable("Unknown reduce");
525 }
526
527 break;
528 }
529
530 case BI_SPECIAL: {
531 assert(nir_alu_type_get_base_type(ins->dest_type) == nir_type_float);
532 assert(ins->dest_type != nir_type_float64);
533
534 if (ins->op.special == BI_SPECIAL_EXP2_LOW) {
535 assert(ins->dest_type == nir_type_float32);
536 dest.f32 = exp2f(srcs[1].f32);
537 break;
538 }
539
540 float Q = (ins->dest_type == nir_type_float16) ?
541 bf(srcs[0].u16[ins->swizzle[0][0]]) :
542 srcs[0].f32;
543
544 float R = biti_special(Q, ins->op.special);
545
546 if (ins->dest_type == nir_type_float16) {
547 dest.f16[0] = bh(R);
548
549 if (!ins->swizzle[0][0] && ins->op.special == BI_SPECIAL_FRSQ) {
550 /* Sorry. */
551 dest.f16[0]++;
552 }
553 } else {
554 dest.f32 = R;
555 }
556 break;
557 }
558
559 case BI_TABLE: {
560 if (ins->op.table == BI_TABLE_LOG2_U_OVER_U_1_LOW) {
561 assert(ins->dest_type == nir_type_float32);
562 int _nop = 0;
563 float f = frexp_log(srcs[0].f32, &_nop);
564 dest.f32 = log2f(f) / (f - 1.0);
565 dest.u32++; /* Sorry. */
566 } else {
567 unreachable("Unknown table op");
568 }
569 break;
570 }
571
572 case BI_SELECT: {
573 if (ins->src_types[0] == nir_type_uint16) {
574 for (unsigned c = 0; c < 2; ++c)
575 dest.u16[c] = srcs[c].u16[ins->swizzle[c][0]];
576 } else if (ins->src_types[0] == nir_type_uint8) {
577 for (unsigned c = 0; c < 4; ++c)
578 dest.u8[c] = srcs[c].u8[ins->swizzle[c][0]];
579 } else {
580 unreachable("Unknown type");
581 }
582 break;
583 }
584
585 case BI_SHIFT:
586 case BI_ROUND:
587 unreachable("Unsupported op");
588
589 /* We only interpret vertex shaders */
590 case BI_DISCARD:
591 case BI_LOAD_VAR:
592 case BI_ATEST:
593 case BI_BLEND:
594 unreachable("Fragment op used in interpreter");
595
596 /* Modeling main memory is more than I bargained for */
597 case BI_LOAD_UNIFORM:
598 case BI_LOAD_ATTR:
599 case BI_LOAD_VAR_ADDRESS:
600 case BI_LOAD:
601 case BI_STORE:
602 case BI_STORE_VAR:
603 case BI_TEX:
604 unreachable("Unsupported I/O in interpreter");
605
606 default:
607 unreachable("Unsupported op");
608 }
609
610 /* Apply _MSCALE */
611 if ((ins->type == BI_FMA || ins->type == BI_ADD) && ins->op.mscale) {
612 unsigned idx = (ins->type == BI_FMA) ? 3 : 2;
613
614 assert(ins->src_types[idx] == nir_type_int32);
615 assert(ins->dest_type == nir_type_float32);
616
617 int32_t scale = srcs[idx].i32;
618 dest.f32 *= exp2f(scale);
619 }
620
621 /* Apply outmod */
622 if (bi_has_outmod(ins) && ins->outmod != BIFROST_NONE) {
623 if (ins->dest_type == nir_type_float16) {
624 for (unsigned c = 0; c < 2; ++c)
625 dest.f16[c] = bh(bit_outmod(bf(dest.f16[c]), ins->outmod));
626 } else {
627 dest.f32 = bit_outmod(dest.f32, ins->outmod);
628 }
629 }
630
631 /* Finally, store the result */
632 bit_write(s, ins->dest, ins->dest_type, dest, FMA);
633
634 /* For ADD - change out the passthrough */
635 if (!FMA) {
636 s->T0 = s->T;
637 s->T1 = dest.u32;
638 }
639 }
640
641 #undef bh
642 #undef bf