half way through converting in-place dct to yield unit test
[openpower-isa.git] / src / openpower / decoder / isa / remap_dct_yield.py
1 # DCT "REMAP" scheduler
2 #
3 # Modifications made to create an in-place iterative DCT:
4 # Copyright (c) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
5 #
6 # SPDX: LGPLv3+
7 #
8 # Original fastdctlee.py by Nayuki:
9 # Copyright (c) 2020 Project Nayuki. (MIT License)
10 # https://www.nayuki.io/page/fast-discrete-cosine-transform-algorithms
11
12 import math
13
14 # bits of the integer 'val'.
15 def reverse_bits(val, width):
16 result = 0
17 for _ in range(width):
18 result = (result << 1) | (val & 1)
19 val >>= 1
20 return result
21
22
23 # iterative version of [recursively-applied] half-rev.
24 # relies on the list lengths being power-of-two and the fact
25 # that bit-inversion of a list of binary numbers is the same
26 # as reversing the order of the list
27 # this version is dead easy to implement in hardware.
28 # a big surprise is that the half-reversal can be done with
29 # such a simple XOR. the inverse operation is slightly trickier
30 def halfrev2(vec, pre_rev=True):
31 res = []
32 for i in range(len(vec)):
33 if pre_rev:
34 res.append(i ^ (i>>1))
35 else:
36 ri = i
37 bl = i.bit_length()
38 for ji in range(1, bl):
39 ri ^= (i >> ji)
40 res.append(vec[ri])
41 return res
42
43
44 # python "yield" can be iterated. use this to make it clear how
45 # the indices are generated by using natural-looking nested loops
46 def iterate_dct_butterfly_indices(SVSHAPE):
47 # get indices to iterate over, in the required order
48 n = SVSHAPE.lims[0]
49 # createing lists of indices to iterate over in each dimension
50 # has to be done dynamically, because it depends on the size
51 # first, the size-based loop (which can be done statically)
52 x_r = []
53 size = 2
54 while size <= n:
55 x_r.append(size)
56 size *= 2
57 # invert order if requested
58 if SVSHAPE.invxyz[0]:
59 x_r.reverse()
60
61 if len(x_r) == 0:
62 return
63
64 # reference (read/write) the in-place data in *reverse-bit-order*
65 ri = list(range(n))
66 if SVSHAPE.mode == 0b01:
67 levels = n.bit_length() - 1
68 ri = [ri[reverse_bits(i, levels)] for i in range(n)]
69
70 # reference list for not needing to do data-swaps, just swap what
71 # *indices* are referenced (two levels of indirection at the moment)
72 # pre-reverse the data-swap list so that it *ends up* in the order 0123..
73 ji = list(range(n))
74 inplace_mode = SVSHAPE.mode == 0b01 and SVSHAPE.skip not in [0b10, 0b11]
75 if inplace_mode:
76 print ("inplace mode")
77 ji = halfrev2(ji, True)
78
79 print ("ri", ri)
80 print ("ji", ji)
81
82 # start an infinite (wrapping) loop
83 skip = 0
84 while True:
85 for size in x_r: # loop over 3rd order dimension (size)
86 x_end = size == x_r[-1]
87 # y_r schedule depends on size
88 halfsize = size // 2
89 y_r = []
90 for i in range(0, n, size):
91 y_r.append(i)
92 # invert if requested
93 if SVSHAPE.invxyz[1]: y_r.reverse()
94 for i in y_r: # loop over 2nd order dimension
95 y_end = i == y_r[-1]
96 # two lists of half-range indices, e.g. j 0123, jr 7654
97 j = list(range(i, i + halfsize))
98 jr = list(range(i+halfsize, i + size))
99 jr.reverse()
100 # invert if requested
101 if SVSHAPE.invxyz[2]: k_r.reverse()
102 if SVSHAPE.invxyz[2]: j_r.reverse()
103 hz2 = halfsize // 2 # zero stops reversing 1-item lists
104 # if you *really* want to do the in-place swapping manually,
105 # this allows you to do it. good luck...
106 if SVSHAPE.mode == 0b01 and not inplace_mode:
107 print ("swap mode")
108 jr = j_r[:hz2]
109 print ("xform jr", jr)
110 for jl, jh in zip(j, jr): # loop over 1st order dimension
111 z_end = jl == j[-1]
112 # now depending on MODE return the index
113 if SVSHAPE.skip in [0b00, 0b10]:
114 result = ri[ji[jl]] # lower half
115 elif SVSHAPE.skip in [0b01, 0b11]:
116 result = ri[ji[jh]] # upper half, reverse order
117 loopends = (z_end |
118 ((y_end and z_end)<<1) |
119 ((y_end and x_end and z_end)<<2))
120
121 yield result + SVSHAPE.offset, loopends
122
123 # now in-place swap
124 if SVSHAPE.mode == 0b01 and inplace_mode:
125 for ci, (jl, jh) in enumerate(zip(j[:hz2], jr[:hz2])):
126 jlh = jl+halfsize
127 #print ("inplace swap", jh, jlh)
128 tmp1, tmp2 = ji[jlh], ji[jh]
129 ji[jlh], ji[jh] = tmp2, tmp1
130
131
132 def pprint_schedule(schedule, n):
133 size = 2
134 idx = 0
135 while size <= n:
136 halfsize = size // 2
137 tablestep = n // size
138 print ("size %d halfsize %d tablestep %d" % \
139 (size, halfsize, tablestep))
140 for i in range(0, n, size):
141 prefix = "i %d\t" % i
142 for j in range(i, i + halfsize):
143 (jl, je), (jh, he) = schedule[idx]
144 print (" %-3d\t%s j=%-2d jh=%-2d "
145 "j[jl=%-2d] j[jh=%-2d]" % \
146 (idx, prefix, j, j+halfsize,
147 jl, jh,
148 ),
149 "end", bin(je)[2:], bin(je)[2:])
150 idx += 1
151 size *= 2
152
153 # totally cool *in-place* DCT algorithm using yield REMAPs
154 def transform2(vec):
155
156 # Initialization
157 n = len(vec)
158 print ()
159 print ("transform2", n)
160 levels = n.bit_length() - 1
161
162 # reference (read/write) the in-place data in *reverse-bit-order*
163 ri = list(range(n))
164 ri = [ri[reverse_bits(i, levels)] for i in range(n)]
165
166 # and pretend we LDed data in half-swapped *and* bit-reversed order as well
167 # TODO: merge these two
168 vec = halfrev2(vec, False)
169 vec = [vec[ri[i]] for i in range(n)]
170
171 # create a cos table: not strictly necessary but here for illustrative
172 # purposes, to demonstrate the point that it really *is* iterative.
173 # this table could be cached and used multiple times rather than
174 # computed every time.
175 ctable = []
176 size = n
177 VL = 0
178 while size >= 2:
179 halfsize = size // 2
180 for i in range(n//size):
181 for ci in range(halfsize):
182 ctable.append((math.cos((ci + 0.5) * math.pi / size) * 2.0))
183 VL += 1
184 size //= 2
185
186 ################
187 # INNER butterfly
188 ################
189 xdim = n
190 ydim = 0
191 zdim = 0
192
193 # set up an SVSHAPE
194 class SVSHAPE:
195 pass
196 # j schedule
197 SVSHAPE0 = SVSHAPE()
198 SVSHAPE0.lims = [xdim, ydim, zdim]
199 SVSHAPE0.order = [0,1,2] # experiment with different permutations, here
200 SVSHAPE0.mode = 0b01
201 SVSHAPE0.skip = 0b00
202 SVSHAPE0.offset = 0 # experiment with different offset, here
203 SVSHAPE0.invxyz = [1,0,0] # inversion if desired
204 # j+halfstep schedule
205 SVSHAPE1 = SVSHAPE()
206 SVSHAPE1.lims = [xdim, ydim, zdim]
207 SVSHAPE1.order = [0,1,2] # experiment with different permutations, here
208 SVSHAPE1.mode = 0b01
209 SVSHAPE1.skip = 0b01
210 SVSHAPE1.offset = 0 # experiment with different offset, here
211 SVSHAPE1.invxyz = [1,0,0] # inversion if desired
212
213 # enumerate over the iterator function, getting new indices
214 i0 = iterate_dct_butterfly_indices(SVSHAPE0)
215 i1 = iterate_dct_butterfly_indices(SVSHAPE1)
216 for k, ((jl, jle), (jh, jhe)) in enumerate(zip(i0, i1)):
217 if k >= VL:
218 break
219 t1, t2 = vec[jl], vec[jh]
220 coeff = ctable[k]
221 vec[jl] = t1 + t2
222 vec[jh] = (t1 - t2) * (1/coeff)
223 print ("coeff", size, i, "ci", ci,
224 "jl", jl, "jh", jh,
225 "i/n", (ci+0.5)/size, coeff, vec[jl],
226 vec[jh])
227
228 print("transform2 pre-itersum", vec)
229
230 # now things are in the right order for the outer butterfly.
231 n = len(vec)
232 size = n // 2
233 while size >= 2:
234 halfsize = size // 2
235 ir = list(range(0, halfsize))
236 print ("itersum", halfsize, size, ir)
237 for i in ir:
238 jr = list(range(i+halfsize, i+n-halfsize, size))
239 print ("itersum jr", i+halfsize, i+size, jr)
240 for jh in jr:
241 vec[jh] += vec[jh+size]
242 print (" itersum", size, i, jh, jh+size)
243 size //= 2
244
245 print("transform2 result", vec)
246
247 return vec
248
249
250 def demo():
251 # set the dimension sizes here
252 xdim = 8
253 ydim = 0 # not needed
254 zdim = 0 # again, not needed
255
256 # set total. err don't know how to calculate how many there are...
257 # do it manually for now
258 VL = 0
259 size = 2
260 n = xdim
261 while size <= n:
262 halfsize = size // 2
263 tablestep = n // size
264 for i in range(0, n, size):
265 for j in range(i, i + halfsize):
266 VL += 1
267 size *= 2
268
269 ################
270 # INNER butterfly
271 ################
272
273 # set up an SVSHAPE
274 class SVSHAPE:
275 pass
276 # j schedule
277 SVSHAPE0 = SVSHAPE()
278 SVSHAPE0.lims = [xdim, ydim, zdim]
279 SVSHAPE0.order = [0,1,2] # experiment with different permutations, here
280 SVSHAPE0.mode = 0b01
281 SVSHAPE0.skip = 0b00
282 SVSHAPE0.offset = 0 # experiment with different offset, here
283 SVSHAPE0.invxyz = [0,0,0] # inversion if desired
284 # j+halfstep schedule
285 SVSHAPE1 = SVSHAPE()
286 SVSHAPE1.lims = [xdim, ydim, zdim]
287 SVSHAPE1.order = [0,1,2] # experiment with different permutations, here
288 SVSHAPE1.mode = 0b01
289 SVSHAPE1.skip = 0b01
290 SVSHAPE1.offset = 0 # experiment with different offset, here
291 SVSHAPE1.invxyz = [0,0,0] # inversion if desired
292
293 # enumerate over the iterator function, getting new indices
294 schedule = []
295 i0 = iterate_dct_butterfly_indices(SVSHAPE0)
296 i1 = iterate_dct_butterfly_indices(SVSHAPE1)
297 for idx, (jl, jh) in enumerate(zip(i0, i1)):
298 if idx >= VL:
299 break
300 schedule.append((jl, jh))
301
302 # ok now pretty-print the results, with some debug output
303 print ("inner butterfly")
304 pprint_schedule(schedule, n)
305 print ("")
306
307 ################
308 # outer butterfly
309 ################
310
311 # j schedule
312 SVSHAPE0 = SVSHAPE()
313 SVSHAPE0.lims = [xdim, ydim, zdim]
314 SVSHAPE0.order = [0,1,2] # experiment with different permutations, here
315 SVSHAPE0.mode = 0b10
316 SVSHAPE0.skip = 0b00
317 SVSHAPE0.offset = 0 # experiment with different offset, here
318 SVSHAPE0.invxyz = [1,0,0] # inversion if desired
319 # j+halfstep schedule
320 SVSHAPE1 = SVSHAPE()
321 SVSHAPE1.lims = [xdim, ydim, zdim]
322 SVSHAPE1.order = [0,1,2] # experiment with different permutations, here
323 SVSHAPE1.mode = 0b10
324 SVSHAPE1.skip = 0b01
325 SVSHAPE1.offset = 0 # experiment with different offset, here
326 SVSHAPE1.invxyz = [1,0,0] # inversion if desired
327
328 # enumerate over the iterator function, getting new indices
329 schedule = []
330 i0 = iterate_dct_butterfly_indices(SVSHAPE0)
331 i1 = iterate_dct_butterfly_indices(SVSHAPE1)
332 for idx, (jl, jh) in enumerate(zip(i0, i1)):
333 if idx >= VL:
334 break
335 schedule.append((jl, jh))
336
337 # ok now pretty-print the results, with some debug output
338 print ("outer butterfly")
339 pprint_schedule(schedule, n)
340
341 # run the demo
342 if __name__ == '__main__':
343 demo()