Start work on improving formal verification of PartitionedSignal
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34
35 from nmigen import Elaboratable, Signal, Module, Const
36 from nmigen.asserts import Assert, Cover
37
38 from nmutil.formaltest import FHDLTestCase
39 from nmutil.gtkw import write_gtkw
40
41 from ieee754.part_mul_add.partpoints import PartitionPoints
42
43
44 class PartitionedPattern(Elaboratable):
45 """ Generate a unique pattern, depending on partition size.
46
47 * 1-byte partitions: 0x11
48 * 2-byte partitions: 0x21 0x22
49 * 3-byte partitions: 0x31 0x32 0x33
50
51 And so on.
52
53 Useful as a test vector for testing the formal prover
54
55 """
56 def __init__(self, width, partition_points):
57 self.width = width
58 self.partition_points = PartitionPoints(partition_points)
59 self.mwidth = len(self.partition_points)+1
60 self.output = Signal(self.width, reset_less=True)
61
62 def elaborate(self, platform):
63 m = Module()
64 comb = m.d.comb
65
66 # Add a guard bit at each end
67 positions = [0] + list(self.partition_points.keys()) + [self.width]
68 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
69 # Begin counting at one
70 last_start = positions[0]
71 last_end = positions[1]
72 comb += self.output[last_start:last_end].eq(1)
73 # Build an incrementing cascade
74 for i in range(1, self.mwidth):
75 start = positions[i]
76 end = positions[i+1]
77 # Propagate from the previous byte, adding one to it
78 comb += self.output[start:end].eq(
79 self.output[last_start:last_end] + 1)
80 last_start = start
81 last_end = end
82 return m
83
84
85 # This defines a module to drive the device under test and assert
86 # properties about its outputs
87 class Driver(Elaboratable):
88 def __init__(self):
89 # inputs and outputs
90 pass
91
92 @staticmethod
93 def elaborate(_):
94 m = Module()
95 comb = m.d.comb
96 sync = m.d.sync
97 width = 64
98 mwidth = 8
99 out = Signal(width)
100 # Setup partition points and gates
101 points = PartitionPoints()
102 gates = Signal(mwidth-1)
103 step = int(width/mwidth)
104 for i in range(mwidth-1):
105 points[(i+1)*step] = gates[i]
106 # Instantiate the partitioned pattern producer
107 m.submodules.dut = dut = PartitionedPattern(width, points)
108 # Directly check some cases
109 comb += Assert(dut.output == 0x0807060504030201)
110 comb += Cover(1)
111 return m
112
113
114 class PartitionTestCase(FHDLTestCase):
115 def test_formal(self):
116 traces = ['output[63:0]']
117 write_gtkw(
118 'test_formal_cover.gtkw',
119 os.path.dirname(__file__) +
120 '/proof_partition_formal/engine_0/trace0.vcd',
121 traces,
122 module='top.dut',
123 zoom="formal"
124 )
125 write_gtkw(
126 'test_formal_bmc.gtkw',
127 os.path.dirname(__file__) +
128 '/proof_partition_formal/engine_0/trace.vcd',
129 traces,
130 module='top.dut',
131 zoom="formal"
132 )
133 module = Driver()
134 self.assertFormal(module, mode="bmc", depth=1)
135 self.assertFormal(module, mode="cover", depth=1)
136
137
138 if __name__ == '__main__':
139 unittest.main()