#include "nir.h"
#include "nir_builder.h"
-/* Lowers idiv/udiv/umod
- * Based on NV50LegalizeSSA::handleDIV()
+/* Has two paths
+ * One (nir_lower_idiv_fast) lowers idiv/udiv/umod and is based on
+ * NV50LegalizeSSA::handleDIV()
*
- * Note that this is probably not enough precision for compute shaders.
- * Perhaps we want a second higher precision (looping) version of this?
- * Or perhaps we assume if you can do compute shaders you can also
- * branch out to a pre-optimized shader library routine..
+ * Note that this path probably does not have not enough precision for
+ * compute shaders. Perhaps we want a second higher precision (looping)
+ * version of this? Or perhaps we assume if you can do compute shaders you
+ * can also branch out to a pre-optimized shader library routine..
+ *
+ * The other path (nir_lower_idiv_precise) is based off of code used by LLVM's
+ * AMDGPU target. It should handle 32-bit idiv/irem/imod/udiv/umod exactly.
*/
static bool
return true;
}
+/* ported from LLVM's AMDGPUTargetLowering::LowerUDIVREM */
+static nir_ssa_def *
+emit_udiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, bool modulo)
+{
+ nir_ssa_def *rcp = nir_frcp(bld, nir_u2f32(bld, denom));
+ rcp = nir_f2u32(bld, nir_fmul_imm(bld, rcp, 4294967296.0));
+ nir_ssa_def *rcp_lo = nir_imul(bld, rcp, denom);
+ nir_ssa_def *rcp_hi = nir_umul_high(bld, rcp, denom);
+ nir_ssa_def *rcp_hi_ne_zero = nir_ine(bld, rcp_hi, nir_imm_int(bld, 0));
+ nir_ssa_def *neg_rcp_lo = nir_ineg(bld, rcp_lo);
+ nir_ssa_def *abs_rcp_lo = nir_bcsel(bld, rcp_hi_ne_zero, rcp_lo, neg_rcp_lo);
+ nir_ssa_def *e = nir_umul_high(bld, abs_rcp_lo, rcp);
+ nir_ssa_def *rcp_plus_e = nir_iadd(bld, rcp, e);
+ nir_ssa_def *rcp_minus_e = nir_isub(bld, rcp, e);
+ nir_ssa_def *tmp0 = nir_bcsel(bld, rcp_hi_ne_zero, rcp_minus_e, rcp_plus_e);
+ nir_ssa_def *quotient = nir_umul_high(bld, tmp0, numer);
+ nir_ssa_def *num_s_remainder = nir_imul(bld, quotient, denom);
+ nir_ssa_def *remainder = nir_isub(bld, numer, num_s_remainder);
+ nir_ssa_def *remainder_ge_den = nir_uge(bld, remainder, denom);
+ nir_ssa_def *remainder_ge_zero = nir_uge(bld, numer, num_s_remainder);
+ nir_ssa_def *tmp1 = nir_iand(bld, remainder_ge_den, remainder_ge_zero);
+
+ if (modulo) {
+ nir_ssa_def *rem = nir_bcsel(bld, tmp1,
+ nir_isub(bld, remainder, denom), remainder);
+ return nir_bcsel(bld, remainder_ge_zero,
+ rem, nir_iadd(bld, remainder, denom));
+ } else {
+ nir_ssa_def *one = nir_imm_int(bld, 1);
+ nir_ssa_def *div = nir_bcsel(bld, tmp1,
+ nir_iadd(bld, quotient, one), quotient);
+ return nir_bcsel(bld, remainder_ge_zero,
+ div, nir_isub(bld, quotient, one));
+ }
+}
+
+/* ported from LLVM's AMDGPUTargetLowering::LowerSDIVREM */
+static nir_ssa_def *
+emit_idiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, nir_op op)
+{
+ nir_ssa_def *lh_sign = nir_ilt(bld, numer, nir_imm_int(bld, 0));
+ nir_ssa_def *rh_sign = nir_ilt(bld, denom, nir_imm_int(bld, 0));
+ lh_sign = nir_bcsel(bld, lh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
+ rh_sign = nir_bcsel(bld, rh_sign, nir_imm_int(bld, -1), nir_imm_int(bld, 0));
+
+ nir_ssa_def *lhs = nir_iadd(bld, numer, lh_sign);
+ nir_ssa_def *rhs = nir_iadd(bld, denom, rh_sign);
+ lhs = nir_ixor(bld, lhs, lh_sign);
+ rhs = nir_ixor(bld, rhs, rh_sign);
+
+ if (op == nir_op_idiv) {
+ nir_ssa_def *d_sign = nir_ixor(bld, lh_sign, rh_sign);
+ nir_ssa_def *res = emit_udiv(bld, lhs, rhs, false);
+ res = nir_ixor(bld, res, d_sign);
+ return nir_isub(bld, res, d_sign);
+ } else {
+ nir_ssa_def *res = emit_udiv(bld, lhs, rhs, true);
+ res = nir_ixor(bld, res, lh_sign);
+ res = nir_isub(bld, res, lh_sign);
+ if (op == nir_op_imod) {
+ nir_ssa_def *cond = nir_ieq(bld, res, nir_imm_int(bld, 0));
+ cond = nir_ior(bld, nir_ieq(bld, lh_sign, rh_sign), cond);
+ res = nir_bcsel(bld, cond, res, nir_iadd(bld, res, denom));
+ }
+ return res;
+ }
+}
+
+static bool
+convert_instr_precise(nir_builder *bld, nir_alu_instr *alu)
+{
+ nir_op op = alu->op;
+
+ if ((op != nir_op_idiv) &&
+ (op != nir_op_imod) &&
+ (op != nir_op_irem) &&
+ (op != nir_op_udiv) &&
+ (op != nir_op_umod))
+ return false;
+
+ if (alu->dest.dest.ssa.bit_size != 32)
+ return false;
+
+ bld->cursor = nir_before_instr(&alu->instr);
+
+ nir_ssa_def *numer = nir_ssa_for_alu_src(bld, alu, 0);
+ nir_ssa_def *denom = nir_ssa_for_alu_src(bld, alu, 1);
+
+ nir_ssa_def *res = NULL;
+
+ if (op == nir_op_udiv || op == nir_op_umod)
+ res = emit_udiv(bld, numer, denom, op == nir_op_umod);
+ else
+ res = emit_idiv(bld, numer, denom, op);
+
+ assert(alu->dest.dest.is_ssa);
+ nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(res));
+
+ return true;
+}
+
static bool
-convert_impl(nir_function_impl *impl)
+convert_impl(nir_function_impl *impl, enum nir_lower_idiv_path path)
{
nir_builder b;
nir_builder_init(&b, impl);
nir_foreach_block(block, impl) {
nir_foreach_instr_safe(instr, block) {
- if (instr->type == nir_instr_type_alu)
+ if (instr->type == nir_instr_type_alu && path == nir_lower_idiv_precise)
+ progress |= convert_instr_precise(&b, nir_instr_as_alu(instr));
+ else if (instr->type == nir_instr_type_alu)
progress |= convert_instr(&b, nir_instr_as_alu(instr));
}
}
}
bool
-nir_lower_idiv(nir_shader *shader)
+nir_lower_idiv(nir_shader *shader, enum nir_lower_idiv_path path)
{
bool progress = false;
nir_foreach_function(function, shader) {
if (function->impl)
- progress |= convert_impl(function->impl);
+ progress |= convert_impl(function->impl, path);
}
return progress;