From 936c58c8fcceee086d3c492712595555afe82266 Mon Sep 17 00:00:00 2001 From: Boris Brezillon Date: Fri, 19 Jun 2020 17:28:09 +0200 Subject: [PATCH] nir: Extend nir_lower_int64() to support i2f/f2i lowering That's an attempt at replacing the complex __int64_to_float() and __float_to_int64() implementations found in float64.glsl by a simpler native NIR equivalent. Thanks to that, we can have lower those conversion without having to compile a GLSL shader, which would be quite annoying for OpenCL kernels. Signed-off-by: Boris Brezillon Reviewed-by: Matt Turner Acked-by: Jason Ekstrand Part-of: --- src/compiler/nir/nir_lower_int64.c | 155 ++++++++++++++++++++++++++++- 1 file changed, 154 insertions(+), 1 deletion(-) diff --git a/src/compiler/nir/nir_lower_int64.c b/src/compiler/nir/nir_lower_int64.c index 2c188d5eaf6..03d7c156807 100644 --- a/src/compiler/nir/nir_lower_int64.c +++ b/src/compiler/nir/nir_lower_int64.c @@ -24,6 +24,23 @@ #include "nir.h" #include "nir_builder.h" +#define COND_LOWER_OP(b, name, ...) \ + (b->shader->options->lower_int64_options & \ + nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \ + lower_##name##64(b, __VA_ARGS__) : nir_##name(b, __VA_ARGS__) + +#define COND_LOWER_CMP(b, name, ...) \ + (b->shader->options->lower_int64_options & \ + nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \ + lower_int64_compare(b, nir_op_##name, __VA_ARGS__) : \ + nir_##name(b, __VA_ARGS__) + +#define COND_LOWER_CAST(b, name, ...) \ + (b->shader->options->lower_int64_options & \ + nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \ + lower_##name(b, __VA_ARGS__) : \ + nir_##name(b, __VA_ARGS__) + static nir_ssa_def * lower_b2i64(nir_builder *b, nir_ssa_def *x) { @@ -670,6 +687,109 @@ lower_ufind_msb64(nir_builder *b, nir_ssa_def *x) return nir_bcsel(b, valid_hi_bits, hi_res, lo_count); } +static nir_ssa_def * +lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size, + bool src_is_signed) +{ + nir_ssa_def *x_sign = NULL; + + if (src_is_signed) { + x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)), + nir_imm_floatN_t(b, -1, dest_bit_size), + nir_imm_floatN_t(b, 1, dest_bit_size)); + x = COND_LOWER_OP(b, iabs, x); + } + + nir_ssa_def *exp = COND_LOWER_OP(b, ufind_msb, x); + unsigned significand_bits; + + switch (dest_bit_size) { + case 32: + significand_bits = 23; + break; + case 16: + significand_bits = 10; + break; + default: + unreachable("Invalid dest_bit_size"); + } + + /* We keep one more bit than can fit in the significand field to let the + * u2f32 conversion do the rounding for us. + */ + nir_ssa_def *discard = + nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits + 1)), + nir_imm_int(b, 0)); + + /* Part of the "round to nearest" has to be taken care of before we discard + * the LSB, and that's what this extra iadd is for. + * "Round to nearest even" is handled by u2f. That works because the + * shifted value either fits in the significand field (which means no + * rounding is required) or contains one extra bit that forces the + * conversion op to round things properly. + */ + nir_ssa_def *add = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard); + add = COND_LOWER_OP(b, isub, add, nir_imm_int64(b, 1)); + nir_ssa_def *rounded_x = COND_LOWER_OP(b, iadd, x, add); + + /* Signed Values can't overflow because we've saved the sign and promoted + * them to unsigned values. + */ + if (!src_is_signed) { + nir_ssa_def *overflow = COND_LOWER_CMP(b, ult, rounded_x, x); + rounded_x = COND_LOWER_OP(b, bcsel, overflow, + nir_imm_int64(b, UINT64_MAX), rounded_x); + } + + nir_ssa_def *significand = COND_LOWER_OP(b, ushr, rounded_x, discard); + significand = COND_LOWER_CAST(b, u2u32, significand); + + nir_ssa_def *res; + + if (dest_bit_size == 32) + res = nir_fmul(b, nir_u2f32(b, significand), + nir_fexp2(b, nir_u2f32(b, discard))); + else + res = nir_fmul(b, nir_u2f16(b, significand), + nir_fexp2(b, nir_u2f16(b, discard))); + + if (src_is_signed) + res = nir_fmul(b, res, x_sign); + + return res; +} + +static nir_ssa_def * +lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed) +{ + assert(x->bit_size == 16 || x->bit_size == 32); + nir_ssa_def *x_sign = NULL; + + if (dst_is_signed) + x_sign = nir_fsign(b, x); + else + x = nir_fmin(b, x, nir_imm_floatN_t(b, UINT64_MAX, x->bit_size)); + + x = nir_ftrunc(b, x); + + if (dst_is_signed) { + x = nir_fmin(b, x, nir_imm_floatN_t(b, INT64_MAX, x->bit_size)); + x = nir_fmax(b, x, nir_imm_floatN_t(b, INT64_MIN, x->bit_size)); + x = nir_fabs(b, x); + } + + nir_ssa_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size); + nir_ssa_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div)); + nir_ssa_def *res_lo = nir_f2u32(b, nir_frem(b, x, div)); + nir_ssa_def *res = nir_pack_64_2x32_split(b, res_lo, res_hi); + + if (dst_is_signed) + res = nir_bcsel(b, nir_flt(b, x_sign, nir_imm_float(b, 0)), + nir_ineg(b, res), res); + + return res; +} + nir_lower_int64_options nir_lower_int64_op_to_options_mask(nir_op opcode) { @@ -701,6 +821,12 @@ nir_lower_int64_op_to_options_mask(nir_op opcode) case nir_op_u2u16: case nir_op_u2u32: case nir_op_u2u64: + case nir_op_i2f32: + case nir_op_u2f32: + case nir_op_i2f16: + case nir_op_u2f16: + case nir_op_f2i64: + case nir_op_f2u64: case nir_op_bcsel: return nir_lower_mov64; case nir_op_ieq: @@ -860,7 +986,21 @@ lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state) return lower_extract(b, alu->op, src[0], src[1]); case nir_op_ufind_msb: return lower_ufind_msb64(b, src[0]); - break; + case nir_op_i2f64: + case nir_op_i2f32: + case nir_op_i2f16: + return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), true); + case nir_op_u2f64: + case nir_op_u2f32: + case nir_op_u2f16: + return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), false); + case nir_op_f2i64: + case nir_op_f2u64: + /* We don't support f64toi64 (yet?). */ + if (src[0]->bit_size > 32) + return false; + + return lower_f2(b, src[0], alu->op == nir_op_f2i64); default: unreachable("Invalid ALU opcode to lower"); } @@ -922,6 +1062,19 @@ should_lower_int64_alu_instr(const nir_instr *instr, const void *_data) if (alu->dest.dest.ssa.bit_size != 64) return false; break; + case nir_op_i2f64: + case nir_op_u2f64: + case nir_op_i2f32: + case nir_op_u2f32: + case nir_op_i2f16: + case nir_op_u2f16: + assert(alu->src[0].src.is_ssa); + if (alu->src[0].src.ssa->bit_size != 64) + return false; + break; + case nir_op_f2u64: + case nir_op_f2i64: + /* fall-through */ default: assert(alu->dest.dest.is_ssa); if (alu->dest.dest.ssa.bit_size != 64) -- 2.30.2