start on inverse dct, turning recursive to iterative
[openpower-isa.git] / src / openpower / decoder / isa / fastdctlee.py
1 #
2 # Fast discrete cosine transform algorithm (Python)
3 #
4 # Modifications made to create an in-place iterative DCT:
5 # Copyright (c) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 #
7 # License for modifications - SPDX: LGPLv3+
8 #
9 # Original fastdctlee.py by Nayuki:
10 # Copyright (c) 2020 Project Nayuki. (MIT License)
11 # https://www.nayuki.io/page/fast-discrete-cosine-transform-algorithms
12 #
13 # License for original fastdctlee.py by Nayuki:
14 #
15 # Permission is hereby granted, free of charge, to any person obtaining
16 # a copy of this software and associated documentation files (the
17 # "Software"), to deal in the Software without restriction, including
18 # without limitation the rights to use, copy, modify, merge, publish,
19 # distribute, sublicense, and/or sell copies of the Software, and to
20 # permit persons to whom the Software is furnished to do so, subject to
21 # the following conditions:
22 # - The above copyright notice and this permission notice shall be included in
23 # all copies or substantial portions of the Software.
24 # - The Software is provided "as is", without warranty of any kind, express or
25 # implied, including but not limited to the warranties of merchantability,
26 # fitness for a particular purpose and noninfringement. In no event shall the
27 # authors or copyright holders be liable for any claim, damages or other
28 # liability, whether in an action of contract, tort or otherwise,
29 # arising from, out of or in connection with the Software or the use
30 # or other dealings in the Software.
31 #
32 #
33 # The modifications made are firstly to create an iterative schedule,
34 # rather than the more normal recursive algorithm. Secondly, the
35 # two butterflys are also separated out: inner butterfly does COS +/-
36 # whilst outer butterfly does the iterative summing.
37 #
38 # However, to avoid data copying some additional tricks are played:
39 # - firstly, the data is LOADed in bit-reversed order (which is normally
40 # covered by the recursive algorithm due to the odd-even reconstruction)
41 # but then to reference the data in the correct order an array of
42 # bit-reversed indices is created, as a level of indirection.
43 # the data is bit-reversed but so are the indices, making it all A-Ok.
44 # - secondly, normally in DCT a 2nd target (copy) array is used where
45 # the top half is read in reverse order (7 6 5 4) and written out
46 # to the target 4 5 6 7. the plan was to do this in two stages:
47 # write in-place in order 4 5 6 7 then swap afterwards (7-4), (6-5).
48 # however by leaving the data *in-place* and having subsequent
49 # loops refer to the data *where it now is*, the swap is avoided
50 # - thirdly, arrange for the data to be *pre-swapped* (in an inverse
51 # order of how it would have got there, if that makes sense), such
52 # that *when* it gets swapped, it ends up in the right order.
53 # given that that will be a LD operation it's no big deal.
54 #
55 # End result is that once the first butterfly is done - bear in mind
56 # it's in-place - the data is in the right order so that a second
57 # dead-straightforward iterative sum can be done: again, in-place.
58 # Really surprising.
59
60 import math
61 from copy import deepcopy
62
63 # bits of the integer 'val'.
64 def reverse_bits(val, width):
65 result = 0
66 for _ in range(width):
67 result = (result << 1) | (val & 1)
68 val >>= 1
69 return result
70
71
72 # reverse top half of a list, recursively. the recursion can be
73 # applied *after* or *before* the reversal of the top half. these
74 # are inverses of each other.
75 # this function is unused except to test the iterative version (halfrev2)
76 def halfrev(l, pre_rev=True):
77 n = len(l)
78 if n == 1:
79 return l
80 ll, lh = l[:n//2], l[n//2:]
81 if pre_rev:
82 ll, lh = halfrev(ll, pre_rev), halfrev(lh, pre_rev)
83 lh.reverse()
84 if not pre_rev:
85 ll, lh = halfrev(ll, pre_rev), halfrev(lh, pre_rev)
86 return ll + lh
87
88
89 # iterative version of [recursively-applied] half-rev.
90 # relies on the list lengths being power-of-two and the fact
91 # that bit-inversion of a list of binary numbers is the same
92 # as reversing the order of the list
93 # this version is dead easy to implement in hardware.
94 # a big surprise is that the half-reversal can be done with
95 # such a simple XOR. the inverse operation is slightly trickier
96 def halfrev2(vec, pre_rev=True):
97 res = []
98 for i in range(len(vec)):
99 if pre_rev:
100 res.append(i ^ (i>>1))
101 else:
102 ri = i
103 bl = i.bit_length()
104 for ji in range(1, bl):
105 ri ^= (i >> ji)
106 res.append(vec[ri])
107 return res
108
109
110 # DCT type II, unscaled. Algorithm by Byeong Gi Lee, 1984.
111 # See: http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.118.3056&rep=rep1&type=pdf#page=34
112 # original (recursive) algorithm by Nayuki
113 def transform(vector, indent=0):
114 idt = " " * indent
115 n = len(vector)
116 if n == 1:
117 return list(vector)
118 elif n == 0 or n % 2 != 0:
119 raise ValueError()
120 else:
121 half = n // 2
122 alpha = [(vector[i] + vector[-(i + 1)]) for i in range(half)]
123 beta = [(vector[i] - vector[-(i + 1)]) /
124 (math.cos((i + 0.5) * math.pi / n) * 2.0)
125 for i in range(half)]
126 alpha = transform(alpha)
127 beta = transform(beta )
128 result = []
129 for i in range(half - 1):
130 result.append(alpha[i])
131 result.append(beta[i] + beta[i + 1])
132 result.append(alpha[-1])
133 result.append(beta [-1])
134 return result
135
136
137 # modified recursive algorithm, based on Nayuki original, which simply
138 # prints out an awful lot of debug data. used to work out the ordering
139 # for the iterative version by analysing the indices printed out
140 def transform(vector, indent=0):
141 idt = " " * indent
142 n = len(vector)
143 if n == 1:
144 return list(vector)
145 elif n == 0 or n % 2 != 0:
146 raise ValueError()
147 else:
148 half = n // 2
149 alpha = [0] * half
150 beta = [0] * half
151 print (idt, "xf", vector)
152 print (idt, "coeff", n, "->", end=" ")
153 for i in range(half):
154 t1, t2 = vector[i], vector[n-i-1]
155 k = (math.cos((i + 0.5) * math.pi / n) * 2.0)
156 print (i, n-i-1, "i/n", (i+0.5)/n, ":", k, end= " ")
157 alpha[i] = t1 + t2
158 beta[i] = (t1 - t2) * (1/k)
159 print ()
160 print (idt, "n", n, "alpha", end=" ")
161 for i in range(0, n, 2):
162 print (i, i//2, alpha[i//2], end=" ")
163 print()
164 print (idt, "n", n, "beta", end=" ")
165 for i in range(0, n, 2):
166 print (i, beta[i//2], end=" ")
167 print()
168 alpha = transform(alpha, indent+1)
169 beta = transform(beta , indent+1)
170 result = [0] * n
171 for i in range(half):
172 result[i*2] = alpha[i]
173 result[i*2+1] = beta[i]
174 print(idt, "merge", result)
175 for i in range(half - 1):
176 result[i*2+1] += result[i*2+3]
177 print(idt, "result", result)
178 return result
179
180
181 # totally cool *in-place* DCT algorithm
182 def transform2(vec):
183
184 # Initialization
185 n = len(vec)
186 print ()
187 print ("transform2", n)
188 levels = n.bit_length() - 1
189
190 # reference (read/write) the in-place data in *reverse-bit-order*
191 ri = list(range(n))
192 ri = [ri[reverse_bits(i, levels)] for i in range(n)]
193
194 # reference list for not needing to do data-swaps, just swap what
195 # *indices* are referenced (two levels of indirection at the moment)
196 # pre-reverse the data-swap list so that it *ends up* in the order 0123..
197 ji = list(range(n))
198 ji = halfrev2(ji, True)
199
200 # and pretend we LDed data in half-swapped *and* bit-reversed order as well
201 # TODO: merge these two
202 vec = halfrev2(vec, False)
203 vec = [vec[ri[i]] for i in range(n)]
204
205 print ("ri", ri)
206 print ("ji", ji)
207
208 # create a cos table: not strictly necessary but here for illustrative
209 # purposes, to demonstrate the point that it really *is* iterative.
210 # this table could be cached and used multiple times rather than
211 # computed every time.
212 ctable = []
213 size = n
214 while size >= 2:
215 halfsize = size // 2
216 for i in range(n//size):
217 for ci in range(halfsize):
218 ctable.append((math.cos((ci + 0.5) * math.pi / size) * 2.0))
219 size //= 2
220
221 # start the inner butterfly
222 size = n
223 k = 0
224 while size >= 2:
225 halfsize = size // 2
226 tablestep = n // size
227 ir = list(range(0, n, size))
228 print (" xform", size, ir)
229 for i in ir:
230 # two lists of half-range indices, e.g. j 0123, jr 7654
231 j = list(range(i, i + halfsize))
232 jr = list(range(i+halfsize, i + size))
233 jr.reverse()
234 print (" xform jr", j, jr)
235 for ci, (jl, jh) in enumerate(zip(j, jr)):
236 t1, t2 = vec[ri[ji[jl]]], vec[ri[ji[jh]]]
237 #coeff = (math.cos((ci + 0.5) * math.pi / size) * 2.0)
238 coeff = ctable[k]
239 k += 1
240 # normally DCT would use jl+halfsize not jh, here.
241 # to be able to work in-place, the idea is to perform a
242 # swap afterwards.
243 vec[ri[ji[jl]]] = t1 + t2
244 vec[ri[ji[jh]]] = (t1 - t2) * (1/coeff)
245 print ("coeff", size, i, "ci", ci,
246 "jl", ri[ji[jl]], "jh", ri[ji[jh]],
247 "i/n", (ci+0.5)/size, coeff, vec[ri[ji[jl]]],
248 vec[ri[ji[jh]]])
249 # instead of using jl+halfsize, perform a swap here.
250 # use half of j/jr because actually jl+halfsize = reverse(j)
251 hz2 = halfsize // 2 # can be zero which stops reversing 1-item lists
252 for ci, (jl, jh) in enumerate(zip(j[:hz2], jr[:hz2])):
253 jlh = jl+halfsize
254 # swap indices, NOT the data
255 tmp1, tmp2 = ji[jlh], ji[jh]
256 ji[jlh], ji[jh] = tmp2, tmp1
257 print (" swap", size, i, ji[jlh], ji[jh])
258 size //= 2
259
260 print("post-swapped", ri)
261 print("ji-swapped", ji)
262 print("transform2 pre-itersum", vec)
263
264 # now things are in the right order for the outer butterfly.
265 n = len(vec)
266 size = n // 2
267 while size >= 2:
268 halfsize = size // 2
269 ir = list(range(0, halfsize))
270 print ("itersum", halfsize, size, ir)
271 for i in ir:
272 jr = list(range(i+halfsize, i+n-halfsize, size))
273 print ("itersum jr", i+halfsize, i+size, jr)
274 for jh in jr:
275 vec[jh] += vec[jh+size]
276 print (" itersum", size, i, jh, jh+size)
277 size //= 2
278
279 print("transform2 result", vec)
280
281 return vec
282
283
284 # DCT type III, unscaled. Algorithm by Byeong Gi Lee, 1984.
285 # See: https://www.nayuki.io/res/fast-discrete-cosine-transform-algorithms/lee-new-algo-discrete-cosine-transform.pdf
286 def inverse_transform(vector, root=True, indent=0):
287 idt = " " * indent
288 if root:
289 vector = list(vector)
290 vector[0] /= 2
291 n = len(vector)
292 if n == 1:
293 return vector, [0]
294 elif n == 0 or n % 2 != 0:
295 raise ValueError()
296 else:
297 half = n // 2
298 alpha = [vector[0]]
299 beta = [vector[1]]
300 for i in range(2, n, 2):
301 alpha.append(vector[i])
302 beta.append(vector[i - 1] + vector[i + 1])
303 print (idt, "n", n, "alpha 0", end=" ")
304 for i in range(2, n, 2):
305 print (i, end=" ")
306 print ("beta 1", end=" ")
307 for i in range(2, n, 2):
308 print ("%d+%d" % (i-1, i+1), end=" ")
309 print()
310 inverse_transform(alpha, False, indent+1)
311 inverse_transform(beta , False, indent+1)
312 for i in range(half):
313 x, y = alpha[i], beta[i]
314 coeff = (math.cos((i + 0.5) * math.pi / n) * 2)
315 y /= coeff
316 vector[i] = x + y
317 vector[n-(i+1)] = x - y
318 print (idt, " v[%d] = alpha[%d]+beta[%d]" % (i, i, i))
319 print (idt, " v[%d] = alpha[%d]-beta[%d]" % (n-i-1, i, i))
320 return vector
321
322
323 # totally cool *in-place* DCT algorithm
324 def inverse_transform_iter(vec):
325
326 # Initialization
327 n = len(vec)
328 print ()
329 print ("transform2 inv", n, vec)
330 levels = n.bit_length() - 1
331
332 # reference (read/write) the in-place data in *reverse-bit-order*
333 ri = list(range(n))
334 ri = [ri[reverse_bits(i, levels)] for i in range(n)]
335
336 # reference list for not needing to do data-swaps, just swap what
337 # *indices* are referenced (two levels of indirection at the moment)
338 # pre-reverse the data-swap list so that it *ends up* in the order 0123..
339 ji = list(range(n))
340 #ji = halfrev2(ji, True)
341
342 print ("ri", ri)
343 print ("ji", ji)
344
345 # create a cos table: not strictly necessary but here for illustrative
346 # purposes, to demonstrate the point that it really *is* iterative.
347 # this table could be cached and used multiple times rather than
348 # computed every time.
349 ctable = []
350 size = n
351 while size >= 2:
352 halfsize = size // 2
353 for i in range(n//size):
354 for ci in range(halfsize):
355 ctable.append((math.cos((ci + 0.5) * math.pi / size) * 2.0))
356 size //= 2
357
358 # first divide element 0 by 2
359 vec[0] /= 2.0
360
361 print("transform2-inv pre-itersum", vec)
362
363 # first the outer butterfly (iterative sum thing)
364 n = len(vec)
365 size = n // 2
366 while size >= 2:
367 halfsize = size // 2
368 ir = list(range(0, halfsize))
369 print ("itersum", halfsize, size, ir)
370 for i in ir:
371 jr = list(range(i+halfsize, i+n-halfsize, size))
372 print ("itersum jr", i+halfsize, i+size, jr)
373 for jh in jr:
374 x = vec[jh]
375 y = vec[jh+size]
376 vec[jh+size] = x + y
377 print (" itersum", size, i, jh, jh+size,
378 x, y, "jh+sz", vec[jh+size])
379 size //= 2
380
381 print("transform2-inv post-itersum", vec)
382
383 # and pretend we LDed data in half-swapped *and* bit-reversed order as well
384 # TODO: merge these two
385 #vec = halfrev2(vec, False)
386 vec = [vec[ri[i]] for i in range(n)]
387 ri = list(range(n))
388
389 print("transform2-inv post-reorder", vec)
390
391 # start the inner butterfly (coefficients)
392 size = 2
393 k = 0
394 while size <= n:
395 halfsize = size // 2
396 tablestep = n // size
397 ir = list(range(0, n, size))
398 print (" xform", size, ir)
399 for i in ir:
400 # two lists of half-range indices, e.g. j 0123, jr 7654
401 j = list(range(i, i + halfsize))
402 jr = list(range(i+halfsize, i + size))
403 jr.reverse()
404 print (" xform jr", j, jr)
405 vec2 = deepcopy(vec)
406 for ci, (jl, jh) in enumerate(zip(j, jr)):
407 #t1, t2 = vec[ri[ji[jl]]], vec[ri[ji[jh]]]
408 t1, t2 = vec[jl], vec[jl+halfsize]
409 coeff = (math.cos((ci + 0.5) * math.pi / size) * 2.0)
410 #coeff = ctable[k]
411 k += 1
412 # normally DCT would use jl+halfsize not jh, here.
413 # to be able to work in-place, the idea is to perform a
414 # swap afterwards.
415 #vec[ri[ji[jl]]] = t1 + t2/coeff
416 #vec[ri[ji[jh]]] = t1 - t2/coeff
417 vec2[jl] = t1 + t2/coeff
418 vec2[jh] = t1 - t2/coeff
419 print ("coeff", size, i, "ci", ci,
420 "jl", ri[ji[jl]], "jh", ri[ji[jh]],
421 "i/n", (ci+0.5)/size, coeff,
422 "t1,t2", t1, t2,
423 "+/i", vec2[jl], vec2[jh])
424 #"+/i", vec2[ri[ji[jl]]], vec2[ri[ji[jh]]])
425 vec = vec2
426 continue
427 # instead of using jl+halfsize, perform a swap here.
428 # use half of j/jr because actually jl+halfsize = reverse(j)
429 hz2 = halfsize // 2 # can be zero which stops reversing 1-item lists
430 for ci, (jl, jh) in enumerate(zip(j[:hz2], jr[:hz2])):
431 jlh = jl+halfsize
432 # swap indices, NOT the data
433 tmp1, tmp2 = ji[jlh], ji[jh]
434 ji[jlh], ji[jh] = tmp2, tmp1
435 print (" swap", size, i, ji[jlh], ji[jh])
436 size *= 2
437
438 print("post-swapped", ri)
439 print("ji-swapped", ji)
440 print("transform2 result", vec)
441
442 return vec
443
444
445 def inverse_transform2(vector, root=True, indent=0):
446 idt = " " * indent
447 n = len(vector)
448 if root:
449 vector = list(vector)
450 vector[0] /= 2
451 if n == 1:
452 return vector
453 elif n == 0 or n % 2 != 0:
454 raise ValueError()
455 else:
456 print (idt, "inverse_xform2", vector)
457 half = n // 2
458 alpha = [vector[0]]
459 beta = [vector[1]]
460 for i in range(2, n, 2):
461 alpha.append(vector[i])
462 beta.append(vector[i - 1] + vector[i + 1])
463 print (idt, " alpha", alpha)
464 print (idt, " beta", beta)
465 inverse_transform2(alpha, False, indent+1)
466 inverse_transform2(beta , False, indent+1)
467 for i in range(half):
468 x, y = alpha[i], beta[i]
469 coeff = (math.cos((i + 0.5) * math.pi / n) * 2)
470 vector[i] = x + y / coeff
471 vector[n-(i+1)] = x - y / coeff
472 print (idt, " v[%d] = %f+%f/%f=%f" % (i, x, y, coeff, vector[i]))
473 print (idt, " v[%d] = %f-%f/%f=%f" % (n-i-1, x, y,
474 coeff, vector[n-i-1]))
475 return vector
476
477
478 def inverse_transform2_explore(vector, root=True, indent=0):
479 n = len(vector)
480 if root:
481 vector = list(vector)
482 if n == 1:
483 return vector
484 elif n == 0 or n % 2 != 0:
485 raise ValueError()
486 else:
487 half = n // 2
488 alpha = [vector[0]]
489 beta = [vector[1]]
490 for i in range(2, n, 2):
491 alpha.append(vector[i])
492 beta.append(("add%d" % indent, vector[i - 1], vector[i + 1]))
493 inverse_transform2_explore(alpha, False, indent+1)
494 inverse_transform2_explore(beta , False, indent+1)
495 for i in range(half):
496 x = alpha[i]
497 y = ("cos%d" % indent, beta[i], i, n)
498 vector[i] = ("add%d" % indent, x, y)
499 vector[n-(i + 1)] = ("sub%d" % indent, x, y)
500 return vector
501
502
503
504 # does the outer butterfly in a recursive fashion, used in an
505 # intermediary development of the in-place DCT.
506 def transform_itersum(vector, indent=0):
507 idt = " " * indent
508 n = len(vector)
509 if n == 1:
510 return list(vector)
511 elif n == 0 or n % 2 != 0:
512 raise ValueError()
513 else:
514 half = n // 2
515 alpha = [0] * half
516 beta = [0] * half
517 for i in range(half):
518 t1, t2 = vector[i], vector[i+half]
519 alpha[i] = t1
520 beta[i] = t2
521 alpha = transform_itersum(alpha, indent+1)
522 beta = transform_itersum(beta , indent+1)
523 result = [0] * n
524 for i in range(half):
525 result[i*2] = alpha[i]
526 result[i*2+1] = beta[i]
527 print(idt, "iter-merge", result)
528 for i in range(half - 1):
529 result[i*2+1] += result[i*2+3]
530 print(idt, "iter-result", result)
531 return result
532
533
534 # prints out an "add" schedule for the outer butterfly, recursively,
535 # matching what transform_itersum does.
536 def itersum_explore(vector, indent=0):
537 idt = " " * indent
538 n = len(vector)
539 if n == 1:
540 return list(vector)
541 elif n == 0 or n % 2 != 0:
542 raise ValueError()
543 else:
544 half = n // 2
545 alpha = [0] * half
546 beta = [0] * half
547 for i in range(half):
548 t1, t2 = vector[i], vector[i+half]
549 alpha[i] = t1
550 beta[i] = t2
551 alpha = itersum_explore(alpha, indent+1)
552 beta = itersum_explore(beta , indent+1)
553 result = [0] * n
554 for i in range(half):
555 result[i*2] = alpha[i]
556 result[i*2+1] = beta[i]
557 print(idt, "iter-merge", result)
558 for i in range(half - 1):
559 result[i*2+1] = ("add", result[i*2+1], result[i*2+3])
560 print(idt, "iter-result", result)
561 return result
562
563
564 # prints out the exact same outer butterfly but does so iteratively.
565 # by comparing the output from itersum_explore and itersum_explore2
566 # and by drawing out the resultant ADDs as a graph it was possible
567 # to deduce what the heck was going on.
568 def itersum_explore2(vec, indent=0):
569 n = len(vec)
570 size = n // 2
571 while size >= 2:
572 halfsize = size // 2
573 ir = list(range(0, halfsize))
574 print ("itersum", halfsize, size, ir)
575 for i in ir:
576 jr = list(range(i+halfsize, i+n-halfsize, size))
577 print ("itersum jr", i+halfsize, i+size, jr)
578 for jh in jr:
579 vec[jh] = ("add", vec[jh], vec[jh+size])
580 print (" itersum", size, i, jh, jh+size)
581 size //= 2
582
583 return vec
584
585 if __name__ == '__main__':
586 n = 16
587 vec = list(range(n))
588 levels = n.bit_length() - 1
589 vec = [vec[reverse_bits(i, levels)] for i in range(n)]
590 ops = itersum_explore(vec)
591 for i, x in enumerate(ops):
592 print (i, x)
593
594 n = 16
595 vec = list(range(n))
596 levels = n.bit_length() - 1
597 ops = itersum_explore2(vec)
598 for i, x in enumerate(ops):
599 print (i, x)
600
601 # halfrev test
602 vec = list(range(16))
603 print ("orig vec", vec)
604 vecr = halfrev(vec, True)
605 print ("reversed", vecr)
606 for i, v in enumerate(vecr):
607 print ("%2d %2d %04s %04s %04s" % (i, v,
608 bin(i)[2:], bin(v ^ i)[2:], bin(v)[2:]))
609 vecrr = halfrev(vecr, False)
610 assert vec == vecrr
611 vecrr = halfrev(vec, False)
612 print ("pre-reversed", vecrr)
613 for i, v in enumerate(vecrr):
614 print ("%2d %2d %04s %04s %04s" % (i, v,
615 bin(i)[2:], bin(v ^ i)[2:], bin(v)[2:]))
616 il = halfrev2(vec, False)
617 print ("iterative rev", il)
618 il = halfrev2(vec, True)
619 print ("iterative rev-true", il)
620
621 n = 4
622 vec = list(range(n))
623 levels = n.bit_length() - 1
624 ops = inverse_transform2_explore(vec)
625 for i, x in enumerate(ops):
626 print (i, x)
627