Allow a variable number of operands in the proof driver
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
index c86f262dd471eb5edbc1b5fced6de8c508285385..df1b6f871c15c1cf16e338bb918b3e976d919434 100644 (file)
@@ -33,7 +33,7 @@ import os
 import unittest
 import operator
 
-from nmigen import Elaboratable, Signal, Module, Const
+from nmigen import Elaboratable, Signal, Module, Const, Repl
 from nmigen.asserts import Assert, Cover
 from nmigen.hdl.ast import Assume
 
@@ -323,7 +323,7 @@ class GeneratorDriver(Elaboratable):
 
 class ComparisonOpDriver(Elaboratable):
     """Checks comparison operations on partitioned signals"""
-    def __init__(self, op, width, mwidth):
+    def __init__(self, op, width=64, mwidth=8, nops=2):
         self.op = op
         """Operation to perform. Must accept two integer-like inputs and give
         a predicate-like output (1-bit partitions in case of
@@ -332,49 +332,87 @@ class ComparisonOpDriver(Elaboratable):
         """Partition full width"""
         self.mwidth = mwidth
         """Maximum number of equally sized partitions"""
-
+        self.nops = nops
+        """Number of input operands"""
     def elaborate(self, _):
         m = Module()
         comb = m.d.comb
         width = self.width
         mwidth = self.mwidth
+        nops = self.nops
         # setup partition points and gates
         step = int(width/mwidth)
         points, gates = make_partitions(step, mwidth)
         # setup inputs and outputs
-        a = PartitionedSignal(points, width)
-        b = PartitionedSignal(points, width)
+        operands = list()
+        for i in range(nops):
+            inp = PartitionedSignal(points, width, name=f"i_{i+1}")
+            inp.set_module(m)
+            operands.append(inp)
         output = Signal(mwidth)
-        a.set_module(m)
-        b.set_module(m)
         # perform the operation on the partitioned signals
-        comb += output.eq(self.op(a, b))
+        comb += output.eq(self.op(*operands))
         # instantiate the partitioned gate generator and connect the gates
         m.submodules.gen = gen = GateGenerator(mwidth)
         comb += gates.eq(gen.gates)
         p_offset = gen.p_offset
         p_width = gen.p_width
+        # generate shifted down inputs and outputs
+        p_operands = list()
+        for i in range(nops):
+            p_i = Signal(width, name=f"p_{i+1}")
+            p_operands.append(p_i)
+            for pos in range(mwidth):
+                with m.If(p_offset == pos):
+                    comb += p_i.eq(operands[i].sig[pos * step:])
+        p_output = Signal(mwidth)
+        for pos in range(mwidth):
+            with m.If(p_offset == pos):
+                comb += p_output.eq(output[pos:])
+        # generate and check expected values for all possible partition sizes
+        for w in range(1, mwidth+1):
+            with m.If(p_width == w):
+                # calculate the expected output, for the given bit width,
+                # truncating the inputs to the partition size
+                input_bit_width = w * step
+                output_bit_width = w
+                expected = Signal(output_bit_width, name=f"expected_{w}")
+                trunc_operands = list()
+                for i in range(nops):
+                    t_i = Signal(input_bit_width, name=f"t{w}_{i+1}")
+                    trunc_operands.append(t_i)
+                    comb += t_i.eq(p_operands[i][:input_bit_width])
+                lsb = Signal(name=f"lsb_{w}")
+                comb += lsb.eq(self.op(*trunc_operands))
+                comb += expected.eq(Repl(lsb, output_bit_width))
+                # truncate the output, compare and assert
+                comb += Assert(p_output[:output_bit_width] == expected)
         # output a test case
-        comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1))
+        comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1) &
+                      (p_output != 0))
         return m
 
 
 class PartitionTestCase(FHDLTestCase):
     def test_formal(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
         traces = [
-            ('p_offset[2:0]', {'base': 'dec'}),
-            ('p_width[3:0]', {'base': 'dec'}),
-            ('p_finish[3:0]', {'base': 'dec'}),
-            ('p_gates[8:0]', {'base': 'bin'}),
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_finish[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
             ('dut', {'submodule': 'dut'}, [
-                ('gates[6:0]', {'base': 'bin'}),
+                ('gates[6:0]', 'bin'),
                 'output[63:0]']),
             'p_output[63:0]', 'expected_3[21:0]']
         write_gtkw(
             'proof_partition_cover.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_formal/engine_0/trace0.vcd',
-            traces,
+            traces, style,
             module='top',
             zoom=-3
         )
@@ -382,7 +420,7 @@ class PartitionTestCase(FHDLTestCase):
             'proof_partition_bmc.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_formal/engine_0/trace.vcd',
-            traces,
+            traces, style,
             module='top',
             zoom=-3
         )
@@ -391,20 +429,25 @@ class PartitionTestCase(FHDLTestCase):
         self.assertFormal(module, mode="cover", depth=1)
 
     def test_generator(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
         traces = [
-            ('p_offset[2:0]', {'base': 'dec'}),
-            ('p_width[3:0]', {'base': 'dec'}),
-            ('p_finish[3:0]', {'base': 'dec'}),
-            ('p_gates[8:0]', {'base': 'bin'}),
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_finish[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
             ('dut', {'submodule': 'dut'}, [
-                ('gates[6:0]', {'base': 'bin'}),
+                ('gates[6:0]', 'bin'),
                 'output[63:0]']),
-            'p_output[63:0]', 'expected_3[21:0]']
+            'p_output[63:0]', 'expected_3[21:0]',
+            'a_3[23:0]', 'b_3[32:0]', 'expected_3[2:0]']
         write_gtkw(
             'proof_partition_generator_cover.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_generator/engine_0/trace0.vcd',
-            traces,
+            traces, style,
             module='top',
             zoom=-3
         )
@@ -412,7 +455,7 @@ class PartitionTestCase(FHDLTestCase):
             'proof_partition_generator_bmc.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_generator/engine_0/trace.vcd',
-            traces,
+            traces, style,
             module='top',
             zoom=-3
         )
@@ -421,19 +464,28 @@ class PartitionTestCase(FHDLTestCase):
         self.assertFormal(module, mode="cover", depth=1)
 
     def test_partsig_eq(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
         traces = [
-            ('p_offset[2:0]', {'base': 'dec'}),
-            ('p_width[3:0]', {'base': 'dec'}),
-            ('p_gates[8:0]', {'base': 'bin'}),
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
+            'i_1[63:0]', 'i_2[63:0]',
             ('eq_1', {'submodule': 'eq_1'}, [
-                ('gates[6:0]', {'base': 'bin'}),
+                ('gates[6:0]', 'bin'),
                 'a[63:0]', 'b[63:0]',
-                ('output[7:0]', {'base': 'bin'})])]
+                ('output[7:0]', 'bin')]),
+            'p_1[63:0]', 'p_2[63:0]',
+            ('p_output[7:0]', 'bin'),
+            't3_1[23:0]', 't3_2[23:0]', 'lsb_3',
+            ('expected_3[2:0]', 'bin')]
         write_gtkw(
             'proof_partsig_eq_cover.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_partsig_eq/engine_0/trace0.vcd',
-            traces,
+            traces, style,
             module='top',
             zoom=-3
         )
@@ -441,11 +493,65 @@ class PartitionTestCase(FHDLTestCase):
             'proof_partsig_eq_bmc.gtkw',
             os.path.dirname(__file__) +
             '/proof_partition_partsig_eq/engine_0/trace.vcd',
-            traces,
+            traces, style,
+            module='top',
+            zoom=-3
+        )
+        module = ComparisonOpDriver(operator.eq)
+        self.assertFormal(module, mode="bmc", depth=1)
+        self.assertFormal(module, mode="cover", depth=1)
+
+    def test_partsig_ne(self):
+        module = ComparisonOpDriver(operator.ne)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_gt(self):
+        module = ComparisonOpDriver(operator.gt)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_ge(self):
+        module = ComparisonOpDriver(operator.ge)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_lt(self):
+        module = ComparisonOpDriver(operator.lt)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_le(self):
+        module = ComparisonOpDriver(operator.le)
+        self.assertFormal(module, mode="bmc", depth=1)
+
+    def test_partsig_all(self):
+        style = {
+            'dec': {'base': 'dec'},
+            'bin': {'base': 'bin'}
+        }
+        traces = [
+            ('p_offset[2:0]', 'dec'),
+            ('p_width[3:0]', 'dec'),
+            ('p_gates[8:0]', 'bin'),
+            'i_1[63:0]',
+            ('eq_1', {'submodule': 'eq_1'}, [
+                ('gates[6:0]', 'bin'),
+                'a[63:0]', 'b[63:0]',
+                ('output[7:0]', 'bin')]),
+            'p_1[63:0]',
+            ('p_output[7:0]', 'bin'),
+            't3_1[23:0]', 'lsb_3',
+            ('expected_3[2:0]', 'bin')]
+        write_gtkw(
+            'proof_partsig_all_cover.gtkw',
+            os.path.dirname(__file__) +
+            '/proof_partition_partsig_all/engine_0/trace0.vcd',
+            traces, style,
             module='top',
             zoom=-3
         )
-        module = ComparisonOpDriver(operator.eq, 64, 8)
+
+        def op_all(obj):
+            return obj.all()
+
+        module = ComparisonOpDriver(op_all, nops=1)
         self.assertFormal(module, mode="bmc", depth=1)
         self.assertFormal(module, mode="cover", depth=1)