nir/algebraic: Optimize comparisons and up-casts
authorJason Ekstrand <jason@jlekstrand.net>
Sat, 13 Jul 2019 04:26:59 +0000 (23:26 -0500)
committerJason Ekstrand <jason@jlekstrand.net>
Wed, 17 Jul 2019 18:44:35 +0000 (18:44 +0000)
These seem like obvious enough optimizations in the world of multiple
integer bit sizes.  The only known thing which hits these at the moment
is some Vulkan CTS tests for 16-bit SSBO values which like to up-cast
and check for equality.  However, it's something that's bound to come up
as we start seeing more integers in shaders.

The optimizations of comparisons of casted values with constants are
something which we would ideally do with range analysis.  However,
lacking that, we can do it in opt_algebraic as long as one side is a
constant.

In dEQP-VK.ssbo.phys.layout.random.16bit.scalar.13, this commit, along
with the previous commit, reduce the number of instructions emitted on
Skylake from 55328 to 44546, a reduction of 20%.

Acked-by: Matt Turner <mattst88@gmail.com>
Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
src/compiler/nir/nir_opt_algebraic.py

index 42462d5befaeed1272882bf173b3d2b07b94c863..abefbb54756edae60d808e72ab52bac46045278b 100644 (file)
@@ -1037,6 +1037,73 @@ for N, M in itertools.product(type_sizes('uint'), type_sizes('uint')):
       # The N == M case is handled by other optimizations
       pass
 
+# Optimize comparisons with up-casts
+for t in ['int', 'uint', 'float']:
+    for N, M in itertools.product(type_sizes(t), repeat=2):
+        if N == 1 or N >= M:
+            continue
+
+        x2xM = '{0}2{0}{1}'.format(t[0], M)
+        x2xN = '{0}2{0}{1}'.format(t[0], N)
+        aN = 'a@' + str(N)
+        bN = 'b@' + str(N)
+        xeq = 'feq' if t == 'float' else 'ieq'
+        xne = 'fne' if t == 'float' else 'ine'
+        xge = '{0}ge'.format(t[0])
+        xlt = '{0}lt'.format(t[0])
+
+        # Up-casts are lossless so for correctly signed comparisons of
+        # up-casted values we can do the comparison at the largest of the two
+        # original sizes and drop one or both of the casts.  (We have
+        # optimizations to drop the no-op casts which this may generate.)
+        for P in type_sizes(t):
+            if P == 1 or P > N:
+                continue
+
+            bP = 'b@' + str(P)
+            optimizations += [
+                ((xeq, (x2xM, aN), (x2xM, bP)), (xeq, a, (x2xN, b))),
+                ((xne, (x2xM, aN), (x2xM, bP)), (xne, a, (x2xN, b))),
+                ((xge, (x2xM, aN), (x2xM, bP)), (xge, a, (x2xN, b))),
+                ((xlt, (x2xM, aN), (x2xM, bP)), (xlt, a, (x2xN, b))),
+                ((xge, (x2xM, bP), (x2xM, aN)), (xge, (x2xN, b), a)),
+                ((xlt, (x2xM, bP), (x2xM, aN)), (xlt, (x2xN, b), a)),
+            ]
+
+        # The next bit doesn't work on floats because the range checks would
+        # get way too complicated.
+        if t in ['int', 'uint']:
+            if t == 'int':
+                xN_min = -(1 << (N - 1))
+                xN_max = (1 << (N - 1)) - 1
+            elif t == 'uint':
+                xN_min = 0
+                xN_max = (1 << N) - 1
+            else:
+                assert False
+
+            # If we're up-casting and comparing to a constant, we can unfold
+            # the comparison into a comparison with the shrunk down constant
+            # and a check that the constant fits in the smaller bit size.
+            optimizations += [
+                ((xeq, (x2xM, aN), '#b'),
+                 ('iand', (xeq, a, (x2xN, b)), (xeq, (x2xM, (x2xN, b)), b))),
+                ((xne, (x2xM, aN), '#b'),
+                 ('ior', (xne, a, (x2xN, b)), (xne, (x2xM, (x2xN, b)), b))),
+                ((xlt, (x2xM, aN), '#b'),
+                 ('iand', (xlt, xN_min, b),
+                          ('ior', (xlt, xN_max, b), (xlt, a, (x2xN, b))))),
+                ((xlt, '#a', (x2xM, bN)),
+                 ('iand', (xlt, a, xN_max),
+                          ('ior', (xlt, a, xN_min), (xlt, (x2xN, a), b)))),
+                ((xge, (x2xM, aN), '#b'),
+                 ('iand', (xge, xN_max, b),
+                          ('ior', (xge, xN_min, b), (xge, a, (x2xN, b))))),
+                ((xge, '#a', (x2xM, bN)),
+                 ('iand', (xge, a, xN_min),
+                          ('ior', (xge, a, xN_max), (xge, (x2xN, a), b)))),
+            ]
+
 def fexp2i(exp, bits):
    # We assume that exp is already in the right range.
    if bits == 16: