make common function for testing comparators
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 6 Feb 2020 14:57:57 +0000 (14:57 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Thu, 6 Feb 2020 14:57:57 +0000 (14:57 +0000)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py

index e3f056d23a41c7bc7e12a572197b42f78a288d9b..e1627b7d6d5891de0fd3c6f69acfd4f77eead0c5 100644 (file)
@@ -102,12 +102,22 @@ class PartitionedSignal:
         return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ)
 
     def __ne__(self, other):
-        return ~self.__eq__(other)
+        width = self.sig.shape()[0]
+        invert = ~self.sig # invert the input before compare EQ.
+        return self._compare(width, invert, other, "eq", PartitionedEqGtGe.EQ)
 
     def __gt__(self, other):
         width = self.sig.shape()[0]
         return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT)
 
+    def __lt__(self, other):
+        width = self.sig.shape()[0]
+        return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT)
+
     def __ge__(self, other):
         width = self.sig.shape()[0]
         return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE)
+
+    def __le__(self, other):
+        width = self.sig.shape()[0]
+        return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE)
index 908a3e91b04b98bcf71d474958c6c2fadcd86f57..2dd76bc8f22264507b39315040fc96db26195208 100644 (file)
@@ -31,11 +31,17 @@ class TestAddMod(Elaboratable):
         self.eq_output = Signal(len(partpoints)+1)
         self.gt_output = Signal(len(partpoints)+1)
         self.ge_output = Signal(len(partpoints)+1)
+        self.ne_output = Signal(len(partpoints)+1)
+        self.lt_output = Signal(len(partpoints)+1)
+        self.le_output = Signal(len(partpoints)+1)
 
     def elaborate(self, platform):
         m = Module()
         self.a.set_module(m)
         self.b.set_module(m)
+        m.d.comb += self.lt_output.eq(self.a < self.b)
+        m.d.comb += self.ne_output.eq(self.a != self.b)
+        m.d.comb += self.le_output.eq(self.a <= self.b)
         m.d.comb += self.gt_output.eq(self.a > self.b)
         m.d.comb += self.eq_output.eq(self.a == self.b)
         m.d.comb += self.ge_output.eq(self.a >= self.b)
@@ -83,83 +89,25 @@ class TestPartitionPoints(unittest.TestCase):
             yield part_mask.eq(0b1111)
             yield from test_add("4-bit", 0xF000, 0x0F00, 0x00F0, 0x000F)
 
