2 * Copyright (C) 2020 Collabora Ltd.
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 * Authors (Collabora):
24 * Alyssa Rosenzweig <alyssa.rosenzweig@collabora.com>
29 #include "util/half_float.h"
45 /* Interprets a subset of Bifrost IR required for automated testing */
48 bit_read(struct bit_state
*s
, bi_instruction
*ins
, unsigned index
, nir_alu_type T
, bool FMA
)
50 if (index
& BIR_INDEX_REGISTER
) {
51 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
54 } else if (index
& BIR_INDEX_UNIFORM
) {
55 unreachable("Uniform registers to be implemented");
56 } else if (index
& BIR_INDEX_CONSTANT
) {
57 return ins
->constant
.u64
>> (index
& ~BIR_INDEX_CONSTANT
);
58 } else if (index
& BIR_INDEX_ZERO
) {
60 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_STAGE
)) {
61 return FMA
? 0 : s
->T
;
62 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_FMA
)) {
64 } else if (index
& (BIR_INDEX_PASS
| BIFROST_SRC_PASS_ADD
)) {
70 unreachable("Invalid source");
75 bit_write(struct bit_state
*s
, unsigned index
, nir_alu_type T
, bit_t value
, bool FMA
)
77 /* Always write stage passthrough */
81 if (index
& BIR_INDEX_REGISTER
) {
82 uint32_t reg
= index
& ~BIR_INDEX_REGISTER
;
84 s
->r
[reg
] = value
.u32
;
88 unreachable("Invalid destination");
92 #define bh _mesa_float_to_half
93 #define bf _mesa_half_to_float
96 for (unsigned c = 0; c < 2; ++c) { \
97 dest.f16[c] = bh(fxn(bf(srcs[0].f16[ins->swizzle[0][c]]), \
98 bf(srcs[1].f16[ins->swizzle[1][c]]), \
99 bf(srcs[2].f16[ins->swizzle[2][c]]), \
100 bf(srcs[3].f16[ins->swizzle[3][c]]))); \
103 #define bv2i16(fxn) \
104 for (unsigned c = 0; c < 2; ++c) { \
105 dest.f16[c] = fxn(srcs[0].u16[ins->swizzle[0][c]], \
106 srcs[1].u16[ins->swizzle[1][c]], \
107 srcs[2].u16[ins->swizzle[2][c]], \
108 srcs[3].u16[ins->swizzle[3][c]]); \
111 #define bf32(fxn) dest.f32 = fxn(srcs[0].f32, srcs[1].f32, srcs[2].f32, srcs[3].f32)
112 #define bi32(fxn) dest.i32 = fxn(srcs[0].u32, srcs[1].u32, srcs[2].u32, srcs[3].i32)
114 #define bfloat(fxn64, fxn32) \
115 if (ins->dest_type == nir_type_float64) { \
116 unreachable("TODO: 64-bit"); \
117 } else if (ins->dest_type == nir_type_float32) { \
120 } else if (ins->dest_type == nir_type_float16) { \
125 #define bint(fxn64, fxn32, fxn16, fxn8) \
126 if (ins->dest_type == nir_type_int64 || ins->dest_type == nir_type_uint64) { \
127 unreachable("TODO: 64-bit"); \
128 } else if (ins->dest_type == nir_type_int32 || ins->dest_type == nir_type_uint32) { \
131 } else if (ins->dest_type == nir_type_int16 || ins->dest_type == nir_type_uint16) { \
134 } else if (ins->dest_type == nir_type_int8 || ins->dest_type == nir_type_uint8) { \
135 unreachable("TODO: 8-bit"); \
138 #define bpoly(name) \
139 bfloat(bit_f64 ## name, bit_f32 ## name); \
140 bint(bit_i64 ## name, bit_i32 ## name, bit_i16 ## name, bit_i8 ## name); \
141 unreachable("Invalid type");
143 #define bit_make_float(name, expr) \
144 static inline double \
145 bit_f64 ## name(double a, double b, double c, double d) \
149 static inline float \
150 bit_f32 ## name(float a, float b, float c, float d) \
155 #define bit_make_int(name, expr) \
156 static inline int64_t \
157 bit_i64 ## name (int64_t a, int64_t b, int64_t c, int64_t d) \
162 static inline int32_t \
163 bit_i32 ## name (int32_t a, int32_t b, int32_t c, int32_t d) \
168 static inline int16_t \
169 bit_i16 ## name (int16_t a, int16_t b, int16_t c, int16_t d) \
174 static inline int8_t \
175 bit_i8 ## name (int8_t a, int8_t b, int8_t c, int8_t d) \
180 #define bit_make_poly(name, expr) \
181 bit_make_float(name, expr) \
182 bit_make_int(name, expr) \
184 bit_make_poly(add
, a
+ b
);
185 bit_make_float(fma
, (a
* b
) + c
);
186 bit_make_poly(mov
, a
);
191 bit_outmod(float raw
, enum bifrost_outmod mod
)
195 return MAX2(raw
, 0.0);
196 case BIFROST_SAT_SIGNED
:
197 return CLAMP(raw
, -1.0, 1.0);
199 return CLAMP(raw
, 0.0, 1.0);
206 bit_srcmod(float raw
, bool abs
, bool neg
)
217 #define BIT_COND(cond, left, right) \
218 if (cond == BI_COND_LT) return left < right; \
219 else if (cond == BI_COND_LE) return left <= right; \
220 else if (cond == BI_COND_GE) return left >= right; \
221 else if (cond == BI_COND_GT) return left > right; \
222 else if (cond == BI_COND_EQ) return left == right; \
223 else if (cond == BI_COND_NE) return left != right; \
224 else { return true; }
227 bit_eval_cond(enum bi_cond cond
, bit_t l
, bit_t r
, nir_alu_type T
, unsigned c
)
229 if (T
== nir_type_float32
) {
230 BIT_COND(cond
, l
.f32
, r
.f32
);
231 } else if (T
== nir_type_float16
) {
232 float left
= bf(l
.f16
[c
]);
233 float right
= bf(r
.f16
[c
]);
234 BIT_COND(cond
, left
, right
);
235 } else if (T
== nir_type_int32
) {
236 int32_t left
= (int32_t) l
.u32
;
237 int32_t right
= (int32_t) r
.u32
;
238 BIT_COND(cond
, left
, right
);
239 } else if (T
== nir_type_int16
) {
240 int16_t left
= (int16_t) l
.u32
;
241 int16_t right
= (int16_t) r
.u32
;
242 BIT_COND(cond
, left
, right
);
243 } else if (T
== nir_type_uint32
) {
244 BIT_COND(cond
, l
.u32
, r
.u32
);
245 } else if (T
== nir_type_uint16
) {
246 BIT_COND(cond
, l
.u16
[c
], r
.u16
[c
]);
248 unreachable("Unknown type evaluated");
253 biti_special(float Q
, enum bi_special_op op
)
256 case BI_SPECIAL_FRCP
: return 1.0 / Q
;
257 case BI_SPECIAL_FRSQ
: {
258 double Qf
= 1.0 / sqrt(Q
);
261 default: unreachable("Invalid special");
266 bit_step(struct bit_state
*s
, bi_instruction
*ins
, bool FMA
)
268 /* First, load sources */
269 bit_t srcs
[BIR_SRC_COUNT
] = { 0 };
271 bi_foreach_src(ins
, src
)
272 srcs
[src
].u64
= bit_read(s
, ins
, ins
->src
[src
], ins
->src_types
[src
], FMA
);
274 /* Apply source modifiers if we need to */
275 if (bi_has_source_mods(ins
)) {
276 bi_foreach_src(ins
, src
) {
277 if (ins
->src_types
[src
] == nir_type_float16
) {
278 for (unsigned c
= 0; c
< 2; ++c
) {
279 srcs
[src
].f16
[c
] = bh(bit_srcmod(bf(srcs
[src
].f16
[c
]),
283 } else if (ins
->src_types
[src
] == nir_type_float32
) {
284 srcs
[src
].f32
= bit_srcmod(srcs
[src
].f32
,
291 /* Next, do the action of the instruction */
302 unreachable("Unsupported op");
305 bool direct
= ins
->csel_cond
== BI_COND_ALWAYS
;
306 bool cond
= direct
? srcs
[0].u32
:
307 bit_eval_cond(ins
->csel_cond
, srcs
[0], srcs
[1], ins
->src_types
[0], 0);
309 dest
= cond
? srcs
[2] : srcs
[3];
314 bfloat(bit_f64fma
, bit_f32fma
);
315 unreachable("Unknown type");
321 unreachable("Unsupported op");
327 assert(nir_alu_type_get_base_type(ins
->dest_type
) == nir_type_float
);
328 assert(nir_alu_type_get_base_type(ins
->dest_type
) != nir_type_float64
);
329 float Q
= (ins
->dest_type
== nir_type_float16
) ?
330 bf(srcs
[0].u16
[ins
->swizzle
[0][0]]) :
333 float R
= biti_special(Q
, ins
->op
.special
);
335 if (ins
->dest_type
== nir_type_float16
) {
338 if (!ins
->swizzle
[0][0] && ins
->op
.special
== BI_SPECIAL_FRSQ
) {
351 unreachable("Unsupported op");
353 /* We only interpret vertex shaders */
358 unreachable("Fragment op used in interpreter");
360 /* Modeling main memory is more than I bargained for */
361 case BI_LOAD_UNIFORM
:
363 case BI_LOAD_VAR_ADDRESS
:
368 unreachable("Unsupported I/O in interpreter");
371 unreachable("Unsupported op");
375 if (bi_has_outmod(ins
) && ins
->outmod
!= BIFROST_NONE
) {
376 if (ins
->dest_type
== nir_type_float16
) {
377 for (unsigned c
= 0; c
< 2; ++c
)
378 dest
.f16
[c
] = bh(bit_outmod(bf(dest
.f16
[c
]), ins
->outmod
));
380 dest
.f32
= bit_outmod(dest
.f32
, ins
->outmod
);
384 /* Finally, store the result */
385 bit_write(s
, ins
->dest
, ins
->dest_type
, dest
, FMA
);
387 /* For ADD - change out the passthrough */