Generate the bit pattern of gates corresponding to a partition
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34
35 from nmigen import Elaboratable, Signal, Module, Const
36 from nmigen.asserts import Assert, Cover
37 from nmigen.hdl.ast import Assume
38
39 from nmutil.formaltest import FHDLTestCase
40 from nmutil.gtkw import write_gtkw
41
42 from ieee754.part_mul_add.partpoints import PartitionPoints
43
44
45 class PartitionedPattern(Elaboratable):
46 """ Generate a unique pattern, depending on partition size.
47
48 * 1-byte partitions: 0x11
49 * 2-byte partitions: 0x21 0x22
50 * 3-byte partitions: 0x31 0x32 0x33
51
52 And so on.
53
54 Useful as a test vector for testing the formal prover
55
56 """
57 def __init__(self, width, partition_points):
58 self.width = width
59 self.partition_points = PartitionPoints(partition_points)
60 self.mwidth = len(self.partition_points)+1
61 self.output = Signal(self.width, reset_less=True)
62
63 def elaborate(self, platform):
64 m = Module()
65 comb = m.d.comb
66
67 # Add a guard bit at each end
68 positions = [0] + list(self.partition_points.keys()) + [self.width]
69 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
70 # Begin counting at one
71 last_start = positions[0]
72 last_end = positions[1]
73 last_middle = (last_start+last_end)//2
74 comb += self.output[last_start:last_middle].eq(1)
75 # Build an incrementing cascade
76 for i in range(1, self.mwidth):
77 start = positions[i]
78 end = positions[i+1]
79 middle = (start + end) // 2
80 # Propagate from the previous byte, adding one to it.
81 with m.If(~gates[i]):
82 comb += self.output[start:middle].eq(
83 self.output[last_start:last_middle] + 1)
84 with m.Else():
85 # ... unless it's a partition boundary. If so, start again.
86 comb += self.output[start:middle].eq(1)
87 last_start = start
88 last_middle = middle
89 # Mirror the nibbles on the last byte
90 last_start = positions[-2]
91 last_end = positions[-1]
92 last_middle = (last_start+last_end)//2
93 comb += self.output[last_middle:last_end].eq(
94 self.output[last_start:last_middle])
95 for i in range(self.mwidth, 0, -1):
96 start = positions[i-1]
97 end = positions[i]
98 middle = (start + end) // 2
99 # Propagate from the previous byte.
100 with m.If(~gates[i]):
101 comb += self.output[middle:end].eq(
102 self.output[last_middle:last_end])
103 with m.Else():
104 # ... unless it's a partition boundary.
105 # If so, mirror the nibbles again.
106 comb += self.output[middle:end].eq(
107 self.output[start:middle])
108 last_middle = middle
109 last_end = end
110
111 return m
112
113
114 # This defines a module to drive the device under test and assert
115 # properties about its outputs
116 class Driver(Elaboratable):
117 def __init__(self):
118 # inputs and outputs
119 pass
120
121 @staticmethod
122 def elaborate(_):
123 m = Module()
124 comb = m.d.comb
125 width = 64
126 mwidth = 8
127 # Setup partition points and gates
128 points = PartitionPoints()
129 gates = Signal(mwidth-1)
130 step = int(width/mwidth)
131 for i in range(mwidth-1):
132 points[(i+1)*step] = gates[i]
133 # Instantiate the partitioned pattern producer
134 m.submodules.dut = dut = PartitionedPattern(width, points)
135 # Directly check some cases
136 with m.If(gates == 0):
137 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
138 with m.If(gates == 0b1100101):
139 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
140 with m.If(gates == 0b0001000):
141 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
142 with m.If(gates == 0b0100001):
143 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
144 with m.If(gates == 0b1000001):
145 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
146 with m.If(gates == 0b0000001):
147 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
148 # Make it interesting, by having four partitions.
149 comb += Cover(sum(gates) == 3)
150 # Choose a partition offset and width at random.
151 p_offset = Signal(range(mwidth))
152 p_width = Signal(range(mwidth+1))
153 p_finish = Signal(range(mwidth+1))
154 comb += p_finish.eq(p_offset + p_width)
155 # Partition must not be empty, and fit within the signal.
156 comb += Assume(p_width != 0)
157 comb += Assume(p_offset + p_width <= mwidth)
158
159 # Build the corresponding partition
160 # Use Assume to constraint the pattern to conform to the given offset
161 # and width. For each gate bit it is:
162 # 1) one, if on the partition boundary
163 # 2) zero, if it's inside the partition
164 # 3) don't care, otherwise
165 p_gates = Signal(mwidth+1)
166 for i in range(mwidth+1):
167 with m.If(i == p_offset):
168 # Partitions begin with 1
169 comb += Assume(p_gates[i] == 1)
170 with m.If((i > p_offset) & (i < p_finish)):
171 # The interior are all zeros
172 comb += Assume(p_gates[i] == 0)
173 with m.If(i == p_finish):
174 # End with 1 again
175 comb += Assume(p_gates[i] == 1)
176 # Check some possible partitions generating a given pattern
177 with m.If(p_gates == 0b0100110):
178 comb += Assert(((p_offset == 1) & (p_width == 1)) |
179 ((p_offset == 2) & (p_width == 3)))
180 # Remove guard bits at each end and assign to the DUT gates
181 comb += gates.eq(p_gates[1:])
182 return m
183
184
185 class PartitionTestCase(FHDLTestCase):
186 def test_formal(self):
187 traces = [
188 ('dut', {'submodule': 'dut'}, [
189 'output[63:0]',
190 ('gates[6:0]', {'base': 'bin'})]),
191 ('p_offset[2:0]', {'base': 'dec'}),
192 ('p_width[3:0]', {'base': 'dec'}),
193 ('p_finish[3:0]', {'base': 'dec'}),
194 ('p_gates[8:0]', {'base': 'bin'})]
195 write_gtkw(
196 'proof_partition_cover.gtkw',
197 os.path.dirname(__file__) +
198 '/proof_partition_formal/engine_0/trace0.vcd',
199 traces,
200 module='top',
201 zoom=-3
202 )
203 write_gtkw(
204 'proof_partition_bmc.gtkw',
205 os.path.dirname(__file__) +
206 '/proof_partition_formal/engine_0/trace.vcd',
207 traces,
208 module='top',
209 zoom=-3
210 )
211 module = Driver()
212 self.assertFormal(module, mode="bmc", depth=1)
213 self.assertFormal(module, mode="cover", depth=1)
214
215
216 if __name__ == '__main__':
217 unittest.main()