#include "nir.h"
#include "nir_builder.h"
+#define COND_LOWER_OP(b, name, ...) \
+ (b->shader->options->lower_int64_options & \
+ nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \
+ lower_##name##64(b, __VA_ARGS__) : nir_##name(b, __VA_ARGS__)
+
+#define COND_LOWER_CMP(b, name, ...) \
+ (b->shader->options->lower_int64_options & \
+ nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \
+ lower_int64_compare(b, nir_op_##name, __VA_ARGS__) : \
+ nir_##name(b, __VA_ARGS__)
+
+#define COND_LOWER_CAST(b, name, ...) \
+ (b->shader->options->lower_int64_options & \
+ nir_lower_int64_op_to_options_mask(nir_op_##name)) ? \
+ lower_##name(b, __VA_ARGS__) : \
+ nir_##name(b, __VA_ARGS__)
+
static nir_ssa_def *
lower_b2i64(nir_builder *b, nir_ssa_def *x)
{
lower_i2i64(nir_builder *b, nir_ssa_def *x)
{
nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
- return nir_pack_64_2x32_split(b, x32, nir_ishr(b, x32, nir_imm_int(b, 31)));
+ return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
}
static nir_ssa_def *
x32[0] = nir_unpack_64_2x32_split_x(b, x);
x32[1] = nir_unpack_64_2x32_split_y(b, x);
if (sign_extend) {
- x32[2] = x32[3] = nir_ishr(b, x32[1], nir_imm_int(b, 31));
+ x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
} else {
x32[2] = x32[3] = nir_imm_int(b, 0);
}
y32[0] = nir_unpack_64_2x32_split_x(b, y);
y32[1] = nir_unpack_64_2x32_split_y(b, y);
if (sign_extend) {
- y32[2] = y32[3] = nir_ishr(b, y32[1], nir_imm_int(b, 31));
+ y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
} else {
y32[2] = y32[3] = nir_imm_int(b, 0);
}
if (carry)
tmp = nir_iadd(b, tmp, carry);
res[i + j] = nir_u2u32(b, tmp);
- carry = nir_ushr(b, tmp, nir_imm_int(b, 32));
+ carry = nir_ushr_imm(b, tmp, 32);
}
res[i + 4] = nir_u2u32(b, carry);
}
nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
- nir_ssa_def *res_hi = nir_ishr(b, x_hi, nir_imm_int(b, 31));
+ nir_ssa_def *res_hi = nir_ishr_imm(b, x_hi, 31);
nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
return nir_pack_64_2x32_split(b, res_lo, res_hi);
nir_ssa_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
- nir_const_value v = { .u32 = { 0, 0, 0, 0 } };
- nir_ssa_def *q_lo = nir_build_imm(b, n->num_components, 32, v);
- nir_ssa_def *q_hi = nir_build_imm(b, n->num_components, 32, v);
+ nir_ssa_def *q_lo = nir_imm_zero(b, n->num_components, 32);
+ nir_ssa_def *q_hi = nir_imm_zero(b, n->num_components, 32);
nir_ssa_def *n_hi_before_if = n_hi;
nir_ssa_def *q_hi_before_if = q_hi;
return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
}
+static nir_ssa_def *
+lower_extract(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *c)
+{
+ assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
+ op == nir_op_extract_u16 || op == nir_op_extract_i16);
+
+ const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
+ const int chunk_bits =
+ (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
+ const int num_chunks_in_32 = 32 / chunk_bits;
+
+ nir_ssa_def *extract32;
+ if (chunk < num_chunks_in_32) {
+ extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
+ nir_imm_int(b, chunk),
+ NULL, NULL);
+ } else {
+ extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
+ nir_imm_int(b, chunk - num_chunks_in_32),
+ NULL, NULL);
+ }
+
+ if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
+ return lower_i2i64(b, extract32);
+ else
+ return lower_u2u64(b, extract32);
+}
+
+static nir_ssa_def *
+lower_ufind_msb64(nir_builder *b, nir_ssa_def *x)
+{
+
+ nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
+ nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
+ nir_ssa_def *lo_count = nir_ufind_msb(b, x_lo);
+ nir_ssa_def *hi_count = nir_ufind_msb(b, x_hi);
+ nir_ssa_def *valid_hi_bits = nir_ine(b, x_hi, nir_imm_int(b, 0));
+ nir_ssa_def *hi_res = nir_iadd(b, nir_imm_intN_t(b, 32, 32), hi_count);
+ return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
+}
+
+static nir_ssa_def *
+lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
+ bool src_is_signed)
+{
+ nir_ssa_def *x_sign = NULL;
+
+ if (src_is_signed) {
+ x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
+ nir_imm_floatN_t(b, -1, dest_bit_size),
+ nir_imm_floatN_t(b, 1, dest_bit_size));
+ x = COND_LOWER_OP(b, iabs, x);
+ }
+
+ nir_ssa_def *exp = COND_LOWER_OP(b, ufind_msb, x);
+ unsigned significand_bits;
+
+ switch (dest_bit_size) {
+ case 32:
+ significand_bits = 23;
+ break;
+ case 16:
+ significand_bits = 10;
+ break;
+ default:
+ unreachable("Invalid dest_bit_size");
+ }
+
+ nir_ssa_def *discard =
+ nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits)),
+ nir_imm_int(b, 0));
+ nir_ssa_def *significand =
+ COND_LOWER_CAST(b, u2u32, COND_LOWER_OP(b, ushr, x, discard));
+
+ /* Round-to-nearest-even implementation:
+ * - if the non-representable part of the significand is higher than half
+ * the minimum representable significand, we round-up
+ * - if the non-representable part of the significand is equal to half the
+ * minimum representable significand and the representable part of the
+ * significand is odd, we round-up
+ * - in any other case, we round-down
+ */
+ nir_ssa_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
+ nir_ssa_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
+ nir_ssa_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
+ nir_ssa_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
+ nir_ssa_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
+ nir_ine(b, discard, nir_imm_int(b, 0)));
+ nir_ssa_def *is_odd = nir_i2b(b, nir_iand(b, significand, nir_imm_int(b, 1)));
+ nir_ssa_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
+ nir_iand(b, halfway, is_odd));
+ significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
+
+ nir_ssa_def *res;
+
+ if (dest_bit_size == 32)
+ res = nir_fmul(b, nir_u2f32(b, significand),
+ nir_fexp2(b, nir_u2f32(b, discard)));
+ else
+ res = nir_fmul(b, nir_u2f16(b, significand),
+ nir_fexp2(b, nir_u2f16(b, discard)));
+
+ if (src_is_signed)
+ res = nir_fmul(b, res, x_sign);
+
+ return res;
+}
+
+static nir_ssa_def *
+lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed)
+{
+ assert(x->bit_size == 16 || x->bit_size == 32);
+ nir_ssa_def *x_sign = NULL;
+
+ if (dst_is_signed)
+ x_sign = nir_fsign(b, x);
+ else
+ x = nir_fmin(b, x, nir_imm_floatN_t(b, UINT64_MAX, x->bit_size));
+
+ x = nir_ftrunc(b, x);
+
+ if (dst_is_signed) {
+ x = nir_fmin(b, x, nir_imm_floatN_t(b, INT64_MAX, x->bit_size));
+ x = nir_fmax(b, x, nir_imm_floatN_t(b, INT64_MIN, x->bit_size));
+ x = nir_fabs(b, x);
+ }
+
+ nir_ssa_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
+ nir_ssa_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
+ nir_ssa_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
+ nir_ssa_def *res = nir_pack_64_2x32_split(b, res_lo, res_hi);
+
+ if (dst_is_signed)
+ res = nir_bcsel(b, nir_flt(b, x_sign, nir_imm_float(b, 0)),
+ nir_ineg(b, res), res);
+
+ return res;
+}
+
+static nir_ssa_def *
+lower_bit_count64(nir_builder *b, nir_ssa_def *x)
+{
+ nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
+ nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
+ nir_ssa_def *lo_count = nir_bit_count(b, x_lo);
+ nir_ssa_def *hi_count = nir_bit_count(b, x_hi);
+ return nir_iadd(b, lo_count, hi_count);
+}
+
nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)
{
switch (opcode) {
case nir_op_imul:
+ case nir_op_amul:
return nir_lower_imul64;
case nir_op_imul_2x32_64:
case nir_op_umul_2x32_64:
return nir_lower_divmod64;
case nir_op_b2i64:
case nir_op_i2b1:
+ case nir_op_i2i8:
+ case nir_op_i2i16:
case nir_op_i2i32:
case nir_op_i2i64:
+ case nir_op_u2u8:
+ case nir_op_u2u16:
case nir_op_u2u32:
case nir_op_u2u64:
+ case nir_op_i2f32:
+ case nir_op_u2f32:
+ case nir_op_i2f16:
+ case nir_op_u2f16:
+ case nir_op_f2i64:
+ case nir_op_f2u64:
case nir_op_bcsel:
return nir_lower_mov64;
case nir_op_ieq:
case nir_op_ishr:
case nir_op_ushr:
return nir_lower_shift64;
+ case nir_op_extract_u8:
+ case nir_op_extract_i8:
+ case nir_op_extract_u16:
+ case nir_op_extract_i16:
+ return nir_lower_extract64;
+ case nir_op_ufind_msb:
+ return nir_lower_ufind_msb64;
+ case nir_op_bit_count:
+ return nir_lower_bit_count64;
default:
return 0;
}
}
static nir_ssa_def *
-lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
+lower_int64_alu_instr(nir_builder *b, nir_instr *instr, void *_state)
{
+ nir_alu_instr *alu = nir_instr_as_alu(instr);
+
nir_ssa_def *src[4];
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
src[i] = nir_ssa_for_alu_src(b, alu, i);
switch (alu->op) {
case nir_op_imul:
+ case nir_op_amul:
return lower_imul64(b, src[0], src[1]);
case nir_op_imul_2x32_64:
return lower_mul_2x32_64(b, src[0], src[1], true);
return lower_ishr64(b, src[0], src[1]);
case nir_op_ushr:
return lower_ushr64(b, src[0], src[1]);
+ case nir_op_extract_u8:
+ case nir_op_extract_i8:
+ case nir_op_extract_u16:
+ case nir_op_extract_i16:
+ return lower_extract(b, alu->op, src[0], src[1]);
+ case nir_op_ufind_msb:
+ return lower_ufind_msb64(b, src[0]);
+ case nir_op_bit_count:
+ return lower_bit_count64(b, src[0]);
+ case nir_op_i2f64:
+ case nir_op_i2f32:
+ case nir_op_i2f16:
+ return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), true);
+ case nir_op_u2f64:
+ case nir_op_u2f32:
+ case nir_op_u2f16:
+ return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), false);
+ case nir_op_f2i64:
+ case nir_op_f2u64:
+ /* We don't support f64toi64 (yet?). */
+ if (src[0]->bit_size > 32)
+ return false;
+
+ return lower_f2(b, src[0], alu->op == nir_op_f2i64);
default:
unreachable("Invalid ALU opcode to lower");
}
}
static bool
-lower_int64_impl(nir_function_impl *impl, nir_lower_int64_options options)
-{
- nir_builder b;
- nir_builder_init(&b, impl);
-
- bool progress = false;
- nir_foreach_block(block, impl) {
- nir_foreach_instr_safe(instr, block) {
- if (instr->type != nir_instr_type_alu)
- continue;
-
- nir_alu_instr *alu = nir_instr_as_alu(instr);
- switch (alu->op) {
- case nir_op_i2b1:
- case nir_op_i2i32:
- case nir_op_u2u32:
- assert(alu->src[0].src.is_ssa);
- if (alu->src[0].src.ssa->bit_size != 64)
- continue;
- break;
- case nir_op_bcsel:
- assert(alu->src[1].src.is_ssa);
- assert(alu->src[2].src.is_ssa);
- assert(alu->src[1].src.ssa->bit_size ==
- alu->src[2].src.ssa->bit_size);
- if (alu->src[1].src.ssa->bit_size != 64)
- continue;
- break;
- case nir_op_ieq:
- case nir_op_ine:
- case nir_op_ult:
- case nir_op_ilt:
- case nir_op_uge:
- case nir_op_ige:
- assert(alu->src[0].src.is_ssa);
- assert(alu->src[1].src.is_ssa);
- assert(alu->src[0].src.ssa->bit_size ==
- alu->src[1].src.ssa->bit_size);
- if (alu->src[0].src.ssa->bit_size != 64)
- continue;
- break;
- default:
- assert(alu->dest.dest.is_ssa);
- if (alu->dest.dest.ssa.bit_size != 64)
- continue;
- break;
- }
-
- if (!(options & nir_lower_int64_op_to_options_mask(alu->op)))
- continue;
+should_lower_int64_alu_instr(const nir_instr *instr, const void *_data)
+{
+ const nir_shader_compiler_options *options =
+ (const nir_shader_compiler_options *)_data;
- b.cursor = nir_before_instr(instr);
+ if (instr->type != nir_instr_type_alu)
+ return false;
- nir_ssa_def *lowered = lower_int64_alu_instr(&b, alu);
- nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa,
- nir_src_for_ssa(lowered));
- nir_instr_remove(&alu->instr);
- progress = true;
- }
- }
+ const nir_alu_instr *alu = nir_instr_as_alu(instr);
- if (progress) {
- nir_metadata_preserve(impl, nir_metadata_none);
- } else {
-#ifndef NDEBUG
- impl->valid_metadata &= ~nir_metadata_not_properly_reset;
-#endif
+ switch (alu->op) {
+ case nir_op_i2b1:
+ case nir_op_i2i8:
+ case nir_op_i2i16:
+ case nir_op_i2i32:
+ case nir_op_u2u8:
+ case nir_op_u2u16:
+ case nir_op_u2u32:
+ assert(alu->src[0].src.is_ssa);
+ if (alu->src[0].src.ssa->bit_size != 64)
+ return false;
+ break;
+ case nir_op_bcsel:
+ assert(alu->src[1].src.is_ssa);
+ assert(alu->src[2].src.is_ssa);
+ assert(alu->src[1].src.ssa->bit_size ==
+ alu->src[2].src.ssa->bit_size);
+ if (alu->src[1].src.ssa->bit_size != 64)
+ return false;
+ break;
+ case nir_op_ieq:
+ case nir_op_ine:
+ case nir_op_ult:
+ case nir_op_ilt:
+ case nir_op_uge:
+ case nir_op_ige:
+ assert(alu->src[0].src.is_ssa);
+ assert(alu->src[1].src.is_ssa);
+ assert(alu->src[0].src.ssa->bit_size ==
+ alu->src[1].src.ssa->bit_size);
+ if (alu->src[0].src.ssa->bit_size != 64)
+ return false;
+ break;
+ case nir_op_ufind_msb:
+ case nir_op_bit_count:
+ assert(alu->src[0].src.is_ssa);
+ if (alu->src[0].src.ssa->bit_size != 64)
+ return false;
+ break;
+ case nir_op_amul:
+ assert(alu->dest.dest.is_ssa);
+ if (options->has_imul24)
+ return false;
+ if (alu->dest.dest.ssa.bit_size != 64)
+ return false;
+ break;
+ case nir_op_i2f64:
+ case nir_op_u2f64:
+ case nir_op_i2f32:
+ case nir_op_u2f32:
+ case nir_op_i2f16:
+ case nir_op_u2f16:
+ assert(alu->src[0].src.is_ssa);
+ if (alu->src[0].src.ssa->bit_size != 64)
+ return false;
+ break;
+ case nir_op_f2u64:
+ case nir_op_f2i64:
+ /* fall-through */
+ default:
+ assert(alu->dest.dest.is_ssa);
+ if (alu->dest.dest.ssa.bit_size != 64)
+ return false;
+ break;
}
- return progress;
+ unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
+ return (options->lower_int64_options & mask) != 0;
}
bool
-nir_lower_int64(nir_shader *shader, nir_lower_int64_options options)
+nir_lower_int64(nir_shader *shader)
{
- bool progress = false;
-
- nir_foreach_function(function, shader) {
- if (function->impl)
- progress |= lower_int64_impl(function->impl, options);
- }
-
- return progress;
+ return nir_shader_lower_instructions(shader,
+ should_lower_int64_alu_instr,
+ lower_int64_alu_instr,
+ (void *)shader->options);
}