format code
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
index 38efaadce66fb574c9021b6e084272b44c8c9696..518555b3b619798dcbc6127f4aeabf2da8df2f07 100644 (file)
@@ -18,14 +18,15 @@ import math
 def first_zero(x):
     res = 0
     for i in range(16):
-        if x & (1<<i):
+        if x & (1 << i):
             return res
         res += 1
 
+
 def count_bits(x):
     res = 0
     for i in range(16):
-        if x & (1<<i):
+        if x & (1 << i):
             res += 1
     return res
 
@@ -53,10 +54,10 @@ class TestAddMod2(Elaboratable):
         self.b = SimdSignal(partpoints, width)
         self.bsig = Signal(width)
         self.add_output = Signal(width)
-        self.ls_output = Signal(width) # left shift
-        self.ls_scal_output = Signal(width) # left shift
-        self.rs_output = Signal(width) # right shift
-        self.rs_scal_output = Signal(width) # right shift
+        self.ls_output = Signal(width)  # left shift
+        self.ls_scal_output = Signal(width)  # left shift
+        self.rs_output = Signal(width)  # right shift
+        self.rs_scal_output = Signal(width)  # right shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -200,10 +201,10 @@ class TestAddMod(Elaboratable):
         self.b = SimdSignal(partpoints, width)
         self.bsig = Signal(width)
         self.add_output = Signal(width)
-        self.ls_output = Signal(width) # left shift
-        self.ls_scal_output = Signal(width) # left shift
-        self.rs_output = Signal(width) # right shift
-        self.rs_scal_output = Signal(width) # right shift
+        self.ls_output = Signal(width)  # left shift
+        self.ls_scal_output = Signal(width)  # left shift
+        self.rs_output = Signal(width)  # right shift
+        self.rs_scal_output = Signal(width)  # right shift
         self.sub_output = Signal(width)
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
@@ -377,11 +378,11 @@ class TestCat(unittest.TestCase):
                     apart, bpart = [], []
                     ajump, bjump = alen // 4, blen // 4
                     for i in range(4):
-                        apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
-                        bpart.append((b >> (bjump*i) & ((1<<bjump)-1)))
+                        apart.append((a >> (ajump*i) & ((1 << ajump)-1)))
+                        bpart.append((b >> (bjump*i) & ((1 << bjump)-1)))
 
-                    print ("apart bpart", hex(a), hex(b),
-                            list(map(hex, apart)), list(map(hex, bpart)))
+                    print("apart bpart", hex(a), hex(b),
+                          list(map(hex, apart)), list(map(hex, bpart)))
 
                     yield module.a.lower().eq(a)
                     yield module.b.lower().eq(b)
@@ -398,22 +399,22 @@ class TestCat(unittest.TestCase):
                     for i in runlengths:
                         # a first
                         for _ in range(i):
-                            print ("runlength", i,
-                                   "ai", ai,
-                                   "apart", hex(apart[ai]),
-                                   "j", j)
+                            print("runlength", i,
+                                  "ai", ai,
+                                  "apart", hex(apart[ai]),
+                                  "j", j)
                             y |= apart[ai] << j
-                            print ("    y", hex(y))
+                            print("    y", hex(y))
                             j += ajump
                             ai += 1
                         # now b
                         for _ in range(i):
-                            print ("runlength", i,
-                                   "bi", bi,
-                                   "bpart", hex(bpart[bi]),
-                                   "j", j)
+                            print("runlength", i,
+                                  "bi", bi,
+                                  "bpart", hex(bpart[bi]),
+                                  "j", j)
                             y |= bpart[bi] << j
-                            print ("    y", hex(y))
+                            print("    y", hex(y))
                             j += bjump
                             bi += 1
 
