nir: Add a lowering pass to split 64bit phis
[mesa.git] / src / compiler / nir / nir_lower_flrp.c
index 41342403d484f230e7bec173793d83aa53aa2452..38be18ecc6ba04fb42424f4253776ae3584cbb92 100644 (file)
@@ -84,7 +84,7 @@ replace_with_single_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
    nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
 
    nir_ssa_def *const one_minus_c =
-      nir_fadd(bld, nir_imm_float(bld, 1.0f), neg_c);
+      nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
    nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
 
    nir_ssa_def *const b_times_c = nir_fmul(bld, b, c);
@@ -117,7 +117,7 @@ replace_with_strict(struct nir_builder *bld, struct u_vector *dead_flrp,
    nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
 
    nir_ssa_def *const one_minus_c =
-      nir_fadd(bld, nir_imm_float(bld, 1.0f), neg_c);
+      nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
    nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
 
    nir_ssa_def *const first_product = nir_fmul(bld, a, one_minus_c);
@@ -171,7 +171,7 @@ replace_with_fast(struct nir_builder *bld, struct u_vector *dead_flrp,
 }
 
 /**
- * Replace flrp(a, b, c) with (b*c ± c) + a
+ * Replace flrp(a, b, c) with (b*c ± c) + a => b*c + (a ± c)
  *
  * \note: This only works if a = ±1.
  */
@@ -193,14 +193,14 @@ replace_with_expanded_ffma_and_add(struct nir_builder *bld,
       nir_ssa_def *const neg_c = nir_fneg(bld, c);
       nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
 
-      inner_sum = nir_fadd(bld, b_times_c, neg_c);
+      inner_sum = nir_fadd(bld, a, neg_c);
    } else {
-      inner_sum = nir_fadd(bld, b_times_c, c);
+      inner_sum = nir_fadd(bld, a, c);
    }
 
    nir_instr_as_alu(inner_sum->parent_instr)->exact = alu->exact;
 
-   nir_ssa_def *const outer_sum = nir_fadd(bld, inner_sum, a);
+   nir_ssa_def *const outer_sum = nir_fadd(bld, inner_sum, b_times_c);
    nir_instr_as_alu(outer_sum->parent_instr)->exact = alu->exact;
 
    nir_ssa_def_rewrite_uses(&alu->dest.dest.ssa, nir_src_for_ssa(outer_sum));