-            def test_eq(msg_prefix, *maskbit_list):
-                for a, b in [(0x0000, 0x0000),
-                             (0x1234, 0x1234),
-                             (0xABCD, 0xABCD),
-                             (0xFFFF, 0x0000),
-                             (0x0000, 0x0000),
-                             (0xFFFF, 0xFFFF),
-                             (0x0000, 0xFFFF)]:
-                    yield module.a.eq(a)
-                    yield module.b.eq(b)
-                    yield Delay(0.1e-6)
-                    # convert to mask_list
-                    mask_list = []
-                    for mb in maskbit_list:
-                        v = 0
-                        for i in range(4):
-                            if mb & (1<<i):
-                                v |= 0xf << (i*4)
-                        mask_list.append(v)
-                    y = 0
-                    # do the partitioned tests
-                    for i, mask in enumerate(mask_list):
-                        if (a & mask) == (b & mask):
-                            # OR y with the lowest set bit in the mask
-                            y |= (maskbit_list[i] & ~(maskbit_list[i]-1))
-                    # check the result
-                    outval = (yield module.eq_output)
-                    msg = f"{msg_prefix}: 0x{a:X} == 0x{b:X}" + \
-                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
-                    #print ((msg % str(maskbit_list)).format(locals()))
-                    self.assertEqual(y, outval, msg % str(maskbit_list))
-            yield part_mask.eq(0)
-            yield from test_eq("16-bit", 0b1111)
-            yield part_mask.eq(0b10)
-            yield from test_eq("8-bit", 0b1100, 0b0011)
-            yield part_mask.eq(0b1111)
-            yield from test_eq("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+            def test_ne_fn(a, b, mask):
+                return (a & mask) != (b & mask)
 
-            def test_gt(msg_prefix, *maskbit_list):
-                for a, b in [(0x0000, 0x0000),
-                             (0x1234, 0x1234),
-                             (0xABCD, 0xABCD),
-                             (0xFFFF, 0x0000),
-                             (0x0000, 0x0000),
-                             (0xFFFF, 0xFFFF),
-                             (0x0000, 0xFFFF)]:
-                    yield module.a.eq(a)
-                    yield module.b.eq(b)
-                    yield Delay(0.1e-6)
-                    # convert to mask_list
-                    mask_list = []
-                    for mb in maskbit_list:
-                        v = 0
-                        for i in range(4):
-                            if mb & (1<<i):
-                                v |= 0xf << (i*4)
-                        mask_list.append(v)
-                    y = 0
-                    # do the partitioned tests
-                    for i, mask in enumerate(mask_list):
-                        if (a & mask) > (b & mask):
-                            # OR y with the lowest set bit in the mask
-                            y |= (maskbit_list[i] & ~(maskbit_list[i]-1))
-                    # check the result
-                    outval = (yield module.gt_output)
-                    msg = f"{msg_prefix}: 0x{a:X} == 0x{b:X}" + \
-                        f" => 0x{y:X} != 0x{outval:X}, masklist %s"
-                    #print ((msg % str(maskbit_list)).format(locals()))
-                    self.assertEqual(y, outval, msg % str(maskbit_list))
-            yield part_mask.eq(0)
-            yield from test_gt("16-bit", 0b1111)
-            yield part_mask.eq(0b10)
-            yield from test_gt("8-bit", 0b1100, 0b0011)
-            yield part_mask.eq(0b1111)
-            yield from test_gt("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+            def test_lt_fn(a, b, mask):
+                return (a & mask) < (b & mask)
 
-            def test_ge(msg_prefix, *maskbit_list):
+            def test_le_fn(a, b, mask):
+                return (a & mask) <= (b & mask)
+
+            def test_eq_fn(a, b, mask):
+                return (a & mask) == (b & mask)
+
+            def test_gt_fn(a, b, mask):
+                return (a & mask) > (b & mask)
+
+            def test_ge_fn(a, b, mask):
+                return (a & mask) >= (b & mask)
+
+            def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
                 for a, b in [(0x0000, 0x0000),
                              (0x1234, 0x1234),
                              (0xABCD, 0xABCD),
@@ -181,21 +129,31 @@ class TestPartitionPoints(unittest.TestCase):
                     y = 0
                     # do the partitioned tests
                     for i, mask in enumerate(mask_list):
-                        if (a & mask) >= (b & mask):
+                        if test_fn(a, b, mask):
                             # OR y with the lowest set bit in the mask
                             y |= (maskbit_list[i] & ~(maskbit_list[i]-1))
                     # check the result
-                    outval = (yield module.ge_output)
-                    msg = f"{msg_prefix}: 0x{a:X} == 0x{b:X}" + \
+                    outval = (yield getattr(module, "%s_output" % mod_attr))
+                    msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
                         f" => 0x{y:X} != 0x{outval:X}, masklist %s"
                     #print ((msg % str(maskbit_list)).format(locals()))
                     self.assertEqual(y, outval, msg % str(maskbit_list))
-            yield part_mask.eq(0)
-            yield from test_ge("16-bit", 0b1111)
-            yield part_mask.eq(0b10)
-            yield from test_ge("8-bit", 0b1100, 0b0011)
-            yield part_mask.eq(0b1111)
-            yield from test_ge("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
+
+            for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
+                                        (test_gt_fn, "gt"),
+                                        (test_ge_fn, "ge"),
+                                        (test_lt_fn, "lt"),
+                                        (test_le_fn, "le"),
+                                        (test_ne_fn, "ne"),
+                                        ):
+                yield part_mask.eq(0)
+                yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
+                yield part_mask.eq(0b10)
+                yield from test_binop("8-bit", test_fn, mod_attr,
+                                      0b1100, 0b0011)
+                yield part_mask.eq(0b1111)
+                yield from test_binop("4-bit", test_fn, mod_attr,
+                                      0b1000, 0b0100, 0b0010, 0b0001)
 
         sim.add_process(async_process)
         sim.run()