@@ -461,22 +462,22 @@ class TestRepl(unittest.TestCase):
                 alen = 16
                 # test values a
                 for a in [0x0000,
-                             0xDCBA,
-                             0x1234,
-                             0xABCD,
-                             0xFFFF,
-                             0x0000,
-                             0x1F1F,
-                             0xF1F1,
-                             ]:
+                          0xDCBA,
+                          0x1234,
+                          0xABCD,
+                          0xFFFF,
+                          0x0000,
+                          0x1F1F,
+                          0xF1F1,
+                          ]:
 
                     # convert a to partitions
                     apart = []
                     ajump = alen // 4
                     for i in range(4):
-                        apart.append((a >> (ajump*i) & ((1<<ajump)-1)))
+                        apart.append((a >> (ajump*i) & ((1 << ajump)-1)))
 
-                    print ("apart", hex(a), list(map(hex, apart)))
+                    print("apart", hex(a), list(map(hex, apart)))
 
                     yield module.a.lower().eq(a)
                     yield Delay(0.1e-6)
@@ -492,12 +493,12 @@ class TestRepl(unittest.TestCase):
                         # a twice because the test is Repl(a, 2)
                         for aidx in range(2):
                             for _ in range(i):
-                                print ("runlength", i,
-                                       "ai", ai,
-                                       "apart", hex(apart[ai[aidx]]),
-                                       "j", j)
+                                print("runlength", i,
+                                      "ai", ai,
+                                      "apart", hex(apart[ai[aidx]]),
+                                      "j", j)
                                 y |= apart[ai[aidx]] << j
-                                print ("    y", hex(y))
+                                print("    y", hex(y))
                                 j += ajump
                                 ai[aidx] += 1
 
@@ -531,8 +532,8 @@ class TestAssign(unittest.TestCase):
                             part_mask, scalar)
 
         test_name = "part_sig_ass_%d_%d_%s_%s" % (in_width, out_width,
-                     "signed" if out_signed else "unsigned",
-                     "scalar" if scalar else "partitioned")
+                                                  "signed" if out_signed else "unsigned",
+                                                  "scalar" if scalar else "partitioned")
 
         traces = [part_mask,
                   module.ass_out.lower()]
@@ -562,31 +563,32 @@ class TestAssign(unittest.TestCase):
                           0x00c0,
                           0x0c00,
                           0xc000,
-                             0x1234,
-                             0xDCBA,
-                             0xABCD,
-                             0x0000,
-                             0xFFFF,
-                        ] + randomvals:
+                          0x1234,
+                          0xDCBA,
+                          0xABCD,
+                          0x0000,
+                          0xFFFF,
+                          ] + randomvals:
                     # work out the runlengths for this mask.
                     # 0b011 returns [1,1,2] (for a mask of length 3)
                     mval = yield part_mask
                     runlengths = get_runlengths(mval, 3)
 
-                    print ("test a", hex(a), "mask", bin(mval), "widths",
-                            in_width, out_width,
-                            "signed", out_signed,
-                            "scalar", scalar)
+                    print("test a", hex(a), "mask", bin(mval), "widths",
+                          in_width, out_width,
+                          "signed", out_signed,
+                          "scalar", scalar)
 
                     # convert a to runlengths sub-sections
                     apart = []
                     ajump = alen // 4
                     ai = 0
                     for i in runlengths:
-                        subpart = (a >> (ajump*ai) & ((1<<(ajump*i))-1))
-                        msb = (subpart >> ((ajump*i)-1)) # will contain the sign
+                        subpart = (a >> (ajump*ai) & ((1 << (ajump*i))-1))
+                        # will contain the sign
+                        msb = (subpart >> ((ajump*i)-1))
                         apart.append((subpart, msb))
-                        print ("apart", ajump*i, hex(a), hex(subpart), msb)
+                        print("apart", ajump*i, hex(a), hex(subpart), msb)
                         if not scalar:
                             ai += i
 
@@ -606,26 +608,27 @@ class TestAssign(unittest.TestCase):
                         signext = 0
                         if out_signed and ojump > ajump:
                             if amsb:
