2 * Copyright (C) 2020 Collabora Ltd.
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 * Authors (Collabora):
24 * Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
29 #include "util/half_float.h"
45 /* Interprets a subset of Bifrost IR required for automated testing */
48 bit_read(struct bit_state
*s
, bi_instruction
*ins
, unsigned index
, nir_alu_type T
, bool FMA
)
50 if (index
& BIR_INDEX_REGISTER
) {
51 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
54 } else if (index
& BIR_INDEX_UNIFORM
) {
55 unreachable("Uniform registers to be implemented");
56 } else if (index
& BIR_INDEX_CONSTANT
) {
57 return ins
->constant
.u64
>> (index
& ~BIR_INDEX_CONSTANT
);
58 } else if (index
& BIR_INDEX_ZERO
) {
60 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_STAGE
)) {
61 return FMA
? 0 : s
->T
;
62 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_FMA
)) {
64 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_ADD
)) {
70 unreachable("Invalid source");
75 bit_write(struct bit_state
*s
, unsigned index
, nir_alu_type T
, bit_t value
, bool FMA
)
77 /* Always write stage passthrough */
81 if (index
& BIR_INDEX_REGISTER
) {
82 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
84 s
->r
[reg
] = value
.u32
;
88 unreachable("Invalid destination");
92 #define bh _mesa_float_to_half
93 #define bf _mesa_half_to_float
96 for (unsigned c = 0; c < 2; ++c) { \
97 dest.f16[c] = bh(fxn(bf(srcs[0].f16[ins->swizzle[0][c]]), \
98 bf(srcs[1].f16[ins->swizzle[1][c]]), \
99 bf(srcs[2].f16[ins->swizzle[2][c]]), \
100 bf(srcs[3].f16[ins->swizzle[3][c]]))); \
103 #define bv2i16(fxn) \
104 for (unsigned c = 0; c < 2; ++c) { \
105 dest.f16[c] = fxn(srcs[0].u16[ins->swizzle[0][c]], \
106 srcs[1].u16[ins->swizzle[1][c]], \
107 srcs[2].u16[ins->swizzle[2][c]], \
108 srcs[3].u16[ins->swizzle[3][c]]); \
112 for (unsigned c = 0; c < 4; ++c) { \
113 dest.u8[c] = fxn(srcs[0].u8[ins->swizzle[0][c]], \
114 srcs[1].u8[ins->swizzle[1][c]], \
115 srcs[2].u8[ins->swizzle[2][c]], \
116 srcs[3].u8[ins->swizzle[3][c]]); \
119 #define bf32(fxn) dest.f32 = fxn(srcs[0].f32, srcs[1].f32, srcs[2].f32, srcs[3].f32)
120 #define bi32(fxn) dest.i32 = fxn(srcs[0].u32, srcs[1].u32, srcs[2].u32, srcs[3].i32)
122 #define bfloat(fxn64, fxn32) \
123 if (ins->dest_type == nir_type_float64) { \
124 unreachable("TODO: 64-bit"); \
125 } else if (ins->dest_type == nir_type_float32) { \
128 } else if (ins->dest_type == nir_type_float16) { \
133 #define bint(fxn64, fxn32, fxn16, fxn8) \
134 if (ins->dest_type == nir_type_int64 || ins->dest_type == nir_type_uint64) { \
135 unreachable("TODO: 64-bit"); \
136 } else if (ins->dest_type == nir_type_int32 || ins->dest_type == nir_type_uint32) { \
139 } else if (ins->dest_type == nir_type_int16 || ins->dest_type == nir_type_uint16) { \
142 } else if (ins->dest_type == nir_type_int8 || ins->dest_type == nir_type_uint8) { \
147 #define bpoly(name) \
148 bfloat(bit_f64 ## name, bit_f32 ## name); \
149 bint(bit_i64 ## name, bit_i32 ## name, bit_i16 ## name, bit_i8 ## name); \
150 unreachable("Invalid type");
152 #define bit_make_float_2(name, expr32, expr64) \
153 static inline double \
154 bit_f64 ## name(double a, double b, double c, double d) \
158 static inline float \
159 bit_f32 ## name(float a, float b, float c, float d) \
164 #define bit_make_float(name, expr) \
165 bit_make_float_2(name, expr, expr)
167 #define bit_make_int(name, expr) \
168 static inline int64_t \
169 bit_i64 ## name (int64_t a, int64_t b, int64_t c, int64_t d) \
174 static inline int32_t \
175 bit_i32 ## name (int32_t a, int32_t b, int32_t c, int32_t d) \
180 static inline int16_t \
181 bit_i16 ## name (int16_t a, int16_t b, int16_t c, int16_t d) \
186 static inline int8_t \
187 bit_i8 ## name (int8_t a, int8_t b, int8_t c, int8_t d) \
192 #define bit_make_poly(name, expr) \
193 bit_make_float(name, expr) \
194 bit_make_int(name, expr) \
196 bit_make_poly(add
, a
+ b
);
197 bit_make_int(sub
, a
- b
);
198 bit_make_float(fma
, (a
* b
) + c
);
199 bit_make_poly(mov
, a
);
200 bit_make_poly(min
, MIN2(a
, b
));
201 bit_make_poly(max
, MAX2(a
, b
));
202 bit_make_float_2(floor
, floorf(a
), floor(a
));
203 bit_make_float_2(ceil
, ceilf(a
), ceil(a
));
204 bit_make_float_2(trunc
, truncf(a
), trunc(a
));
205 bit_make_float_2(nearbyint
, nearbyintf(a
), nearbyint(a
));
210 bit_outmod(float raw
, enum bifrost_outmod mod
)
214 return MAX2(raw
, 0.0);
215 case BIFROST_SAT_SIGNED
:
216 return CLAMP(raw
, -1.0, 1.0);
218 return SATURATE(raw
);
225 bit_srcmod(float raw
, bool abs
, bool neg
)
236 #define BIT_COND(cond, left, right) \
237 if (cond == BI_COND_LT) return left < right; \
238 else if (cond == BI_COND_LE) return left <= right; \
239 else if (cond == BI_COND_GE) return left >= right; \
240 else if (cond == BI_COND_GT) return left > right; \
241 else if (cond == BI_COND_EQ) return left == right; \
242 else if (cond == BI_COND_NE) return left != right; \
243 else { return true; }
246 bit_eval_cond(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned cl
, unsigned cr
)
248 if (T
== nir_type_float32
) {
249 BIT_COND(cond
, l
.f32
, r
.f32
);
250 } else if (T
== nir_type_float16
) {
251 float left
= bf(l
.f16
[cl
]);
252 float right
= bf(r
.f16
[cr
]);
253 BIT_COND(cond
, left
, right
);
254 } else if (T
== nir_type_int32
) {
255 int32_t left
= l
.u32
;
256 int32_t right
= r
.u32
;
257 BIT_COND(cond
, left
, right
);
258 } else if (T
== nir_type_int16
) {
259 int16_t left
= l
.i16
[cl
];
260 int16_t right
= r
.i16
[cr
];
261 BIT_COND(cond
, left
, right
);
262 } else if (T
== nir_type_uint32
) {
263 BIT_COND(cond
, l
.u32
, r
.u32
);
264 } else if (T
== nir_type_uint16
) {
265 BIT_COND(cond
, l
.u16
[cl
], r
.u16
[cr
]);
267 unreachable("Unknown type evaluated");
272 bit_cmp(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned cl
, unsigned cr
, bool d3d
)
274 bool v
= bit_eval_cond(cond
, l
, r
, T
, cl
, cr
);
276 /* Fill for D3D but only up to 32-bit... 64-bit is only partial
277 * (although we probably need a cleverer representation for 64-bit) */
279 unsigned sz
= MIN2(nir_alu_type_get_type_size(T
), 32);
280 unsigned max
= (sz
== 32) ? (~0) : ((1 << sz
) - 1);
282 return v
? (d3d
? max
: 1) : 0;
286 biti_special(float Q
, enum bi_special_op op
)
289 case BI_SPECIAL_FRCP
: return 1.0 / Q
;
290 case BI_SPECIAL_FRSQ
: {
291 double Qf
= 1.0 / sqrt(Q
);
294 default: unreachable("Invalid special");
298 /* For BI_CONVERT. */
300 #define _AS_ROUNDMODE(mode) \
301 ((mode == BIFROST_RTZ) ? FP_INT_TOWARDZERO : \
302 (mode == BIFROST_RTE) ? FP_INT_TONEAREST : \
303 (mode == BIFROST_RTN) ? FP_INT_DOWNWARD : \
307 bit_as_float32(nir_alu_type T
, bit_t src
, unsigned C
)
310 case nir_type_int32
: return src
.i32
;
311 case nir_type_uint32
: return src
.u32
;
312 case nir_type_float16
: return bf(src
.u16
[C
]);
313 default: unreachable("Invalid");
318 bit_as_uint32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
321 case nir_type_float16
: return bf(src
.u16
[C
]);
322 case nir_type_float32
: return ufromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
323 default: unreachable("Invalid");
328 bit_as_int32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
331 case nir_type_float16
: return bf(src
.u16
[C
]);
332 case nir_type_float32
: return fromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
333 default: unreachable("Invalid");
338 bit_as_float16(nir_alu_type T
, bit_t src
, unsigned C
)
341 case nir_type_int32
: return bh(src
.i32
);
342 case nir_type_uint32
: return bh(src
.u32
);
343 case nir_type_float32
: return bh(src
.f32
);
344 case nir_type_int16
: return bh(src
.i16
[C
]);
345 case nir_type_uint16
: return bh(src
.u16
[C
]);
346 default: unreachable("Invalid");
351 bit_as_uint16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
354 case nir_type_int32
: return src
.i32
;
355 case nir_type_uint32
: return src
.u32
;
356 case nir_type_float16
: return ufromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
357 case nir_type_float32
: return src
.f32
;
358 default: unreachable("Invalid");
363 bit_as_int16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
366 case nir_type_int32
: return src
.i32
;
367 case nir_type_uint32
: return src
.u32
;
368 case nir_type_float16
: return fromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
369 case nir_type_float32
: return src
.f32
;
370 default: unreachable("Invalid");
375 frexp_log(float x
, int *e
)
377 /* Ignore sign until end */
380 /* frexp reduces to [0.5, 1) */
381 float f
= frexpf(xa
, e
);
383 /* reduce to [0.75, 1.5) */
397 bit_step(struct bit_state
*s
, bi_instruction
*ins
, bool FMA
)
399 /* First, load sources */
400 bit_t srcs
[BIR_SRC_COUNT
] = { 0 };
402 bi_foreach_src(ins
, src
)
403 srcs
[src
].u64
= bit_read(s
, ins
, ins
->src
[src
], ins
->src_types
[src
], FMA
);
405 /* Apply source modifiers if we need to */
406 if (bi_has_source_mods(ins
)) {
407 bi_foreach_src(ins
, src
) {
408 if (ins
->src_types
[src
] == nir_type_float16
) {
409 for (unsigned c
= 0; c
< 2; ++c
) {
410 srcs
[src
].f16
[c
] = bh(bit_srcmod(bf(srcs
[src
].f16
[c
]),
414 } else if (ins
->src_types
[src
] == nir_type_float32
) {
415 srcs
[src
].f32
= bit_srcmod(srcs
[src
].f32
,
422 /* Next, do the action of the instruction */
430 unreachable("Unsupported op");
433 nir_alu_type T
= ins
->src_types
[0];
434 unsigned sz
= nir_alu_type_get_type_size(T
);
436 if (sz
== 32 || sz
== 64) {
437 dest
.u32
= bit_cmp(ins
->cond
, srcs
[0], srcs
[1], T
, 0, 0, false);
438 } else if (sz
== 16) {
439 for (unsigned c
= 0; c
< 2; ++c
) {
440 dest
.u16
[c
] = bit_cmp(ins
->cond
, srcs
[0], srcs
[1],
441 T
, ins
->swizzle
[0][c
], ins
->swizzle
[1][c
],
444 } else if (sz
== 8) {
445 for (unsigned c
= 0; c
< 4; ++c
) {
446 dest
.u8
[c
] = bit_cmp(ins
->cond
, srcs
[0], srcs
[1],
447 T
, ins
->swizzle
[0][c
], ins
->swizzle
[1][c
],
451 unreachable("Invalid");
458 /* Apply inverts first */
459 if (ins
->bitwise
.src_invert
[0])
460 srcs
[0].u64
= ~srcs
[0].u64
;
462 if (ins
->bitwise
.src_invert
[1])
463 srcs
[1].u64
= ~srcs
[1].u64
;
466 assert(srcs
[2].u32
== 0);
468 if (ins
->op
.bitwise
== BI_BITWISE_AND
)
469 dest
.u64
= srcs
[0].u64
& srcs
[1].u64
;
470 else if (ins
->op
.bitwise
== BI_BITWISE_OR
)
471 dest
.u64
= srcs
[0].u64
| srcs
[1].u64
;
472 else if (ins
->op
.bitwise
== BI_BITWISE_XOR
)
473 dest
.u64
= srcs
[0].u64
^ srcs
[1].u64
;
475 unreachable("Unsupported op");
482 unsigned comp
= ins
->swizzle
[0][1];
484 if (ins
->dest_type
== nir_type_float32
)
485 dest
.f32
= bit_as_float32(ins
->src_types
[0], srcs
[0], comp
);
486 else if (ins
->dest_type
== nir_type_uint32
)
487 dest
.u32
= bit_as_uint32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
488 else if (ins
->dest_type
== nir_type_int32
)
489 dest
.i32
= bit_as_int32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
490 else if (ins
->dest_type
== nir_type_float16
) {
491 dest
.u16
[0] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0]);
493 if (ins
->src_types
[0] == nir_type_float32
) {
494 /* TODO: Second argument */
497 dest
.u16
[1] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1]);
499 } else if (ins
->dest_type
== nir_type_uint16
) {
500 dest
.u16
[0] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
501 dest
.u16
[1] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
502 } else if (ins
->dest_type
== nir_type_int16
) {
503 dest
.i16
[0] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
504 dest
.i16
[1] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
506 unreachable("Unknown convert type");
513 bool direct
= ins
->cond
== BI_COND_ALWAYS
;
514 unsigned sz
= nir_alu_type_get_type_size(ins
->src_types
[0]);
517 bool cond
= direct
? srcs
[0].u32
:
518 bit_eval_cond(ins
->cond
, srcs
[0], srcs
[1], ins
->src_types
[0], 0, 0);
520 dest
= cond
? srcs
[2] : srcs
[3];
521 } else if (sz
== 16) {
522 for (unsigned c
= 0; c
< 2; ++c
) {
523 bool cond
= direct
? srcs
[0].u16
[c
] :
524 bit_eval_cond(ins
->cond
, srcs
[0], srcs
[1], ins
->src_types
[0], c
, c
);
526 dest
.u16
[c
] = cond
? srcs
[2].u16
[c
] : srcs
[3].u16
[c
];
529 unreachable("Remaining types todo");
536 bfloat(bit_f64fma
, bit_f32fma
);
537 unreachable("Unknown type");
541 if (ins
->src_types
[0] != nir_type_float32
)
542 unreachable("Unknown frexp type");
545 if (ins
->op
.frexp
== BI_FREXPE_LOG
)
546 frexp_log(srcs
[0].f32
, &dest
.i32
);
548 unreachable("Unknown frexp");
554 if (ins
->op
.imath
== BI_IMATH_ADD
) {
555 bint(bit_i64add
, bit_i32add
, bit_i16add
, bit_i8add
);
556 } else if (ins
->op
.imath
== BI_IMATH_SUB
) {
557 bint(bit_i64sub
, bit_i32sub
, bit_i16sub
, bit_i8sub
);
559 unreachable("Unsupported op");
566 if (ins
->op
.minmax
== BI_MINMAX_MIN
) {
576 case BI_REDUCE_FMA
: {
577 if (ins
->src_types
[0] != nir_type_float32
)
578 unreachable("Unknown reduce type");
580 if (ins
->op
.reduce
== BI_REDUCE_ADD_FREXPM
) {
582 float f
= frexp_log(srcs
[1].f32
, &_nop
);
583 dest
.f32
= srcs
[0].f32
+ f
;
585 unreachable("Unknown reduce");
592 assert(nir_alu_type_get_base_type(ins
->dest_type
) == nir_type_float
);
593 assert(ins
->dest_type
!= nir_type_float64
);
595 if (ins
->op
.special
== BI_SPECIAL_EXP2_LOW
) {
596 assert(ins
->dest_type
== nir_type_float32
);
597 dest
.f32
= exp2f(srcs
[1].f32
);
601 float Q
= (ins
->dest_type
== nir_type_float16
) ?
602 bf(srcs
[0].u16
[ins
->swizzle
[0][0]]) :
605 float R
= biti_special(Q
, ins
->op
.special
);
607 if (ins
->dest_type
== nir_type_float16
) {
610 if (!ins
->swizzle
[0][0] && ins
->op
.special
== BI_SPECIAL_FRSQ
) {
621 if (ins
->op
.table
== BI_TABLE_LOG2_U_OVER_U_1_LOW
) {
622 assert(ins
->dest_type
== nir_type_float32
);
624 float f
= frexp_log(srcs
[0].f32
, &_nop
);
625 dest
.f32
= log2f(f
) / (f
- 1.0);
626 dest
.u32
++; /* Sorry. */
628 unreachable("Unknown table op");
634 if (ins
->src_types
[0] == nir_type_uint16
) {
635 for (unsigned c
= 0; c
< 2; ++c
)
636 dest
.u16
[c
] = srcs
[c
].u16
[ins
->swizzle
[c
][0]];
637 } else if (ins
->src_types
[0] == nir_type_uint8
) {
638 for (unsigned c
= 0; c
< 4; ++c
)
639 dest
.u8
[c
] = srcs
[c
].u8
[ins
->swizzle
[c
][0]];
641 unreachable("Unknown type");
647 unreachable("Unsupported op");
650 if (ins
->roundmode
== BIFROST_RTP
) {
651 bfloat(bit_f64ceil
, bit_f32ceil
);
652 } else if (ins
->roundmode
== BIFROST_RTN
) {
653 bfloat(bit_f64floor
, bit_f32floor
);
654 } else if (ins
->roundmode
== BIFROST_RTE
) {
655 bfloat(bit_f64nearbyint
, bit_f32nearbyint
);
656 } else if (ins
->roundmode
== BIFROST_RTZ
) {
657 bfloat(bit_f64trunc
, bit_f32trunc
);
659 unreachable("Invalid");
664 /* We only interpret vertex shaders */
669 unreachable("Fragment op used in interpreter");
671 /* Modeling main memory is more than I bargained for */
672 case BI_LOAD_UNIFORM
:
674 case BI_LOAD_VAR_ADDRESS
:
679 unreachable("Unsupported I/O in interpreter");
682 unreachable("Unsupported op");
686 if ((ins
->type
== BI_FMA
|| ins
->type
== BI_ADD
) && ins
->op
.mscale
) {
687 unsigned idx
= (ins
->type
== BI_FMA
) ? 3 : 2;
689 assert(ins
->src_types
[idx
] == nir_type_int32
);
690 assert(ins
->dest_type
== nir_type_float32
);
692 int32_t scale
= srcs
[idx
].i32
;
693 dest
.f32
*= exp2f(scale
);
697 if (bi_has_outmod(ins
) && ins
->outmod
!= BIFROST_NONE
) {
698 if (ins
->dest_type
== nir_type_float16
) {
699 for (unsigned c
= 0; c
< 2; ++c
)
700 dest
.f16
[c
] = bh(bit_outmod(bf(dest
.f16
[c
]), ins
->outmod
));
702 dest
.f32
= bit_outmod(dest
.f32
, ins
->outmod
);
706 /* Finally, store the result */
707 bit_write(s
, ins
->dest
, ins
->dest_type
, dest
, FMA
);
709 /* For ADD - change out the passthrough */