nir/opt_algebraic: lower 64-bit fmin3/fmax3/fmed3
[mesa.git] / src / compiler / nir / nir_opt_algebraic.py
index 7d9775950a4d596d519df681ce6f86e605583173..00d18402bd1aa399df49493e5ee186987e92d07d 100644 (file)
@@ -109,7 +109,6 @@ optimizations = [
 
    (('~fneg', ('fneg', a)), a),
    (('ineg', ('ineg', a)), a),
-   (('fabs', ('fabs', a)), ('fabs', a)),
    (('fabs', ('fneg', a)), ('fabs', a)),
    (('fabs', ('u2f', a)), ('u2f', a)),
    (('iabs', ('iabs', a)), ('iabs', a)),
@@ -560,6 +559,16 @@ optimizations.extend([
    (('umin', ('umax', ('umin', ('umax', a, b), c), b), c), ('umin', ('umax', a, b), c)),
    (('fmax', ('fsat', a), '#b@32(is_zero_to_one)'), ('fsat', ('fmax', a, b))),
    (('fmin', ('fsat', a), '#b@32(is_zero_to_one)'), ('fsat', ('fmin', a, b))),
+
+   # If a in [0,b] then b-a is also in [0,b].  Since b in [0,1], max(b-a, 0) =
+   # fsat(b-a).
+   #
+   # If a > b, then b-a < 0 and max(b-a, 0) = fsat(b-a) = 0
+   #
+   # This should be NaN safe since max(NaN, 0) = fsat(NaN) = 0.
+   (('fmax', ('fadd(is_used_once)', ('fneg', 'a(is_not_negative)'), '#b@32(is_zero_to_one)'), 0.0),
+    ('fsat', ('fadd', ('fneg',  a), b)), '!options->lower_fsat'),
+
    (('extract_u8', ('imin', ('imax', a, 0), 0xff), 0), ('imin', ('imax', a, 0), 0xff)),
    (('~ior', ('flt(is_used_once)', a, b), ('flt', a, c)), ('flt', a, ('fmax', b, c))),
    (('~ior', ('flt(is_used_once)', a, c), ('flt', b, c)), ('flt', ('fmin', a, b), c)),
@@ -629,6 +638,17 @@ optimizations.extend([
    (('iand', ('ieq', 'a@32', 0), ('ieq', 'b@32', 0)), ('ieq', ('ior', a, b), 0), '!options->lower_bitops'),
    (('ior',  ('ine', 'a@32', 0), ('ine', 'b@32', 0)), ('ine', ('ior', a, b), 0), '!options->lower_bitops'),
 
+   # This pattern occurs coutresy of __flt64_nonnan in the soft-fp64 code.
+   # The first part of the iand comes from the !__feq64_nonnan.
+   #
+   # The second pattern is a reformulation of the first based on the relation
+   # (a == 0 || y == 0) <=> umin(a, y) == 0, where b in the first equation
+   # happens to be y == 0.
+   (('iand', ('inot', ('iand', ('ior', ('ieq', a, 0),  b), c)), ('ilt', a, 0)),
+    ('iand', ('inot', ('iand',                         b , c)), ('ilt', a, 0))),
+   (('iand', ('inot', ('iand', ('ieq', ('umin', a, b), 0), c)), ('ilt', a, 0)),
+    ('iand', ('inot', ('iand', ('ieq',             b , 0), c)), ('ilt', a, 0))),
+
    # These patterns can result when (a < b || a < c) => (a < min(b, c))
    # transformations occur before constant propagation and loop-unrolling.
    (('~flt', a, ('fmax', b, a)), ('flt', a, b)),
@@ -776,6 +796,7 @@ optimizations.extend([
    (('~fexp2', ('fmul', ('flog2', a), b)), ('fpow', a, b), '!options->lower_fpow'), # 2^(lg2(a)*b) = a^b
    (('~fexp2', ('fadd', ('fmul', ('flog2', a), b), ('fmul', ('flog2', c), d))),
     ('~fmul', ('fpow', a, b), ('fpow', c, d)), '!options->lower_fpow'), # 2^(lg2(a) * b + lg2(c) + d) = a^b * c^d
+   (('~fexp2', ('fmul', ('flog2', a), 0.5)), ('fsqrt', a)),
    (('~fexp2', ('fmul', ('flog2', a), 2.0)), ('fmul', a, a)),
    (('~fexp2', ('fmul', ('flog2', a), 4.0)), ('fmul', ('fmul', a, a), ('fmul', a, a))),
    (('~fpow', a, 1.0), a),
@@ -1055,6 +1076,10 @@ optimizations.extend([
 
    (('bcsel', ('ine', a, -1), ('ifind_msb', a), -1), ('ifind_msb', a)),
 
+   (('fmin3@64', a, b, c), ('fmin@64', a, ('fmin@64', b, c))),
+   (('fmax3@64', a, b, c), ('fmax@64', a, ('fmax@64', b, c))),
+   (('fmed3@64', a, b, c), ('fmax@64', ('fmin@64', ('fmax@64', a, b), c), ('fmin@64', a, b))),
+
    # Misc. lowering
    (('fmod', a, b), ('fsub', a, ('fmul', b, ('ffloor', ('fdiv', a, b)))), 'options->lower_fmod'),
    (('frem', a, b), ('fsub', a, ('fmul', b, ('ftrunc', ('fdiv', a, b)))), 'options->lower_fmod'),
@@ -1347,7 +1372,7 @@ for x, y in itertools.product(['f', 'u', 'i'], ['f', 'u', 'i']):
    optimizations.append(((x2yN, (b2x, a)), (b2y, a)))
 
 # Optimize away x2xN(a@N)
-for t in ['int', 'uint', 'float']:
+for t in ['int', 'uint', 'float', 'bool']:
    for N in type_sizes(t):
       x2xN = '{0}2{0}{1}'.format(t[0], N)
       aN = 'a@{0}'.format(N)
@@ -1388,6 +1413,15 @@ for N, M in itertools.product(type_sizes('uint'), type_sizes('uint')):
       # The N == M case is handled by other optimizations
       pass
 
+# Downcast operations should be able to see through pack
+for t in ['i', 'u']:
+    for N in [8, 16, 32]:
+        x2xN = '{0}2{0}{1}'.format(t, N)
+        optimizations += [
+            ((x2xN, ('pack_64_2x32_split', a, b)), (x2xN, a)),
+            ((x2xN, ('pack_64_2x32_split', a, b)), (x2xN, a)),
+        ]
+
 # Optimize comparisons with up-casts
 for t in ['int', 'uint', 'float']:
     for N, M in itertools.product(type_sizes(t), repeat=2):
@@ -1832,8 +1866,43 @@ for op in ['ffma']:
         (('bcsel', a, (op, b, c, d), (op + '(is_used_once)', b, e, d)), (op, b, ('bcsel', a, c, e), d)),
     ]
 
+distribute_src_mods = [
+   # Try to remove some spurious negations rather than pushing them down.
+   (('fmul', ('fneg', a), ('fneg', b)), ('fmul', a, b)),
+   (('ffma', ('fneg', a), ('fneg', b), c), ('ffma', a, b, c)),
+   (('fdot_replicated2', ('fneg', a), ('fneg', b)), ('fdot_replicated2', a, b)),
+   (('fdot_replicated3', ('fneg', a), ('fneg', b)), ('fdot_replicated3', a, b)),
+   (('fdot_replicated4', ('fneg', a), ('fneg', b)), ('fdot_replicated4', a, b)),
+   (('fneg', ('fneg', a)), a),
+
+   (('fneg', ('ffma(is_used_once)', a, b, c)), ('ffma', ('fneg', a), b, ('fneg', c))),
+   (('fneg', ('flrp(is_used_once)', a, b, c)), ('flrp', ('fneg', a), ('fneg', b), c)),
+   (('fneg', ('fadd(is_used_once)', a, b)), ('fadd', ('fneg', a), ('fneg', b))),
+
+   # Note that fmin <-> fmax.  I don't think there is a way to distribute
+   # fabs() into fmin or fmax.
+   (('fneg', ('fmin(is_used_once)', a, b)), ('fmax', ('fneg', a), ('fneg', b))),
+   (('fneg', ('fmax(is_used_once)', a, b)), ('fmin', ('fneg', a), ('fneg', b))),
+
+   # fdph works mostly like fdot, but to get the correct result, the negation
+   # must be applied to the second source.
+   (('fneg', ('fdph_replicated(is_used_once)', a, b)), ('fdph_replicated', a, ('fneg', b))),
+   (('fabs', ('fdph_replicated(is_used_once)', a, b)), ('fdph_replicated', ('fabs', a), ('fabs', b))),
+
+   (('fneg', ('fsign(is_used_once)', a)), ('fsign', ('fneg', a))),
+   (('fabs', ('fsign(is_used_once)', a)), ('fsign', ('fabs', a))),
+]
+
+for op in ['fmul', 'fdot_replicated2', 'fdot_replicated3', 'fdot_replicated4']:
+   distribute_src_mods.extend([
+       (('fneg', (op + '(is_used_once)', a, b)), (op, ('fneg', a), b)),
+       (('fabs', (op + '(is_used_once)', a, b)), (op, ('fabs', a), ('fabs', b))),
+   ])
+
 print(nir_algebraic.AlgebraicPass("nir_opt_algebraic", optimizations).render())
 print(nir_algebraic.AlgebraicPass("nir_opt_algebraic_before_ffma",
                                   before_ffma_optimizations).render())
 print(nir_algebraic.AlgebraicPass("nir_opt_algebraic_late",
                                   late_optimizations).render())
+print(nir_algebraic.AlgebraicPass("nir_opt_algebraic_distribute_src_mods",
+                                  distribute_src_mods).render())