2 * Copyright (C) 2020 Collabora Ltd.
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 * Authors (Collabora):
24 * Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
29 #include "util/half_float.h"
45 /* Interprets a subset of Bifrost IR required for automated testing */
48 bit_read(struct bit_state
*s
, bi_instruction
*ins
, unsigned index
, nir_alu_type T
, bool FMA
)
50 if (index
& BIR_INDEX_REGISTER
) {
51 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
54 } else if (index
& BIR_INDEX_UNIFORM
) {
55 unreachable("Uniform registers to be implemented");
56 } else if (index
& BIR_INDEX_CONSTANT
) {
57 return ins
->constant
.u64
>> (index
& ~BIR_INDEX_CONSTANT
);
58 } else if (index
& BIR_INDEX_ZERO
) {
60 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_STAGE
)) {
61 return FMA
? 0 : s
->T
;
62 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_FMA
)) {
64 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_ADD
)) {
70 unreachable("Invalid source");
75 bit_write(struct bit_state
*s
, unsigned index
, nir_alu_type T
, bit_t value
, bool FMA
)
77 /* Always write stage passthrough */
81 if (index
& BIR_INDEX_REGISTER
) {
82 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
84 s
->r
[reg
] = value
.u32
;
88 unreachable("Invalid destination");
92 #define bh _mesa_float_to_half
93 #define bf _mesa_half_to_float
96 for (unsigned c = 0; c < 2; ++c) { \
97 dest.f16[c] = bh(fxn(bf(srcs[0].f16[ins->swizzle[0][c]]), \
98 bf(srcs[1].f16[ins->swizzle[1][c]]), \
99 bf(srcs[2].f16[ins->swizzle[2][c]]), \
100 bf(srcs[3].f16[ins->swizzle[3][c]]))); \
103 #define bv2i16(fxn) \
104 for (unsigned c = 0; c < 2; ++c) { \
105 dest.f16[c] = fxn(srcs[0].u16[ins->swizzle[0][c]], \
106 srcs[1].u16[ins->swizzle[1][c]], \
107 srcs[2].u16[ins->swizzle[2][c]], \
108 srcs[3].u16[ins->swizzle[3][c]]); \
111 #define bf32(fxn) dest.f32 = fxn(srcs[0].f32, srcs[1].f32, srcs[2].f32, srcs[3].f32)
112 #define bi32(fxn) dest.i32 = fxn(srcs[0].u32, srcs[1].u32, srcs[2].u32, srcs[3].i32)
114 #define bfloat(fxn64, fxn32) \
115 if (ins->dest_type == nir_type_float64) { \
116 unreachable("TODO: 64-bit"); \
117 } else if (ins->dest_type == nir_type_float32) { \
120 } else if (ins->dest_type == nir_type_float16) { \
125 #define bint(fxn64, fxn32, fxn16, fxn8) \
126 if (ins->dest_type == nir_type_int64 || ins->dest_type == nir_type_uint64) { \
127 unreachable("TODO: 64-bit"); \
128 } else if (ins->dest_type == nir_type_int32 || ins->dest_type == nir_type_uint32) { \
131 } else if (ins->dest_type == nir_type_int16 || ins->dest_type == nir_type_uint16) { \
134 } else if (ins->dest_type == nir_type_int8 || ins->dest_type == nir_type_uint8) { \
135 unreachable("TODO: 8-bit"); \
138 #define bpoly(name) \
139 bfloat(bit_f64 ## name, bit_f32 ## name); \
140 bint(bit_i64 ## name, bit_i32 ## name, bit_i16 ## name, bit_i8 ## name); \
141 unreachable("Invalid type");
143 #define bit_make_float(name, expr) \
144 static inline double \
145 bit_f64 ## name(double a, double b, double c, double d) \
149 static inline float \
150 bit_f32 ## name(float a, float b, float c, float d) \
155 #define bit_make_int(name, expr) \
156 static inline int64_t \
157 bit_i64 ## name (int64_t a, int64_t b, int64_t c, int64_t d) \
162 static inline int32_t \
163 bit_i32 ## name (int32_t a, int32_t b, int32_t c, int32_t d) \
168 static inline int16_t \
169 bit_i16 ## name (int16_t a, int16_t b, int16_t c, int16_t d) \
174 static inline int8_t \
175 bit_i8 ## name (int8_t a, int8_t b, int8_t c, int8_t d) \
180 #define bit_make_poly(name, expr) \
181 bit_make_float(name, expr) \
182 bit_make_int(name, expr) \
184 bit_make_poly(add
, a
+ b
);
185 bit_make_float(fma
, (a
* b
) + c
);
186 bit_make_poly(mov
, a
);
187 bit_make_poly(min
, MIN2(a
, b
));
188 bit_make_poly(max
, MAX2(a
, b
));
193 bit_outmod(float raw
, enum bifrost_outmod mod
)
197 return MAX2(raw
, 0.0);
198 case BIFROST_SAT_SIGNED
:
199 return CLAMP(raw
, -1.0, 1.0);
201 return CLAMP(raw
, 0.0, 1.0);
208 bit_srcmod(float raw
, bool abs
, bool neg
)
219 #define BIT_COND(cond, left, right) \
220 if (cond == BI_COND_LT) return left < right; \
221 else if (cond == BI_COND_LE) return left <= right; \
222 else if (cond == BI_COND_GE) return left >= right; \
223 else if (cond == BI_COND_GT) return left > right; \
224 else if (cond == BI_COND_EQ) return left == right; \
225 else if (cond == BI_COND_NE) return left != right; \
226 else { return true; }
229 bit_eval_cond(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned c
)
231 if (T
== nir_type_float32
) {
232 BIT_COND(cond
, l
.f32
, r
.f32
);
233 } else if (T
== nir_type_float16
) {
234 float left
= bf(l
.f16
[c
]);
235 float right
= bf(r
.f16
[c
]);
236 BIT_COND(cond
, left
, right
);
237 } else if (T
== nir_type_int32
) {
238 int32_t left
= (int32_t) l
.u32
;
239 int32_t right
= (int32_t) r
.u32
;
240 BIT_COND(cond
, left
, right
);
241 } else if (T
== nir_type_int16
) {
242 int16_t left
= (int16_t) l
.u32
;
243 int16_t right
= (int16_t) r
.u32
;
244 BIT_COND(cond
, left
, right
);
245 } else if (T
== nir_type_uint32
) {
246 BIT_COND(cond
, l
.u32
, r
.u32
);
247 } else if (T
== nir_type_uint16
) {
248 BIT_COND(cond
, l
.u16
[c
], r
.u16
[c
]);
250 unreachable("Unknown type evaluated");
255 biti_special(float Q
, enum bi_special_op op
)
258 case BI_SPECIAL_FRCP
: return 1.0 / Q
;
259 case BI_SPECIAL_FRSQ
: {
260 double Qf
= 1.0 / sqrt(Q
);
263 default: unreachable("Invalid special");
267 /* For BI_CONVERT. */
269 #define _AS_ROUNDMODE(mode) \
270 ((mode == BIFROST_RTZ) ? FP_INT_TOWARDZERO : \
271 (mode == BIFROST_RTE) ? FP_INT_TONEAREST : \
272 (mode == BIFROST_RTN) ? FP_INT_DOWNWARD : \
276 bit_as_float32(nir_alu_type T
, bit_t src
, unsigned C
)
279 case nir_type_int32
: return src
.i32
;
280 case nir_type_uint32
: return src
.u32
;
281 case nir_type_float16
: return bf(src
.u16
[C
]);
282 default: unreachable("Invalid");
287 bit_as_uint32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
290 case nir_type_float16
: return bf(src
.u16
[C
]);
291 case nir_type_float32
: return ufromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
292 default: unreachable("Invalid");
297 bit_as_int32(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
300 case nir_type_float16
: return bf(src
.u16
[C
]);
301 case nir_type_float32
: return fromfpf(src
.f32
, _AS_ROUNDMODE(rm
), 32);
302 default: unreachable("Invalid");
307 bit_as_float16(nir_alu_type T
, bit_t src
, unsigned C
)
310 case nir_type_int32
: return bh(src
.i32
);
311 case nir_type_uint32
: return bh(src
.u32
);
312 case nir_type_float32
: return bh(src
.f32
);
313 case nir_type_int16
: return bh(src
.i16
[C
]);
314 case nir_type_uint16
: return bh(src
.u16
[C
]);
315 default: unreachable("Invalid");
320 bit_as_uint16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
323 case nir_type_int32
: return src
.i32
;
324 case nir_type_uint32
: return src
.u32
;
325 case nir_type_float16
: return ufromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
326 case nir_type_float32
: return src
.f32
;
327 default: unreachable("Invalid");
332 bit_as_int16(nir_alu_type T
, bit_t src
, unsigned C
, enum bifrost_roundmode rm
)
335 case nir_type_int32
: return src
.i32
;
336 case nir_type_uint32
: return src
.u32
;
337 case nir_type_float16
: return fromfpf(bf(src
.u16
[C
]), _AS_ROUNDMODE(rm
), 16);
338 case nir_type_float32
: return src
.f32
;
339 default: unreachable("Invalid");
344 frexp_log(float x
, int *e
)
346 /* Ignore sign until end */
349 /* frexp reduces to [0.5, 1) */
350 float f
= frexpf(xa
, e
);
352 /* reduce to [0.75, 1.5) */
366 bit_step(struct bit_state
*s
, bi_instruction
*ins
, bool FMA
)
368 /* First, load sources */
369 bit_t srcs
[BIR_SRC_COUNT
] = { 0 };
371 bi_foreach_src(ins
, src
)
372 srcs
[src
].u64
= bit_read(s
, ins
, ins
->src
[src
], ins
->src_types
[src
], FMA
);
374 /* Apply source modifiers if we need to */
375 if (bi_has_source_mods(ins
)) {
376 bi_foreach_src(ins
, src
) {
377 if (ins
->src_types
[src
] == nir_type_float16
) {
378 for (unsigned c
= 0; c
< 2; ++c
) {
379 srcs
[src
].f16
[c
] = bh(bit_srcmod(bf(srcs
[src
].f16
[c
]),
383 } else if (ins
->src_types
[src
] == nir_type_float32
) {
384 srcs
[src
].f32
= bit_srcmod(srcs
[src
].f32
,
391 /* Next, do the action of the instruction */
401 unreachable("Unsupported op");
405 unsigned comp
= ins
->swizzle
[0][1];
407 if (ins
->dest_type
== nir_type_float32
)
408 dest
.f32
= bit_as_float32(ins
->src_types
[0], srcs
[0], comp
);
409 else if (ins
->dest_type
== nir_type_uint32
)
410 dest
.u32
= bit_as_uint32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
411 else if (ins
->dest_type
== nir_type_int32
)
412 dest
.i32
= bit_as_int32(ins
->src_types
[0], srcs
[0], comp
, ins
->roundmode
);
413 else if (ins
->dest_type
== nir_type_float16
) {
414 dest
.u16
[0] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0]);
416 if (ins
->src_types
[0] == nir_type_float32
) {
417 /* TODO: Second argument */
420 dest
.u16
[1] = bit_as_float16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1]);
422 } else if (ins
->dest_type
== nir_type_uint16
) {
423 dest
.u16
[0] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
424 dest
.u16
[1] = bit_as_uint16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
425 } else if (ins
->dest_type
== nir_type_int16
) {
426 dest
.i16
[0] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][0], ins
->roundmode
);
427 dest
.i16
[1] = bit_as_int16(ins
->src_types
[0], srcs
[0], ins
->swizzle
[0][1], ins
->roundmode
);
429 unreachable("Unknown convert type");
436 bool direct
= ins
->csel_cond
== BI_COND_ALWAYS
;
437 bool cond
= direct
? srcs
[0].u32
:
438 bit_eval_cond(ins
->csel_cond
, srcs
[0], srcs
[1], ins
->src_types
[0], 0);
440 dest
= cond
? srcs
[2] : srcs
[3];
445 bfloat(bit_f64fma
, bit_f32fma
);
446 unreachable("Unknown type");
450 if (ins
->src_types
[0] != nir_type_float32
)
451 unreachable("Unknown frexp type");
454 if (ins
->op
.frexp
== BI_FREXPE_LOG
)
455 frexp_log(srcs
[0].f32
, &dest
.i32
);
457 unreachable("Unknown frexp");
462 unreachable("Unsupported op");
465 if (ins
->op
.minmax
== BI_MINMAX_MIN
) {
475 case BI_REDUCE_FMA
: {
476 if (ins
->src_types
[0] != nir_type_float32
)
477 unreachable("Unknown reduce type");
479 if (ins
->op
.reduce
== BI_REDUCE_ADD_FREXPM
) {
481 float f
= frexp_log(srcs
[1].f32
, &_nop
);
482 dest
.f32
= srcs
[0].f32
+ f
;
484 unreachable("Unknown reduce");
491 assert(nir_alu_type_get_base_type(ins
->dest_type
) == nir_type_float
);
492 assert(ins
->dest_type
!= nir_type_float64
);
494 if (ins
->op
.special
== BI_SPECIAL_EXP2_LOW
) {
495 assert(ins
->dest_type
== nir_type_float32
);
496 dest
.f32
= exp2f(srcs
[1].f32
);
500 float Q
= (ins
->dest_type
== nir_type_float16
) ?
501 bf(srcs
[0].u16
[ins
->swizzle
[0][0]]) :
504 float R
= biti_special(Q
, ins
->op
.special
);
506 if (ins
->dest_type
== nir_type_float16
) {
509 if (!ins
->swizzle
[0][0] && ins
->op
.special
== BI_SPECIAL_FRSQ
) {
520 if (ins
->op
.table
== BI_TABLE_LOG2_U_OVER_U_1_LOW
) {
521 assert(ins
->dest_type
== nir_type_float32
);
523 float f
= frexp_log(srcs
[0].f32
, &_nop
);
524 dest
.f32
= log2f(f
) / (f
- 1.0);
525 dest
.u32
++; /* Sorry. */
527 unreachable("Unknown table op");
533 if (ins
->dest_type
== nir_type_uint16
) {
534 for (unsigned c
= 0; c
< 2; ++c
)
535 dest
.u16
[c
] = srcs
[c
].u16
[ins
->swizzle
[c
][0]];
536 } else if (ins
->dest_type
== nir_type_uint8
) {
537 for (unsigned c
= 0; c
< 4; ++c
)
538 dest
.u8
[c
] = srcs
[c
].u8
[ins
->swizzle
[c
][0]];
540 unreachable("Unknown type");
547 unreachable("Unsupported op");
549 /* We only interpret vertex shaders */
554 unreachable("Fragment op used in interpreter");
556 /* Modeling main memory is more than I bargained for */
557 case BI_LOAD_UNIFORM
:
559 case BI_LOAD_VAR_ADDRESS
:
564 unreachable("Unsupported I/O in interpreter");
567 unreachable("Unsupported op");
571 if ((ins
->type
== BI_FMA
|| ins
->type
== BI_ADD
) && ins
->op
.mscale
) {
572 unsigned idx
= (ins
->type
== BI_FMA
) ? 3 : 2;
574 assert(ins
->src_types
[idx
] == nir_type_int32
);
575 assert(ins
->dest_type
== nir_type_float32
);
577 int32_t scale
= srcs
[idx
].i32
;
578 dest
.f32
*= exp2f(scale
);
582 if (bi_has_outmod(ins
) && ins
->outmod
!= BIFROST_NONE
) {
583 if (ins
->dest_type
== nir_type_float16
) {
584 for (unsigned c
= 0; c
< 2; ++c
)
585 dest
.f16
[c
] = bh(bit_outmod(bf(dest
.f16
[c
]), ins
->outmod
));
587 dest
.f32
= bit_outmod(dest
.f32
, ins
->outmod
);
591 /* Finally, store the result */
592 bit_write(s
, ins
->dest
, ins
->dest_type
, dest
, FMA
);
594 /* For ADD - change out the passthrough */