use lists rather than list incomprehension
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 18 Jul 2021 10:09:29 +0000 (11:09 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 18 Jul 2021 10:09:29 +0000 (11:09 +0100)
src/openpower/decoder/isa/fastdct-test.py
src/openpower/decoder/isa/fastdctlee.py

index 07c655a6107dd48928cfdb24388eadb6a75101c6..516c101c22e00c497474eb10564dcbafdfc27326 100644 (file)
@@ -28,21 +28,21 @@ import fastdctlee, naivedct
 class FastDctTest(unittest.TestCase):
        
        def test_fast_dct_lee_vs_naive(self):
-               for i in range(1, 12):
+               for i in range(1, 9):
                        n = 2**i
                        vector = FastDctTest.random_vector(n)
                        expect = naivedct.transform(vector)
-                       actual = fastdctlee.transform(vector)
+                       actual = fastdctlee.transform2(vector)
                        self.assertListAlmostEqual(actual, expect)
                        expect = naivedct.inverse_transform(vector)
                        actual = fastdctlee.inverse_transform(vector)
                        self.assertListAlmostEqual(actual, expect)
        
        def test_fast_dct_lee_invertibility(self):
-               for i in range(1, 18):
+               for i in range(1, 10):
                        n = 2**i
                        vector = FastDctTest.random_vector(n)
-                       temp = fastdctlee.transform(vector)
+                       temp = fastdctlee.transform2(vector)
                        temp = fastdctlee.inverse_transform(temp)
                        temp = [(val * 2.0 / n) for val in temp]
                        self.assertListAlmostEqual(vector, temp)
index 38cc545101266d7f2f2b7b5f93e17fff85922f7d..1e6de3b6a315b64077985ef647402e138bca1413 100644 (file)
@@ -49,6 +49,32 @@ def transform(vector):
         return result
 
 
+def transform2(vector):
+    n = len(vector)
+    if n == 1:
+        return list(vector)
+    elif n == 0 or n % 2 != 0:
+        raise ValueError()
+    else:
+        half = n // 2
+        alpha = [0] * half
+        beta = [0] * half
+        for i in range(half):
+            t1, t2 = vector[i], vector[n-i-1]
+            k = (math.cos((i + 0.5) * math.pi / n) * 2.0)
+            alpha[i] = t1 + t2
+            beta[i] = (t1 - t2) * (1/k)
+        alpha = transform2(alpha)
+        beta  = transform2(beta )
+        result = [0] * n
+        for i in range(half):
+            result[i*2] = alpha[i]
+            result[i*2+1] = beta[i]
+        for i in range(half - 1):
+            result[i*2+1] += result[i*2+3]
+        return result
+
+
 # DCT type III, unscaled. Algorithm by Byeong Gi Lee, 1984.
 # See: https://www.nayuki.io/res/fast-discrete-cosine-transform-algorithms/lee-new-algo-discrete-cosine-transform.pdf
 def inverse_transform(vector, root=True, indent=0):