From 5f909a0266065b00b6cf0b1d3045fa08d1191340 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Thu, 6 Feb 2020 14:57:57 +0000 Subject: [PATCH] make common function for testing comparators --- src/ieee754/part/partsig.py | 12 ++- src/ieee754/part/test/test_partsig.py | 126 +++++++++----------------- 2 files changed, 53 insertions(+), 85 deletions(-) diff --git a/src/ieee754/part/partsig.py b/src/ieee754/part/partsig.py index e3f056d2..e1627b7d 100644 --- a/src/ieee754/part/partsig.py +++ b/src/ieee754/part/partsig.py @@ -102,12 +102,22 @@ class PartitionedSignal: return self._compare(width, self, other, "eq", PartitionedEqGtGe.EQ) def __ne__(self, other): - return ~self.__eq__(other) + width = self.sig.shape()[0] + invert = ~self.sig # invert the input before compare EQ. + return self._compare(width, invert, other, "eq", PartitionedEqGtGe.EQ) def __gt__(self, other): width = self.sig.shape()[0] return self._compare(width, self, other, "gt", PartitionedEqGtGe.GT) + def __lt__(self, other): + width = self.sig.shape()[0] + return self._compare(width, other, self, "gt", PartitionedEqGtGe.GT) + def __ge__(self, other): width = self.sig.shape()[0] return self._compare(width, self, other, "ge", PartitionedEqGtGe.GE) + + def __le__(self, other): + width = self.sig.shape()[0] + return self._compare(width, other, self, "ge", PartitionedEqGtGe.GE) diff --git a/src/ieee754/part/test/test_partsig.py b/src/ieee754/part/test/test_partsig.py index 908a3e91..2dd76bc8 100644 --- a/src/ieee754/part/test/test_partsig.py +++ b/src/ieee754/part/test/test_partsig.py @@ -31,11 +31,17 @@ class TestAddMod(Elaboratable): self.eq_output = Signal(len(partpoints)+1) self.gt_output = Signal(len(partpoints)+1) self.ge_output = Signal(len(partpoints)+1) + self.ne_output = Signal(len(partpoints)+1) + self.lt_output = Signal(len(partpoints)+1) + self.le_output = Signal(len(partpoints)+1) def elaborate(self, platform): m = Module() self.a.set_module(m) self.b.set_module(m) + m.d.comb += self.lt_output.eq(self.a < self.b) + m.d.comb += self.ne_output.eq(self.a != self.b) + m.d.comb += self.le_output.eq(self.a <= self.b) m.d.comb += self.gt_output.eq(self.a > self.b) m.d.comb += self.eq_output.eq(self.a == self.b) m.d.comb += self.ge_output.eq(self.a >= self.b) @@ -83,83 +89,25 @@ class TestPartitionPoints(unittest.TestCase): yield part_mask.eq(0b1111) yield from test_add("4-bit", 0xF000, 0x0F00, 0x00F0, 0x000F) - def test_eq(msg_prefix, *maskbit_list): - for a, b in [(0x0000, 0x0000), - (0x1234, 0x1234), - (0xABCD, 0xABCD), - (0xFFFF, 0x0000), - (0x0000, 0x0000), - (0xFFFF, 0xFFFF), - (0x0000, 0xFFFF)]: - yield module.a.eq(a) - yield module.b.eq(b) - yield Delay(0.1e-6) - # convert to mask_list - mask_list = [] - for mb in maskbit_list: - v = 0 - for i in range(4): - if mb & (1< 0x{y:X} != 0x{outval:X}, masklist %s" - #print ((msg % str(maskbit_list)).format(locals())) - self.assertEqual(y, outval, msg % str(maskbit_list)) - yield part_mask.eq(0) - yield from test_eq("16-bit", 0b1111) - yield part_mask.eq(0b10) - yield from test_eq("8-bit", 0b1100, 0b0011) - yield part_mask.eq(0b1111) - yield from test_eq("4-bit", 0b1000, 0b0100, 0b0010, 0b0001) + def test_ne_fn(a, b, mask): + return (a & mask) != (b & mask) - def test_gt(msg_prefix, *maskbit_list): - for a, b in [(0x0000, 0x0000), - (0x1234, 0x1234), - (0xABCD, 0xABCD), - (0xFFFF, 0x0000), - (0x0000, 0x0000), - (0xFFFF, 0xFFFF), - (0x0000, 0xFFFF)]: - yield module.a.eq(a) - yield module.b.eq(b) - yield Delay(0.1e-6) - # convert to mask_list - mask_list = [] - for mb in maskbit_list: - v = 0 - for i in range(4): - if mb & (1< (b & mask): - # OR y with the lowest set bit in the mask - y |= (maskbit_list[i] & ~(maskbit_list[i]-1)) - # check the result - outval = (yield module.gt_output) - msg = f"{msg_prefix}: 0x{a:X} == 0x{b:X}" + \ - f" => 0x{y:X} != 0x{outval:X}, masklist %s" - #print ((msg % str(maskbit_list)).format(locals())) - self.assertEqual(y, outval, msg % str(maskbit_list)) - yield part_mask.eq(0) - yield from test_gt("16-bit", 0b1111) - yield part_mask.eq(0b10) - yield from test_gt("8-bit", 0b1100, 0b0011) - yield part_mask.eq(0b1111) - yield from test_gt("4-bit", 0b1000, 0b0100, 0b0010, 0b0001) + def test_lt_fn(a, b, mask): + return (a & mask) < (b & mask) - def test_ge(msg_prefix, *maskbit_list): + def test_le_fn(a, b, mask): + return (a & mask) <= (b & mask) + + def test_eq_fn(a, b, mask): + return (a & mask) == (b & mask) + + def test_gt_fn(a, b, mask): + return (a & mask) > (b & mask) + + def test_ge_fn(a, b, mask): + return (a & mask) >= (b & mask) + + def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list): for a, b in [(0x0000, 0x0000), (0x1234, 0x1234), (0xABCD, 0xABCD), @@ -181,21 +129,31 @@ class TestPartitionPoints(unittest.TestCase): y = 0 # do the partitioned tests for i, mask in enumerate(mask_list): - if (a & mask) >= (b & mask): + if test_fn(a, b, mask): # OR y with the lowest set bit in the mask y |= (maskbit_list[i] & ~(maskbit_list[i]-1)) # check the result - outval = (yield module.ge_output) - msg = f"{msg_prefix}: 0x{a:X} == 0x{b:X}" + \ + outval = (yield getattr(module, "%s_output" % mod_attr)) + msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \ f" => 0x{y:X} != 0x{outval:X}, masklist %s" #print ((msg % str(maskbit_list)).format(locals())) self.assertEqual(y, outval, msg % str(maskbit_list)) - yield part_mask.eq(0) - yield from test_ge("16-bit", 0b1111) - yield part_mask.eq(0b10) - yield from test_ge("8-bit", 0b1100, 0b0011) - yield part_mask.eq(0b1111) - yield from test_ge("4-bit", 0b1000, 0b0100, 0b0010, 0b0001) + + for (test_fn, mod_attr) in ((test_eq_fn, "eq"), + (test_gt_fn, "gt"), + (test_ge_fn, "ge"), + (test_lt_fn, "lt"), + (test_le_fn, "le"), + (test_ne_fn, "ne"), + ): + yield part_mask.eq(0) + yield from test_binop("16-bit", test_fn, mod_attr, 0b1111) + yield part_mask.eq(0b10) + yield from test_binop("8-bit", test_fn, mod_attr, + 0b1100, 0b0011) + yield part_mask.eq(0b1111) + yield from test_binop("4-bit", test_fn, mod_attr, + 0b1000, 0b0100, 0b0010, 0b0001) sim.add_process(async_process) sim.run() -- 2.30.2