-                                signext = (-1 << ajump*i) & ((1<<(ojump*i))-1)
+                                signext = (-1 << ajump *
+                                           i) & ((1 << (ojump*i))-1)
                                 av |= signext
                         # truncate if needed
                         if ojump < ajump:
-                                av &= ((1<<(ojump*i))-1)
-                        print ("runlength", i,
-                               "ai", ai,
-                               "apart", hex(av), amsb,
-                               "signext", hex(signext),
-                               "j", j)
+                            av &= ((1 << (ojump*i))-1)
+                        print("runlength", i,
+                              "ai", ai,
+                              "apart", hex(av), amsb,
+                              "signext", hex(signext),
+                              "j", j)
                         y |= av << j
-                        print ("    y", hex(y))
+                        print("    y", hex(y))
                         j += ojump*i
                         ai += 1
 
-                    y &= (1<<out_width)-1
+                    y &= (1 << out_width)-1
 
                     # check the result
                     outval = (yield module.ass_out.lower())
-                    outval &= (1<<out_width)-1
+                    outval &= (1 << out_width)-1
                     msg = f"{msg_prefix}: assign " + \
                         f"mask 0x{mval:X} input 0x{a:X}" + \
                         f" => expected 0x{y:X} != actual 0x{outval:X}"
@@ -633,8 +636,8 @@ class TestAssign(unittest.TestCase):
 
             # run the actual tests, here - 16/8/4 bit partitions
             for (mask, name) in ((0, "16-bit"),
-                                  (0b10, "8-bit"),
-                                  (0b111, "4-bit")):
+                                 (0b10, "8-bit"),
+                                 (0b111, "4-bit")):
                 with self.subTest(name + " " + test_name):
                     yield part_mask.eq(mask)
                     yield Settle()
@@ -704,14 +707,14 @@ class TestSimdSignal(unittest.TestCase):
                           0xF000,
                           0x00FF,
                           0xFF00,
-                             0x1234,
-                             0xABCD,
-                             0xFFFF,
-                             0x8000,
-                             0xBEEF, 0xFEED,
-                                ]+randomvals:
+                          0x1234,
+                          0xABCD,
+                          0xFFFF,
+                          0x8000,
+                          0xBEEF, 0xFEED,
+                          ]+randomvals:
                     with self.subTest("%s %s %s" % (msg_prefix,
-                                    test_fn.__name__, hex(a))):
+                                                    test_fn.__name__, hex(a))):
                         yield module.a.lower().eq(a)
                         yield Delay(0.1e-6)
                         # convert to mask_list
@@ -737,7 +740,7 @@ class TestSimdSignal(unittest.TestCase):
 
             for (test_fn, mod_attr) in ((test_xor_fn, "xor"),
                                         (test_all_fn, "all"),
-                                        (test_bool_fn, "any"), # same as bool
+                                        (test_bool_fn, "any"),  # same as bool
                                         (test_bool_fn, "bool"),
                                         #(test_ne_fn, "ne"),
                                         ):
@@ -745,17 +748,17 @@ class TestSimdSignal(unittest.TestCase):
                 yield from test_horizop("16-bit", test_fn, mod_attr, 0b1111)
                 yield part_mask.eq(0b10)
                 yield from test_horizop("8-bit", test_fn, mod_attr,
-                                      0b1100, 0b0011)
+                                        0b1100, 0b0011)
                 yield part_mask.eq(0b1111)
                 yield from test_horizop("4-bit", test_fn, mod_attr,
-                                      0b1000, 0b0100, 0b0010, 0b0001)
+                                        0b1000, 0b0100, 0b0010, 0b0001)
 
             def test_ls_scal_fn(carry_in, a, b, mask):
                 # reduce range of b
                 bits = count_bits(mask)
                 newb = b & ((bits-1))
