2 * Copyright (C) 2020 Collabora Ltd.
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 * Authors (Collabora):
24 * Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
29 #include "util/half_float.h"
45 /* Interprets a subset of Bifrost IR required for automated testing */
48 bit_read(struct bit_state
*s
, bi_instruction
*ins
, unsigned index
, nir_alu_type T
, bool FMA
)
50 if (index
& BIR_INDEX_REGISTER
) {
51 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
54 } else if (index
& BIR_INDEX_UNIFORM
) {
55 unreachable("Uniform registers to be implemented");
56 } else if (index
& BIR_INDEX_CONSTANT
) {
57 return ins
->constant
.u64
>> (index
& ~BIR_INDEX_CONSTANT
);
58 } else if (index
& BIR_INDEX_ZERO
) {
60 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_STAGE
)) {
61 return FMA
? 0 : s
->T
;
62 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_FMA
)) {
64 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_ADD
)) {
70 unreachable("Invalid source");
75 bit_write(struct bit_state
*s
, unsigned index
, nir_alu_type T
, bit_t value
, bool FMA
)
77 /* Always write stage passthrough */
81 if (index
& BIR_INDEX_REGISTER
) {
82 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
84 s
->r
[reg
] = value
.u32
;
88 unreachable("Invalid destination");
92 #define bh _mesa_float_to_half
93 #define bf _mesa_half_to_float
96 for (unsigned c = 0; c < 2; ++c) { \
97 dest.f16[c] = bh(fxn(bf(srcs[0].f16[ins->swizzle[0][c]]), \
98 bf(srcs[1].f16[ins->swizzle[1][c]]), \
99 bf(srcs[2].f16[ins->swizzle[2][c]]), \
100 bf(srcs[3].f16[ins->swizzle[3][c]]))); \
103 #define bv2i16(fxn) \
104 for (unsigned c = 0; c < 2; ++c) { \
105 dest.f16[c] = fxn(srcs[0].u16[ins->swizzle[0][c]], \
106 srcs[1].u16[ins->swizzle[1][c]], \
107 srcs[2].u16[ins->swizzle[2][c]], \
108 srcs[3].u16[ins->swizzle[3][c]]); \
111 #define bf32(fxn) dest.f32 = fxn(srcs[0].f32, srcs[1].f32, srcs[2].f32, srcs[3].f32)
112 #define bi32(fxn) dest.i32 = fxn(srcs[0].u32, srcs[1].u32, srcs[2].u32, srcs[3].i32)
114 #define bfloat(fxn64, fxn32) \
115 if (ins->dest_type == nir_type_float64) { \
116 unreachable("TODO: 64-bit"); \
117 } else if (ins->dest_type == nir_type_float32) { \
120 } else if (ins->dest_type == nir_type_float16) { \
125 #define bint(fxn64, fxn32, fxn16, fxn8) \
126 if (ins->dest_type == nir_type_int64 || ins->dest_type == nir_type_uint64) { \
127 unreachable("TODO: 64-bit"); \
128 } else if (ins->dest_type == nir_type_int32 || ins->dest_type == nir_type_uint32) { \
131 } else if (ins->dest_type == nir_type_int16 || ins->dest_type == nir_type_uint16) { \
134 } else if (ins->dest_type == nir_type_int8 || ins->dest_type == nir_type_uint8) { \
135 unreachable("TODO: 8-bit"); \
138 #define bpoly(name) \
139 bfloat(bit_f64 ## name, bit_f32 ## name); \
140 bint(bit_i64 ## name, bit_i32 ## name, bit_i16 ## name, bit_i8 ## name); \
141 unreachable("Invalid type");
143 #define bit_make_float(name, expr) \
144 static inline double \
145 bit_f64 ## name(double a, double b, double c, double d) \
149 static inline float \
150 bit_f32 ## name(float a, float b, float c, float d) \
155 #define bit_make_int(name, expr) \
156 static inline int64_t \
157 bit_i64 ## name (int64_t a, int64_t b, int64_t c, int64_t d) \
162 static inline int32_t \
163 bit_i32 ## name (int32_t a, int32_t b, int32_t c, int32_t d) \
168 static inline int16_t \
169 bit_i16 ## name (int16_t a, int16_t b, int16_t c, int16_t d) \
174 static inline int8_t \
175 bit_i8 ## name (int8_t a, int8_t b, int8_t c, int8_t d) \
180 #define bit_make_poly(name, expr) \
181 bit_make_float(name, expr) \
182 bit_make_int(name, expr) \
184 bit_make_poly(add
, a
+ b
);
185 bit_make_float(fma
, (a
* b
) + c
);
186 bit_make_poly(mov
, a
);
187 bit_make_poly(min
, MIN2(a
, b
));
188 bit_make_poly(max
, MAX2(a
, b
));
193 bit_outmod(float raw
, enum bifrost_outmod mod
)
197 return MAX2(raw
, 0.0);
198 case BIFROST_SAT_SIGNED
:
199 return CLAMP(raw
, -1.0, 1.0);
201 return CLAMP(raw
, 0.0, 1.0);
208 bit_srcmod(float raw
, bool abs
, bool neg
)
219 #define BIT_COND(cond, left, right) \
220 if (cond == BI_COND_LT) return left < right; \
221 else if (cond == BI_COND_LE) return left <= right; \
222 else if (cond == BI_COND_GE) return left >= right; \
223 else if (cond == BI_COND_GT) return left > right; \
224 else if (cond == BI_COND_EQ) return left == right; \
225 else if (cond == BI_COND_NE) return left != right; \
226 else { return true; }
229 bit_eval_cond(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned cl
, unsigned cr
)
231 if (T
== nir_type_float32
) {
232 BIT_COND(cond
, l
.f32
, r
.f32
);
233 } else if (T
== nir_type_float16
) {
234 float left
= bf(l
.f16
[cl
]);
235 float right
= bf(r
.f16
[cr
]);
236 BIT_COND(cond
, left
, right
);
237 } else if (T
== nir_type_int32
) {
238 int32_t left
= l
.u32
;
239 int32_t right
= r
.u32
;
240 BIT_COND(cond
, left
, right
);
241 } else if (T
== nir_type_int16
) {
242 int16_t left
= l
.i16
[cl
];
243 int16_t right
= r
.i16
[cr
];
244 BIT_COND(cond
, left
, right
);
245 } else if (T
== nir_type_uint32
) {
246 BIT_COND(cond
, l
.u32
, r
.u32
);
247 } else if (T
== nir_type_uint16
) {
248 BIT_COND(cond
, l
.u16
[cl
], r
.u16
[cr
]);
250 unreachable("Unknown type evaluated");
255 bit_cmp(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned cl
, unsigned cr
, bool d3d
)
257 bool v
= bit_eval_cond(cond
, l
, r
, T
, cl
, cr
);
259 /* Fill for D3D but only up to 32-bit... 64-bit is only partial
260 * (although we probably need a cleverer representation for 64-bit) */
262 unsigned sz
= MIN2(nir_alu_type_get_type_size(T
), 32);
263 unsigned max
= (sz
== 32) ? (~0) : ((1 << sz
) - 1);
265 return v
? (d3d
? max
: 1) : 0;
269 biti_special(float Q
, enum bi_special_op op
)
272 case BI_SPECIAL_FRCP
: return 1.0 / Q
;
273 case BI_SPECIAL_FRSQ
: {
274 double Qf
= 1.0 / sqrt(Q
);
277 default: unreachable("Invalid special");
281 /* For BI_CONVERT. */
283 #define _AS_ROUNDMODE(mode) \
284 ((mode == BIFROST_RTZ) ? FP_INT_TOWARDZERO : \
285 (mode == BIFROST_RTE) ? FP_INT_TONEAREST : \
286 (mode == BIFROST_RTN) ? FP_INT_DOWNWARD : \
290 bit_as_float32(nir_alu_type T
, bit_t src
, unsigned C
)
293 case nir_type_int32
: return src
.i32
;
294 case nir_type_uint32
: return src
.u32
;
295 case nir_type_float16
: return bf(src
.u16
[C
]);
296 default: unreachable("Invalid");
301 bit_as_uint32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
304 case nir_type_float16
: return bf(src
.u16
[C
]);
305 case nir_type_float32
: return ufromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
306 default: unreachable("Invalid");
311 bit_as_int32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
314 case nir_type_float16
: return bf(src
.u16
[C
]);
315 case nir_type_float32
: return fromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
316 default: unreachable("Invalid");
321 bit_as_float16(nir_alu_type T
, bit_t src
, unsigned C
)
324 case nir_type_int32
: return bh(src
.i32
);
325 case nir_type_uint32
: return bh(src
.u32
);
326 case nir_type_float32
: return bh(src
.f32
);
327 case nir_type_int16
: return bh(src
.i16
[C
]);
328 case nir_type_uint16
: return bh(src
.u16
[C
]);
329 default: unreachable("Invalid");
334 bit_as_uint16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
337 case nir_type_int32
: return src
.i32
;
338 case nir_type_uint32
: return src
.u32
;
339 case nir_type_float16
: return ufromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
340 case nir_type_float32
: return src
.f32
;
341 default: unreachable("Invalid");
346 bit_as_int16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
349 case nir_type_int32
: return src
.i32
;
350 case nir_type_uint32
: return src
.u32
;
351 case nir_type_float16
: return fromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
352 case nir_type_float32
: return src
.f32
;
353 default: unreachable("Invalid");
358 frexp_log(float x
, int *e
)
360 /* Ignore sign until end */
363 /* frexp reduces to [0.5, 1) */
364 float f
= frexpf(xa
, e
);
366 /* reduce to [0.75, 1.5) */
380 bit_step(struct bit_state
*s
, bi_instruction
*ins
, bool FMA
)
382 /* First, load sources */
383 bit_t srcs
[BIR_SRC_COUNT
] = { 0 };
385 bi_foreach_src(ins
, src
)
386 srcs
[src
].u64
= bit_read(s
, ins
, ins
->src
[src
], ins
->src_types
[src
], FMA
);
388 /* Apply source modifiers if we need to */
389 if (bi_has_source_mods(ins
)) {
390 bi_foreach_src(ins
, src
) {
391 if (ins
->src_types
[src
] == nir_type_float16
) {
392 for (unsigned c
= 0; c
< 2; ++c
) {
393 srcs
[src
].f16
[c
] = bh(bit_srcmod(bf(srcs
[src
].f16
[c
]),
397 } else if (ins
->src_types
[src
] == nir_type_float32
) {
398 srcs
[src
].f32
= bit_srcmod(srcs
[src
].f32
,
405 /* Next, do the action of the instruction */
413 unreachable("Unsupported op");
416 nir_alu_type T
= ins
->src_types
[0];
417 unsigned sz
= nir_alu_type_get_type_size(T
);
419 if (sz
== 32 || sz
== 64) {
420 dest
.u32
= bit_cmp(ins
->cond
, srcs
[0], srcs
[1], T
, 0, 0, false);
421 } else if (sz
== 16) {
422 for (unsigned c
= 0; c
< 2; ++c
) {
423 dest
.u16
[c
] = bit_cmp(ins
->cond
, srcs
[0], srcs
[1],
424 T
, ins
->swizzle
[0][c
], ins
->swizzle
[1][c
],
427 } else if (sz
== 8) {
428 for (unsigned c
= 0; c
< 4; ++c
) {
429 dest
.u8
[c
] = bit_cmp(ins
->cond
, srcs
[0], srcs
[1],
430 T
, ins
->swizzle
[0][c
], ins
->swizzle
[1][c
],
434 unreachable("Invalid");
441 unreachable("Unsupported op");
445 unsigned comp
= ins
->swizzle
[0][1];
447 if (ins
->dest_type
== nir_type_float32
)
448 dest
.f32
= bit_as_float32(ins
->src_types
[0], srcs
[0], comp
);
449 else if (ins
->dest_type
== nir_type_uint32
)
450 dest
.u32
= bit_as_uint32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
451 else if (ins
->dest_type
== nir_type_int32
)
452 dest
.i32
= bit_as_int32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
453 else if (ins
->dest_type
== nir_type_float16
) {
454 dest
.u16
[0] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0]);
456 if (ins
->src_types
[0] == nir_type_float32
) {
457 /* TODO: Second argument */
460 dest
.u16
[1] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1]);
462 } else if (ins
->dest_type
== nir_type_uint16
) {
463 dest
.u16
[0] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
464 dest
.u16
[1] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
465 } else if (ins
->dest_type
== nir_type_int16
) {
466 dest
.i16
[0] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
467 dest
.i16
[1] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
469 unreachable("Unknown convert type");
476 bool direct
= ins
->cond
== BI_COND_ALWAYS
;
477 bool cond
= direct
? srcs
[0].u32
:
478 bit_eval_cond(ins
->cond
, srcs
[0], srcs
[1], ins
->src_types
[0], 0, 0);
480 dest
= cond
? srcs
[2] : srcs
[3];
485 bfloat(bit_f64fma
, bit_f32fma
);
486 unreachable("Unknown type");
490 if (ins
->src_types
[0] != nir_type_float32
)
491 unreachable("Unknown frexp type");
494 if (ins
->op
.frexp
== BI_FREXPE_LOG
)
495 frexp_log(srcs
[0].f32
, &dest
.i32
);
497 unreachable("Unknown frexp");
502 unreachable("Unsupported op");
505 if (ins
->op
.minmax
== BI_MINMAX_MIN
) {
515 case BI_REDUCE_FMA
: {
516 if (ins
->src_types
[0] != nir_type_float32
)
517 unreachable("Unknown reduce type");
519 if (ins
->op
.reduce
== BI_REDUCE_ADD_FREXPM
) {
521 float f
= frexp_log(srcs
[1].f32
, &_nop
);
522 dest
.f32
= srcs
[0].f32
+ f
;
524 unreachable("Unknown reduce");
531 assert(nir_alu_type_get_base_type(ins
->dest_type
) == nir_type_float
);
532 assert(ins
->dest_type
!= nir_type_float64
);
534 if (ins
->op
.special
== BI_SPECIAL_EXP2_LOW
) {
535 assert(ins
->dest_type
== nir_type_float32
);
536 dest
.f32
= exp2f(srcs
[1].f32
);
540 float Q
= (ins
->dest_type
== nir_type_float16
) ?
541 bf(srcs
[0].u16
[ins
->swizzle
[0][0]]) :
544 float R
= biti_special(Q
, ins
->op
.special
);
546 if (ins
->dest_type
== nir_type_float16
) {
549 if (!ins
->swizzle
[0][0] && ins
->op
.special
== BI_SPECIAL_FRSQ
) {
560 if (ins
->op
.table
== BI_TABLE_LOG2_U_OVER_U_1_LOW
) {
561 assert(ins
->dest_type
== nir_type_float32
);
563 float f
= frexp_log(srcs
[0].f32
, &_nop
);
564 dest
.f32
= log2f(f
) / (f
- 1.0);
565 dest
.u32
++; /* Sorry. */
567 unreachable("Unknown table op");
573 if (ins
->src_types
[0] == nir_type_uint16
) {
574 for (unsigned c
= 0; c
< 2; ++c
)
575 dest
.u16
[c
] = srcs
[c
].u16
[ins
->swizzle
[c
][0]];
576 } else if (ins
->src_types
[0] == nir_type_uint8
) {
577 for (unsigned c
= 0; c
< 4; ++c
)
578 dest
.u8
[c
] = srcs
[c
].u8
[ins
->swizzle
[c
][0]];
580 unreachable("Unknown type");
587 unreachable("Unsupported op");
589 /* We only interpret vertex shaders */
594 unreachable("Fragment op used in interpreter");
596 /* Modeling main memory is more than I bargained for */
597 case BI_LOAD_UNIFORM
:
599 case BI_LOAD_VAR_ADDRESS
:
604 unreachable("Unsupported I/O in interpreter");
607 unreachable("Unsupported op");
611 if ((ins
->type
== BI_FMA
|| ins
->type
== BI_ADD
) && ins
->op
.mscale
) {
612 unsigned idx
= (ins
->type
== BI_FMA
) ? 3 : 2;
614 assert(ins
->src_types
[idx
] == nir_type_int32
);
615 assert(ins
->dest_type
== nir_type_float32
);
617 int32_t scale
= srcs
[idx
].i32
;
618 dest
.f32
*= exp2f(scale
);
622 if (bi_has_outmod(ins
) && ins
->outmod
!= BIFROST_NONE
) {
623 if (ins
->dest_type
== nir_type_float16
) {
624 for (unsigned c
= 0; c
< 2; ++c
)
625 dest
.f16
[c
] = bh(bit_outmod(bf(dest
.f16
[c
]), ins
->outmod
));
627 dest
.f32
= bit_outmod(dest
.f32
, ins
->outmod
);
631 /* Finally, store the result */
632 bit_write(s
, ins
->dest
, ins
->dest_type
, dest
, FMA
);
634 /* For ADD - change out the passthrough */