Allow the proof driver to check operations with integer output
authorCesar Strauss <cestrauss@gmail.com>
Sun, 17 Jan 2021 17:00:27 +0000 (14:00 -0300)
committerCesar Strauss <cestrauss@gmail.com>
Sun, 17 Jan 2021 17:14:48 +0000 (14:14 -0300)
The "part_out" parameter tells the driver whether the output of the
operation is supposed to be arithmetic or boolean.

Added checks for addition and subtraction, without carry.

src/ieee754/part/formal/proof_partition.py

index c818115..6d7b088 100644 (file)
@@ -321,25 +321,26 @@ class GeneratorDriver(Elaboratable):
         return m
 
 
-class ComparisonOpDriver(Elaboratable):
-    """Checks comparison operations on partitioned signals"""
-    def __init__(self, op, width=64, mwidth=8, nops=2):
+class OpDriver(Elaboratable):
+    """Checks operations on partitioned signals"""
+    def __init__(self, op, width=64, mwidth=8, nops=2, part_out=True):
         self.op = op
-        """Operation to perform. Must accept two integer-like inputs and give
-        a predicate-like output (1-bit partitions in case of
-        PartitionedSignal types)"""
+        """Operation to perform"""
         self.width = width
         """Partition full width"""
         self.mwidth = mwidth
         """Maximum number of equally sized partitions"""
         self.nops = nops
         """Number of input operands"""
+        self.part_out = part_out
+        """True if output is partition-like"""
     def elaborate(self, _):
         m = Module()
         comb = m.d.comb
         width = self.width
         mwidth = self.mwidth
         nops = self.nops
+        part_out = self.part_out
         # setup partition points and gates
         step = int(width/mwidth)
         points, gates = make_partitions(step, mwidth)
@@ -349,7 +350,13 @@ class ComparisonOpDriver(Elaboratable):
             inp = PartitionedSignal(points, width, name=f"i_{i+1}")
             inp.set_module(m)
             operands.append(inp)
-        output = Signal(mwidth)
+        if part_out:
+            out_width = mwidth
+            out_step = 1
+        else:
+            out_width = width
+            out_step = step
+        output = Signal(out_width)
         # perform the operation on the partitioned signals
         comb += output.eq(self.op(*operands))
         # instantiate the partitioned gate generator and connect the gates
@@ -365,31 +372,43 @@ class ComparisonOpDriver(Elaboratable):
             for pos in range(mwidth):
                 with m.If(p_offset == pos):
                     comb += p_i.eq(operands[i].sig[pos * step:])
-        p_output = Signal(mwidth)
+        p_output = Signal(out_width)
         for pos in range(mwidth):
             with m.If(p_offset == pos):
-                comb += p_output.eq(output[pos:])
+                comb += p_output.eq(output[pos * out_step:])
         # generate and check expected values for all possible partition sizes
+        all_operands_non_zero = Signal()
         for w in range(1, mwidth+1):
             with m.If(p_width == w):
                 # calculate the expected output, for the given bit width,
                 # truncating the inputs to the partition size
                 input_bit_width = w * step
-                output_bit_width = w
+                output_bit_width = w * out_step
                 expected = Signal(output_bit_width, name=f"expected_{w}")
                 trunc_operands = list()
                 for i in range(nops):
                     t_i = Signal(input_bit_width, name=f"t{w}_{i+1}")
                     trunc_operands.append(t_i)
                     comb += t_i.eq(p_operands[i][:input_bit_width])
-                lsb = Signal(name=f"lsb_{w}")
-                comb += lsb.eq(self.op(*trunc_operands))
-                comb += expected.eq(Repl(lsb, output_bit_width))
+                if part_out:
+                    # for partition-like outputs, calculate the LSB
+                    # and replicate it on the partition
+                    lsb = Signal(name=f"lsb_{w}")
+                    comb += lsb.eq(self.op(*trunc_operands))
+                    comb += expected.eq(Repl(lsb, output_bit_width))
+                else:
+                    # otherwise, just take the operation result
+                    comb += expected.eq(self.op(*trunc_operands))
                 # truncate the output, compare and assert
                 comb += Assert(p_output[:output_bit_width] == expected)
+                # ensure a test case with all non-zero operands
+                non_zero_op = Signal(nops)
+                for i in range(nops):
+                    comb += non_zero_op[i].eq(trunc_operands[i].any())
+                comb += all_operands_non_zero.eq(non_zero_op.all())
         # output a test case
         comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1) &
-                      (p_output != 0))
+                      (p_output != 0) & all_operands_non_zero)
         return m
 
 
@@ -497,28 +516,28 @@ class PartitionTestCase(FHDLTestCase):
             module='top',
             zoom=-3
         )
-        module = ComparisonOpDriver(operator.eq)
+        module = OpDriver(operator.eq)
         self.assertFormal(module, mode="bmc", depth=1)
         self.assertFormal(module, mode="cover", depth=1)
 
     def test_partsig_ne(self):
-        module = ComparisonOpDriver(operator.ne)
+        module = OpDriver(operator.ne)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_gt(self):
-        module = ComparisonOpDriver(operator.gt)
+        module = OpDriver(operator.gt)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_ge(self):
-        module = ComparisonOpDriver(operator.ge)
+        module = OpDriver(operator.ge)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_lt(self):
-        module = ComparisonOpDriver(operator.lt)
+        module = OpDriver(operator.lt)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_le(self):
-        module = ComparisonOpDriver(operator.le)
+        module = OpDriver(operator.le)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_all(self):
@@ -551,7 +570,7 @@ class PartitionTestCase(FHDLTestCase):
         def op_all(obj):
             return obj.all()
 
-        module = ComparisonOpDriver(op_all, nops=1)
+        module = OpDriver(op_all, nops=1)
         self.assertFormal(module, mode="bmc", depth=1)
         self.assertFormal(module, mode="cover", depth=1)
 
@@ -560,7 +579,7 @@ class PartitionTestCase(FHDLTestCase):
         def op_any(obj):
             return obj.any()
 
-        module = ComparisonOpDriver(op_any, nops=1)
+        module = OpDriver(op_any, nops=1)
         self.assertFormal(module, mode="bmc", depth=1)
 
     def test_partsig_xor(self):
@@ -569,7 +588,42 @@ class PartitionTestCase(FHDLTestCase):
             return obj.xor()
 
         # 8-bit partitions take a long time, for some reason
-        module = ComparisonOpDriver(op_xor, nops=1, width=32, mwidth=4)
+        module = OpDriver(op_xor, nops=1, width=32, mwidth=4)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_add(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
+        traces = [
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
+            'i_1[63:0]', 'i_2[63:0]',
+            ('add_1', {'submodule': 'add_1'}, [
+                ('gates[6:0]', 'bin'),
+                'a[63:0]', 'b[63:0]',
+                'output[63:0]']),
+            'p_1[63:0]', 'p_2[63:0]',
+            'p_output[63:0]',
+            't3_1[23:0]', 't3_2[23:0]',
+            'expected_3[23:0]']
+        write_gtkw(
+            'proof_partsig_add_cover.gtkw',
+            os.path.dirname(__file__) +
+            '/proof_partition_partsig_add/engine_0/trace0.vcd',
+            traces, style,
+            module='top',
+            zoom=-3
+        )
+
+        module = OpDriver(operator.add, part_out=False)
+        self.assertFormal(module, mode="bmc", depth=1)
+        self.assertFormal(module, mode="cover", depth=1)
+
+    def test_partsig_sub(self):
+        module = OpDriver(operator.sub, part_out=False)
         self.assertFormal(module, mode="bmc", depth=1)