1 """Formal verification of partitioned operations
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
13 In other words, we have patterns as follows (assuming 32-bit words)::
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
20 * for 8-bit the partition bit is 1 and the previous is also 1
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
30 """
32 import os
33 import unittest
35 from nmigen import Elaboratable, Signal, Module, Const
36 from nmigen.asserts import Assert, Cover
37 from nmigen.hdl.ast import Assume
39 from nmutil.formaltest import FHDLTestCase
40 from nmutil.gtkw import write_gtkw
45 class PartitionedPattern(Elaboratable):
46 """ Generate a unique pattern, depending on partition size.
48 * 1-byte partitions: 0x11
49 * 2-byte partitions: 0x21 0x22
50 * 3-byte partitions: 0x31 0x32 0x33
52 And so on.
54 Useful as a test vector for testing the formal prover
56 """
57 def __init__(self, width, partition_points):
58 self.width = width
59 self.partition_points = PartitionPoints(partition_points)
60 self.mwidth = len(self.partition_points)+1
61 self.output = Signal(self.width, reset_less=True)
63 def elaborate(self, platform):
64 m = Module()
65 comb = m.d.comb
67 # Add a guard bit at each end
68 positions =  + list(self.partition_points.keys()) + [self.width]
69 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
70 # Begin counting at one
71 last_start = positions
72 last_end = positions
73 last_middle = (last_start+last_end)//2
74 comb += self.output[last_start:last_middle].eq(1)
75 # Build an incrementing cascade
76 for i in range(1, self.mwidth):
77 start = positions[i]
78 end = positions[i+1]
79 middle = (start + end) // 2
80 # Propagate from the previous byte, adding one to it.
81 with m.If(~gates[i]):
82 comb += self.output[start:middle].eq(
83 self.output[last_start:last_middle] + 1)
84 with m.Else():
85 # ... unless it's a partition boundary. If so, start again.
86 comb += self.output[start:middle].eq(1)
87 last_start = start
88 last_middle = middle
89 # Mirror the nibbles on the last byte
90 last_start = positions[-2]
91 last_end = positions[-1]
92 last_middle = (last_start+last_end)//2
93 comb += self.output[last_middle:last_end].eq(
94 self.output[last_start:last_middle])
95 for i in range(self.mwidth, 0, -1):
96 start = positions[i-1]
97 end = positions[i]
98 middle = (start + end) // 2
99 # Propagate from the previous byte.
100 with m.If(~gates[i]):
101 comb += self.output[middle:end].eq(
102 self.output[last_middle:last_end])
103 with m.Else():
104 # ... unless it's a partition boundary.
105 # If so, mirror the nibbles again.
106 comb += self.output[middle:end].eq(
107 self.output[start:middle])
108 last_middle = middle
109 last_end = end
111 return m
114 # This defines a module to drive the device under test and assert
115 # properties about its outputs
116 class Driver(Elaboratable):
117 def __init__(self):
118 # inputs and outputs
119 pass
121 @staticmethod
122 def elaborate(_):
123 m = Module()
124 comb = m.d.comb
125 width = 64
126 mwidth = 8
127 # Setup partition points and gates
128 points = PartitionPoints()
129 gates = Signal(mwidth-1)
130 step = int(width/mwidth)
131 for i in range(mwidth-1):
132 points[(i+1)*step] = gates[i]
133 # Instantiate the partitioned pattern producer
134 m.submodules.dut = dut = PartitionedPattern(width, points)
135 # Directly check some cases
136 with m.If(gates == 0):
137 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
138 with m.If(gates == 0b1100101):
139 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
140 with m.If(gates == 0b0001000):
141 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
142 with m.If(gates == 0b0100001):
143 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
144 with m.If(gates == 0b1000001):
145 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
146 with m.If(gates == 0b0000001):
147 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
148 # Choose a partition offset and width at random.
149 p_offset = Signal(range(mwidth))
150 p_width = Signal(range(mwidth+1))
151 p_finish = Signal(range(mwidth+1))
152 comb += p_finish.eq(p_offset + p_width)
153 # Partition must not be empty, and fit within the signal.
154 comb += Assume(p_width != 0)
155 comb += Assume(p_offset + p_width <= mwidth)
157 # Build the corresponding partition
158 # Use Assume to constraint the pattern to conform to the given offset
159 # and width. For each gate bit it is:
160 # 1) one, if on the partition boundary
161 # 2) zero, if it's inside the partition
162 # 3) don't care, otherwise
163 p_gates = Signal(mwidth+1)
164 for i in range(mwidth+1):
165 with m.If(i == p_offset):
166 # Partitions begin with 1
167 comb += Assume(p_gates[i] == 1)
168 with m.If((i > p_offset) & (i < p_finish)):
169 # The interior are all zeros
170 comb += Assume(p_gates[i] == 0)
171 with m.If(i == p_finish):
172 # End with 1 again
173 comb += Assume(p_gates[i] == 1)
174 # Check some possible partitions generating a given pattern
175 with m.If(p_gates == 0b0100110):
176 comb += Assert(((p_offset == 1) & (p_width == 1)) |
177 ((p_offset == 2) & (p_width == 3)))
178 # Remove guard bits at each end and assign to the DUT gates
179 comb += gates.eq(p_gates[1:])
180 # Generate shifted down outputs:
181 p_output = Signal(width)
182 positions =  + list(points.keys()) + [width]
183 for i in range(mwidth):
184 with m.If(p_offset == i):
185 comb += p_output.eq(dut.output[positions[i]:])
186 # Some checks on the shifted down output, irrespective of offset:
187 with m.If(p_width == 2):
188 comb += Assert(p_output[:16] == 0x_22_21)
189 with m.If(p_width == 4):
190 comb += Assert(p_output[:32] == 0x_44_43_42_41)
191 # test zero shift
192 with m.If(p_offset == 0):
193 comb += Assert(p_output == dut.output)
194 # Output an example.
195 # Make it interesting, by having four partitions.
196 # Make the selected partition not start at the very beginning.
197 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
198 # Generate and check expected values for all possible partition sizes.
199 # Here, we assume partition sizes are multiple of the smaller size.
200 for w in range(1, mwidth+1):
201 with m.If(p_width == w):
202 # calculate the expected output, for the given bit width
203 bit_width = w * step
204 expected = Signal(bit_width, name=f"expected_{w}")
205 for b in range(w):
206 # lower nibble is the position
207 comb += expected[b*8:b*8+4].eq(b+1)
208 # upper nibble is the partition width
209 comb += expected[b*8+4:b*8+8].eq(w)
210 # truncate the output, compare and assert
211 comb += Assert(p_output[:bit_width] == expected)
212 return m
215 class PartitionTestCase(FHDLTestCase):
216 def test_formal(self):
217 traces = [
218 ('p_offset[2:0]', {'base': 'dec'}),
219 ('p_width[3:0]', {'base': 'dec'}),
220 ('p_finish[3:0]', {'base': 'dec'}),
221 ('p_gates[8:0]', {'base': 'bin'}),
222 ('dut', {'submodule': 'dut'}, [
223 ('gates[6:0]', {'base': 'bin'}),
224 'output[63:0]']),
225 'p_output[63:0]', 'expected_3[21:0]']
226 write_gtkw(
227 'proof_partition_cover.gtkw',
228 os.path.dirname(__file__) +
229 '/proof_partition_formal/engine_0/trace0.vcd',
230 traces,
231 module='top',
232 zoom=-3
233 )
234 write_gtkw(
235 'proof_partition_bmc.gtkw',
236 os.path.dirname(__file__) +
237 '/proof_partition_formal/engine_0/trace.vcd',
238 traces,
239 module='top',
240 zoom=-3
241 )
242 module = Driver()
243 self.assertFormal(module, mode="bmc", depth=1)
244 self.assertFormal(module, mode="cover", depth=1)
247 if __name__ == '__main__':
248 unittest.main()