split out 2nd dct outer butterfly scheduler
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Jul 2021 22:18:25 +0000 (23:18 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 22 Jul 2021 22:18:25 +0000 (23:18 +0100)
src/openpower/decoder/isa/remap_dct_yield.py

index e6598fa87e8f3ea462d6c67622e8ec2ebeb1c137..b88b67268335d0e9db43f76777e72ecd63541ebe 100644 (file)
@@ -43,7 +43,7 @@ def halfrev2(vec, pre_rev=True):
 
 # python "yield" can be iterated. use this to make it clear how
 # the indices are generated by using natural-looking nested loops
-def iterate_dct_butterfly_indices(SVSHAPE):
+def iterate_dct_inner_butterfly_indices(SVSHAPE):
     # get indices to iterate over, in the required order
     n = SVSHAPE.lims[0]
     # createing lists of indices to iterate over in each dimension
@@ -73,11 +73,11 @@ def iterate_dct_butterfly_indices(SVSHAPE):
     ji = list(range(n))
     inplace_mode = SVSHAPE.mode == 0b01 and SVSHAPE.skip not in [0b10, 0b11]
     if inplace_mode:
-        print ("inplace mode")
+        #print ("inplace mode")
         ji = halfrev2(ji, True)
 
-    print ("ri", ri)
-    print ("ji", ji)
+    #print ("ri", ri)
+    #print ("ji", ji)
 
     # start an infinite (wrapping) loop
     skip = 0
@@ -98,22 +98,108 @@ def iterate_dct_butterfly_indices(SVSHAPE):
                 jr = list(range(i+halfsize, i + size))
                 jr.reverse()
                 # invert if requested
-                if SVSHAPE.invxyz[2]: k_r.reverse()
                 if SVSHAPE.invxyz[2]: j_r.reverse()
                 hz2 = halfsize // 2 # zero stops reversing 1-item lists
                 # if you *really* want to do the in-place swapping manually,
                 # this allows you to do it.  good luck...
                 if SVSHAPE.mode == 0b01 and not inplace_mode:
-                    print ("swap mode")
+                    #print ("swap mode")
                     jr = j_r[:hz2]
-                print ("xform jr", jr)
+                #print ("xform jr", jr)
                 for jl, jh in zip(j, jr):   # loop over 1st order dimension
                     z_end = jl == j[-1]
-                    # now depending on MODE return the index
+                    # now depending on MODE return the index.  inner butterfly
+                    if SVSHAPE.mode == 0b01:
+                        if SVSHAPE.skip in [0b00, 0b10]:
+                            result = ri[ji[jl]]        # lower half
+                        elif SVSHAPE.skip in [0b01, 0b11]:
+                            result = ri[ji[jh]] # upper half, reverse order
+                    # outer butterfly
+                    elif SVSHAPE.mode == 0b10:
+                        if SVSHAPE.skip == 0b00:
+                            result = ri[ji[jl]]        # lower half
+                        elif SVSHAPE.skip == 0b01:
+                            result = ri[ji[jl+size]]   # upper half
+                    loopends = (z_end |
+                               ((y_end and z_end)<<1) |
+                                ((y_end and x_end and z_end)<<2))
+
+                    yield result + SVSHAPE.offset, loopends
+
+                # now in-place swap
+                if SVSHAPE.mode == 0b01 and inplace_mode:
+                    for ci, (jl, jh) in enumerate(zip(j[:hz2], jr[:hz2])):
+                        jlh = jl+halfsize
+                        #print ("inplace swap", jh, jlh)
+                        tmp1, tmp2 = ji[jlh], ji[jh]
+                        ji[jlh], ji[jh] = tmp2, tmp1
+
+
+# python "yield" can be iterated. use this to make it clear how
+# the indices are generated by using natural-looking nested loops
+def iterate_dct_outer_butterfly_indices(SVSHAPE):
+    # get indices to iterate over, in the required order
+    n = SVSHAPE.lims[0]
+    # createing lists of indices to iterate over in each dimension
+    # has to be done dynamically, because it depends on the size
+    # first, the size-based loop (which can be done statically)
+    x_r = []
+    size = n // 2
+    while size >= 2:
+        x_r.append(size)
+        size //= 2
+    # invert order if requested
+    if SVSHAPE.invxyz[0]:
+        x_r.reverse()
+
+    if len(x_r) == 0:
+        return
+
+    #print ("outer butterfly")
+
+    # reference (read/write) the in-place data in *reverse-bit-order*
+    ri = list(range(n))
+    if SVSHAPE.mode == 0b11:
+        levels = n.bit_length() - 1
+        ri = [ri[reverse_bits(i, levels)] for i in range(n)]
+
+    # reference list for not needing to do data-swaps, just swap what
+    # *indices* are referenced (two levels of indirection at the moment)
+    # pre-reverse the data-swap list so that it *ends up* in the order 0123..
+    ji = list(range(n))
+    inplace_mode = SVSHAPE.skip in [0b10, 0b11]
+    if inplace_mode:
+        #print ("inplace mode", SVSHAPE.skip)
+        ji = halfrev2(ji, True)
+
+    #print ("ri", ri)
+    #print ("ji", ji)
+
+    # start an infinite (wrapping) loop
+    skip = 0
+    while True:
+        for size in x_r:           # loop over 3rd order dimension (size)
+            halfsize = size//2
+            x_end = size == x_r[-1]
+            y_r = list(range(0, halfsize))
+            #print ("itersum", halfsize, size, y_r)
+            # invert if requested
+            if SVSHAPE.invxyz[1]: y_r.reverse()
+            for i in y_r:       # loop over 2nd order dimension
+                y_end = i == y_r[-1]
+                # one list to create iterative-sum schedule
+                jr = list(range(i+halfsize, i+n-halfsize, size))
+                #print ("itersum     jr", i+halfsize, i+size, jr)
+                # invert if requested
+                if SVSHAPE.invxyz[2]: j_r.reverse()
+                hz2 = halfsize // 2 # zero stops reversing 1-item lists
+                for jh in jr:   # loop over 1st order dimension
+                    z_end = jh == jr[-1]
+                    #print ("     itersum", size, i, jh, jh+size)
                     if SVSHAPE.skip in [0b00, 0b10]:
-                        result = ri[ji[jl]]        # lower half
+                        result = ri[ji[jh]]        # lower half
                     elif SVSHAPE.skip in [0b01, 0b11]:
-                        result = ri[ji[jh]] # upper half, reverse order
+                        result = ri[ji[jh+size]] # upper half
                     loopends = (z_end |
                                ((y_end and z_end)<<1) |
                                 ((y_end and x_end and z_end)<<2))
@@ -121,7 +207,10 @@ def iterate_dct_butterfly_indices(SVSHAPE):
                     yield result + SVSHAPE.offset, loopends
 
                 # now in-place swap
-                if SVSHAPE.mode == 0b01 and inplace_mode:
+                if SVSHAPE.mode == 0b11 and inplace_mode:
+                    j = list(range(i, i + halfsize))
+                    jr = list(range(i+halfsize, i + size))
+                    jr.reverse()
                     for ci, (jl, jh) in enumerate(zip(j[:hz2], jr[:hz2])):
                         jlh = jl+halfsize
                         #print ("inplace swap", jh, jlh)
@@ -150,6 +239,30 @@ def pprint_schedule(schedule, n):
                 idx += 1
         size *= 2
 
+def pprint_schedule_outer(schedule, n):
+    size = 2
+    idx = 0
+    while size <= n//2:
+        halfsize = size // 2
+        tablestep = n // size
+        print ("size %d halfsize %d tablestep %d" % \
+                (size, halfsize, tablestep))
+        y_r = list(range(0, halfsize))
+        for i in y_r:
+            prefix = "i %d\t" % i
+            jr = list(range(i+halfsize, i+n-halfsize, size))
+            for j in jr:
+                (jl, je), (jh, he) = schedule[idx]
+                print ("  %-3d\t%s j=%-2d jh=%-2d "
+                        "j[jl=%-2d] j[jh=%-2d]" % \
+                                (idx, prefix, j, j+halfsize,
+                                      jl, jh,
+                                ),
+                                "end", bin(je)[2:], bin(je)[2:])
+                idx += 1
+        size *= 2
+
+
 # totally cool *in-place* DCT algorithm using yield REMAPs
 def transform2(vec):
 
@@ -174,13 +287,11 @@ def transform2(vec):
     # computed every time.
     ctable = []
     size = n
-    VL = 0
     while size >= 2:
         halfsize = size // 2
         for i in range(n//size):
             for ci in range(halfsize):
                 ctable.append((math.cos((ci + 0.5) * math.pi / size) * 2.0))
-                VL += 1
         size //= 2
 
     ################
@@ -211,11 +322,9 @@ def transform2(vec):
     SVSHAPE1.invxyz = [1,0,0] # inversion if desired
 
     # enumerate over the iterator function, getting new indices
-    i0 = iterate_dct_butterfly_indices(SVSHAPE0)
-    i1 = iterate_dct_butterfly_indices(SVSHAPE1)
+    i0 = iterate_dct_inner_butterfly_indices(SVSHAPE0)
+    i1 = iterate_dct_inner_butterfly_indices(SVSHAPE1)
     for k, ((jl, jle), (jh, jhe)) in enumerate(zip(i0, i1)):
-        if k >= VL:
-            break
         t1, t2 = vec[jl], vec[jh]
         coeff = ctable[k]
         vec[jl] = t1 + t2
@@ -223,24 +332,42 @@ def transform2(vec):
         print ("coeff", size, i, "ci", ci,
                 "jl", jl, "jh", jh,
                "i/n", (ci+0.5)/size, coeff, vec[jl],
-                                            vec[jh])
+                                            vec[jh],
+                "end", bin(jle), bin(jhe))
+        if jle == 0b111: # all loops end
+            break
 
     print("transform2 pre-itersum", vec)
 
     # now things are in the right order for the outer butterfly.
-    n = len(vec)
-    size = n // 2
-    while size >= 2:
-        halfsize = size // 2
-        ir = list(range(0, halfsize))
-        print ("itersum", halfsize, size, ir)
-        for i in ir:
-            jr = list(range(i+halfsize, i+n-halfsize, size))
-            print ("itersum    jr", i+halfsize, i+size, jr)
-            for jh in jr:
-                vec[jh] += vec[jh+size]
-                print ("    itersum", size, i, jh, jh+size)
+
+    # j schedule
+    SVSHAPE0 = SVSHAPE()
+    SVSHAPE0.lims = [xdim, ydim, zdim]
+    SVSHAPE0.order = [0,1,2]  # experiment with different permutations, here
+    SVSHAPE0.mode = 0b10
+    SVSHAPE0.skip = 0b00
+    SVSHAPE0.offset = 0       # experiment with different offset, here
+    SVSHAPE0.invxyz = [0,0,0] # inversion if desired
+    # j+halfstep schedule
+    SVSHAPE1 = SVSHAPE()
+    SVSHAPE1.lims = [xdim, ydim, zdim]
+    SVSHAPE1.order = [0,1,2]  # experiment with different permutations, here
+    SVSHAPE1.mode = 0b10
+    SVSHAPE1.skip = 0b01
+    SVSHAPE1.offset = 0       # experiment with different offset, here
+    SVSHAPE1.invxyz = [0,0,0] # inversion if desired
+
+    # enumerate over the iterator function, getting new indices
+    i0 = iterate_dct_outer_butterfly_indices(SVSHAPE0)
+    i1 = iterate_dct_outer_butterfly_indices(SVSHAPE1)
+    for k, ((jl, jle), (jh, jhe)) in enumerate(zip(i0, i1)):
+        print ("itersum    jr", jl, jh,
+                "end", bin(jle), bin(jhe))
+        vec[jl] += vec[jh]
         size //= 2
+        if jle == 0b111: # all loops end
+            break
 
     print("transform2 result", vec)
 
@@ -249,22 +376,11 @@ def transform2(vec):
 
 def demo():
     # set the dimension sizes here
-    xdim = 8
+    n = 8
+    xdim = n
     ydim = 0 # not needed
     zdim = 0 # again, not needed
 
-    # set total. err don't know how to calculate how many there are...
-    # do it manually for now
-    VL = 0
-    size = 2
-    n = xdim
-    while size <= n:
-        halfsize = size // 2
-        tablestep = n // size
-        for i in range(0, n, size):
-            for j in range(i, i + halfsize):
-                VL += 1
-        size *= 2
 
     ################
     # INNER butterfly
@@ -292,12 +408,12 @@ def demo():
 
     # enumerate over the iterator function, getting new indices
     schedule = []
-    i0 = iterate_dct_butterfly_indices(SVSHAPE0)
-    i1 = iterate_dct_butterfly_indices(SVSHAPE1)
+    i0 = iterate_dct_inner_butterfly_indices(SVSHAPE0)
+    i1 = iterate_dct_inner_butterfly_indices(SVSHAPE1)
     for idx, (jl, jh) in enumerate(zip(i0, i1)):
-        if idx >= VL:
-            break
         schedule.append((jl, jh))
+        if jl[1] == 0b111: # end
+            break
 
     # ok now pretty-print the results, with some debug output
     print ("inner butterfly")
@@ -313,7 +429,7 @@ def demo():
     SVSHAPE0.lims = [xdim, ydim, zdim]
     SVSHAPE0.order = [0,1,2]  # experiment with different permutations, here
     SVSHAPE0.mode = 0b10
-    SVSHAPE0.skip = 0b00
+    SVSHAPE0.skip = 0b10
     SVSHAPE0.offset = 0       # experiment with different offset, here
     SVSHAPE0.invxyz = [1,0,0] # inversion if desired
     # j+halfstep schedule
@@ -321,22 +437,22 @@ def demo():
     SVSHAPE1.lims = [xdim, ydim, zdim]
     SVSHAPE1.order = [0,1,2]  # experiment with different permutations, here
     SVSHAPE1.mode = 0b10
-    SVSHAPE1.skip = 0b01
+    SVSHAPE1.skip = 0b11
     SVSHAPE1.offset = 0       # experiment with different offset, here
     SVSHAPE1.invxyz = [1,0,0] # inversion if desired
 
     # enumerate over the iterator function, getting new indices
     schedule = []
-    i0 = iterate_dct_butterfly_indices(SVSHAPE0)
-    i1 = iterate_dct_butterfly_indices(SVSHAPE1)
+    i0 = iterate_dct_outer_butterfly_indices(SVSHAPE0)
+    i1 = iterate_dct_outer_butterfly_indices(SVSHAPE1)
     for idx, (jl, jh) in enumerate(zip(i0, i1)):
-        if idx >= VL:
-            break
         schedule.append((jl, jh))
+        if jl[1] == 0b111: # end
+            break
 
     # ok now pretty-print the results, with some debug output
     print ("outer butterfly")
-    pprint_schedule(schedule, n)
+    pprint_schedule_outer(schedule, n)
 
 # run the demo
 if __name__ == '__main__':