+/**
+ * If the \p s is an SSA value that was generated by a negation instruction,
+ * that instruction is returned as a \c nir_alu_instr. Otherwise \c NULL is
+ * returned.
+ */
+static nir_alu_instr *
+get_neg_instr(nir_src s)
+{
+ nir_alu_instr *alu = nir_src_as_alu_instr(s);
+
+ return alu != NULL && (alu->op == nir_op_fneg || alu->op == nir_op_ineg)
+ ? alu : NULL;
+}
+
+bool
+nir_const_value_negative_equal(nir_const_value c1,
+ nir_const_value c2,
+ nir_alu_type full_type)
+{
+ assert(nir_alu_type_get_base_type(full_type) != nir_type_invalid);
+ assert(nir_alu_type_get_type_size(full_type) != 0);
+
+ switch (full_type) {
+ case nir_type_float16:
+ return _mesa_half_to_float(c1.u16) == -_mesa_half_to_float(c2.u16);
+
+ case nir_type_float32:
+ return c1.f32 == -c2.f32;
+
+ case nir_type_float64:
+ return c1.f64 == -c2.f64;
+
+ case nir_type_int8:
+ case nir_type_uint8:
+ return c1.i8 == -c2.i8;
+
+ case nir_type_int16:
+ case nir_type_uint16:
+ return c1.i16 == -c2.i16;
+
+ case nir_type_int32:
+ case nir_type_uint32:
+ return c1.i32 == -c2.i32;
+
+ case nir_type_int64:
+ case nir_type_uint64:
+ return c1.i64 == -c2.i64;
+
+ default:
+ break;
+ }
+
+ return false;
+}
+
+/**
+ * Shallow compare of ALU srcs to determine if one is the negation of the other
+ *
+ * This function detects cases where \p alu1 is a constant and \p alu2 is a
+ * constant that is its negation. It will also detect cases where \p alu2 is
+ * an SSA value that is a \c nir_op_fneg applied to \p alu1 (and vice versa).
+ *
+ * This function does not detect the general case when \p alu1 and \p alu2 are
+ * SSA values that are the negations of each other (e.g., \p alu1 represents
+ * (a * b) and \p alu2 represents (-a * b)).
+ *
+ * \warning
+ * It is the responsibility of the caller to ensure that the component counts,
+ * write masks, and base types of the sources being compared are compatible.
+ */
+bool
+nir_alu_srcs_negative_equal(const nir_alu_instr *alu1,
+ const nir_alu_instr *alu2,
+ unsigned src1, unsigned src2)
+{
+#ifndef NDEBUG
+ for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+ assert(nir_alu_instr_channel_used(alu1, src1, i) ==
+ nir_alu_instr_channel_used(alu2, src2, i));
+ }
+
+ if (nir_op_infos[alu1->op].input_types[src1] == nir_type_float) {
+ assert(nir_op_infos[alu1->op].input_types[src1] ==
+ nir_op_infos[alu2->op].input_types[src2]);
+ } else {
+ assert(nir_op_infos[alu1->op].input_types[src1] == nir_type_int);
+ assert(nir_op_infos[alu2->op].input_types[src2] == nir_type_int);
+ }
+#endif
+
+ if (alu1->src[src1].abs != alu2->src[src2].abs)
+ return false;
+
+ bool parity = alu1->src[src1].negate != alu2->src[src2].negate;
+
+ /* Handling load_const instructions is tricky. */
+
+ const nir_const_value *const const1 =
+ nir_src_as_const_value(alu1->src[src1].src);
+
+ if (const1 != NULL) {
+ /* Assume that constant folding will eliminate source mods and unary
+ * ops.
+ */
+ if (parity)
+ return false;
+
+ const nir_const_value *const const2 =
+ nir_src_as_const_value(alu2->src[src2].src);
+
+ if (const2 == NULL)
+ return false;
+
+ if (nir_src_bit_size(alu1->src[src1].src) !=
+ nir_src_bit_size(alu2->src[src2].src))
+ return false;
+
+ const nir_alu_type full_type = nir_op_infos[alu1->op].input_types[src1] |
+ nir_src_bit_size(alu1->src[src1].src);
+ for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) {
+ if (nir_alu_instr_channel_used(alu1, src1, i) &&
+ !nir_const_value_negative_equal(const1[alu1->src[src1].swizzle[i]],
+ const2[alu2->src[src2].swizzle[i]],
+ full_type))
+ return false;
+ }
+
+ return true;
+ }
+
+ uint8_t alu1_swizzle[4] = {0};
+ nir_src alu1_actual_src;
+ nir_alu_instr *neg1 = get_neg_instr(alu1->src[src1].src);
+
+ if (neg1) {
+ parity = !parity;
+ alu1_actual_src = neg1->src[0].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg1, 0); i++)
+ alu1_swizzle[i] = neg1->src[0].swizzle[i];
+ } else {
+ alu1_actual_src = alu1->src[src1].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++)
+ alu1_swizzle[i] = i;
+ }
+
+ uint8_t alu2_swizzle[4] = {0};
+ nir_src alu2_actual_src;
+ nir_alu_instr *neg2 = get_neg_instr(alu2->src[src2].src);
+
+ if (neg2) {
+ parity = !parity;
+ alu2_actual_src = neg2->src[0].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg2, 0); i++)
+ alu2_swizzle[i] = neg2->src[0].swizzle[i];
+ } else {
+ alu2_actual_src = alu2->src[src2].src;
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu2, src2); i++)
+ alu2_swizzle[i] = i;
+ }
+
+ for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) {
+ if (alu1_swizzle[alu1->src[src1].swizzle[i]] !=
+ alu2_swizzle[alu2->src[src2].swizzle[i]])
+ return false;
+ }
+
+ return parity && nir_srcs_equal(alu1_actual_src, alu2_actual_src);
+}
+