-                print ("%x %x %x bits %d trunc %x" % \
-                        (a, b, mask, bits, newb))
+                print("%x %x %x bits %d trunc %x" %
+                      (a, b, mask, bits, newb))
                 b = newb
                 # TODO: carry
                 carry_in = 0
@@ -771,8 +774,8 @@ class TestSimdSignal(unittest.TestCase):
                 # reduce range of b
                 bits = count_bits(mask)
                 newb = b & ((bits-1))
-                print ("%x %x %x bits %d trunc %x" % \
-                        (a, b, mask, bits, newb))
+                print("%x %x %x bits %d trunc %x" %
+                      (a, b, mask, bits, newb))
                 b = newb
                 # TODO: carry
                 carry_in = 0
@@ -788,15 +791,15 @@ class TestSimdSignal(unittest.TestCase):
                 # reduce range of b
                 bits = count_bits(mask)
                 fz = first_zero(mask)
-                newb = b & ((bits-1)<<fz)
-                print ("%x %x %x bits %d zero %d trunc %x" % \
-                        (a, b, mask, bits, fz, newb))
+                newb = b & ((bits-1) << fz)
+                print("%x %x %x bits %d zero %d trunc %x" %
+                      (a, b, mask, bits, fz, newb))
                 b = newb
                 # TODO: carry
                 carry_in = 0
                 lsb = mask & ~(mask-1) if carry_in else 0
                 b = (b & mask)
-                b = b >>fz
+                b = b >> fz
                 sum = ((a & mask) << b)
                 result = mask & sum
                 carry = (sum & mask) != sum
@@ -808,15 +811,15 @@ class TestSimdSignal(unittest.TestCase):
                 # reduce range of b
                 bits = count_bits(mask)
                 fz = first_zero(mask)
-                newb = b & ((bits-1)<<fz)
-                print ("%x %x %x bits %d zero %d trunc %x" % \
-                        (a, b, mask, bits, fz, newb))
+                newb = b & ((bits-1) << fz)
+                print("%x %x %x bits %d zero %d trunc %x" %
+                      (a, b, mask, bits, fz, newb))
                 b = newb
                 # TODO: carry
                 carry_in = 0
                 lsb = mask & ~(mask-1) if carry_in else 0
                 b = (b & mask)
-                b = b >>fz
+                b = b >> fz
                 sum = ((a & mask) >> b)
                 result = mask & sum
                 carry = (sum & mask) != sum
@@ -868,7 +871,7 @@ class TestSimdSignal(unittest.TestCase):
                     y = 0
                     carry_result = 0
                     for i, mask in enumerate(mask_list):
-                        print ("i/mask", i, hex(mask))
+                        print("i/mask", i, hex(mask))
                         res, c = test_fn(carry, a, b, mask)
                         y |= res
                         lsb = mask & ~(mask - 1)
@@ -893,15 +896,15 @@ class TestSimdSignal(unittest.TestCase):
             # output attribute (mod_attr) will contain the result to be
             # compared against the expected output from test_fn
             for (test_fn, mod_attr) in (
-                                        (test_ls_scal_fn, "ls_scal"),
-                                        (test_ls_fn, "ls"),
-                                        (test_rs_scal_fn, "rs_scal"),
-                                        (test_rs_fn, "rs"),
-                                        (test_add_fn, "add"),
-                                        (test_sub_fn, "sub"),
-                                        (test_neg_fn, "neg"),
-                                        (test_signed_fn, "signed"),
-                                        ):
+                (test_ls_scal_fn, "ls_scal"),
+                (test_ls_fn, "ls"),
+                (test_rs_scal_fn, "rs_scal"),
+                (test_rs_fn, "rs"),
+                (test_add_fn, "add"),
+                (test_sub_fn, "sub"),
+                (test_neg_fn, "neg"),
+                (test_signed_fn, "signed"),
+            ):
                 yield part_mask.eq(0)
                 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
                 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)