Start proof for PartitionedSignal equals operator
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34 import operator
35
36 from nmigen import Elaboratable, Signal, Module, Const
37 from nmigen.asserts import Assert, Cover
38 from nmigen.hdl.ast import Assume
39
40 from nmutil.formaltest import FHDLTestCase
41 from nmutil.gtkw import write_gtkw
42
43 from ieee754.part_mul_add.partpoints import PartitionPoints
44 from ieee754.part.partsig import PartitionedSignal
45
46
47 class PartitionedPattern(Elaboratable):
48 """ Generate a unique pattern, depending on partition size.
49
50 * 1-byte partitions: 0x11
51 * 2-byte partitions: 0x21 0x22
52 * 3-byte partitions: 0x31 0x32 0x33
53
54 And so on.
55
56 Useful as a test vector for testing the formal prover
57
58 """
59 def __init__(self, width, partition_points):
60 self.width = width
61 self.partition_points = PartitionPoints(partition_points)
62 self.mwidth = len(self.partition_points)+1
63 self.output = Signal(self.width, reset_less=True)
64
65 def elaborate(self, platform):
66 m = Module()
67 comb = m.d.comb
68
69 # Add a guard bit at each end
70 positions = [0] + list(self.partition_points.keys()) + [self.width]
71 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
72 # Begin counting at one
73 last_start = positions[0]
74 last_end = positions[1]
75 last_middle = (last_start+last_end)//2
76 comb += self.output[last_start:last_middle].eq(1)
77 # Build an incrementing cascade
78 for i in range(1, self.mwidth):
79 start = positions[i]
80 end = positions[i+1]
81 middle = (start + end) // 2
82 # Propagate from the previous byte, adding one to it.
83 with m.If(~gates[i]):
84 comb += self.output[start:middle].eq(
85 self.output[last_start:last_middle] + 1)
86 with m.Else():
87 # ... unless it's a partition boundary. If so, start again.
88 comb += self.output[start:middle].eq(1)
89 last_start = start
90 last_middle = middle
91 # Mirror the nibbles on the last byte
92 last_start = positions[-2]
93 last_end = positions[-1]
94 last_middle = (last_start+last_end)//2
95 comb += self.output[last_middle:last_end].eq(
96 self.output[last_start:last_middle])
97 for i in range(self.mwidth, 0, -1):
98 start = positions[i-1]
99 end = positions[i]
100 middle = (start + end) // 2
101 # Propagate from the previous byte.
102 with m.If(~gates[i]):
103 comb += self.output[middle:end].eq(
104 self.output[last_middle:last_end])
105 with m.Else():
106 # ... unless it's a partition boundary.
107 # If so, mirror the nibbles again.
108 comb += self.output[middle:end].eq(
109 self.output[start:middle])
110 last_middle = middle
111 last_end = end
112
113 return m
114
115
116 # This defines a module to drive the device under test and assert
117 # properties about its outputs
118 class Driver(Elaboratable):
119 def __init__(self):
120 # inputs and outputs
121 pass
122
123 @staticmethod
124 def elaborate(_):
125 m = Module()
126 comb = m.d.comb
127 width = 64
128 mwidth = 8
129 # Setup partition points and gates
130 points = PartitionPoints()
131 gates = Signal(mwidth-1)
132 step = int(width/mwidth)
133 for i in range(mwidth-1):
134 points[(i+1)*step] = gates[i]
135 # Instantiate the partitioned pattern producer
136 m.submodules.dut = dut = PartitionedPattern(width, points)
137 # Directly check some cases
138 with m.If(gates == 0):
139 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
140 with m.If(gates == 0b1100101):
141 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
142 with m.If(gates == 0b0001000):
143 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
144 with m.If(gates == 0b0100001):
145 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
146 with m.If(gates == 0b1000001):
147 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
148 with m.If(gates == 0b0000001):
149 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
150 # Choose a partition offset and width at random.
151 p_offset = Signal(range(mwidth))
152 p_width = Signal(range(mwidth+1))
153 p_finish = Signal(range(mwidth+1))
154 comb += p_finish.eq(p_offset + p_width)
155 # Partition must not be empty, and fit within the signal.
156 comb += Assume(p_width != 0)
157 comb += Assume(p_offset + p_width <= mwidth)
158
159 # Build the corresponding partition
160 # Use Assume to constraint the pattern to conform to the given offset
161 # and width. For each gate bit it is:
162 # 1) one, if on the partition boundary
163 # 2) zero, if it's inside the partition
164 # 3) don't care, otherwise
165 p_gates = Signal(mwidth+1)
166 for i in range(mwidth+1):
167 with m.If(i == p_offset):
168 # Partitions begin with 1
169 comb += Assume(p_gates[i] == 1)
170 with m.If((i > p_offset) & (i < p_finish)):
171 # The interior are all zeros
172 comb += Assume(p_gates[i] == 0)
173 with m.If(i == p_finish):
174 # End with 1 again
175 comb += Assume(p_gates[i] == 1)
176 # Check some possible partitions generating a given pattern
177 with m.If(p_gates == 0b0100110):
178 comb += Assert(((p_offset == 1) & (p_width == 1)) |
179 ((p_offset == 2) & (p_width == 3)))
180 # Remove guard bits at each end and assign to the DUT gates
181 comb += gates.eq(p_gates[1:])
182 # Generate shifted down outputs:
183 p_output = Signal(width)
184 positions = [0] + list(points.keys()) + [width]
185 for i in range(mwidth):
186 with m.If(p_offset == i):
187 comb += p_output.eq(dut.output[positions[i]:])
188 # Some checks on the shifted down output, irrespective of offset:
189 with m.If(p_width == 2):
190 comb += Assert(p_output[:16] == 0x_22_21)
191 with m.If(p_width == 4):
192 comb += Assert(p_output[:32] == 0x_44_43_42_41)
193 # test zero shift
194 with m.If(p_offset == 0):
195 comb += Assert(p_output == dut.output)
196 # Output an example.
197 # Make it interesting, by having four partitions.
198 # Make the selected partition not start at the very beginning.
199 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
200 # Generate and check expected values for all possible partition sizes.
201 # Here, we assume partition sizes are multiple of the smaller size.
202 for w in range(1, mwidth+1):
203 with m.If(p_width == w):
204 # calculate the expected output, for the given bit width
205 bit_width = w * step
206 expected = Signal(bit_width, name=f"expected_{w}")
207 for b in range(w):
208 # lower nibble is the position
209 comb += expected[b*8:b*8+4].eq(b+1)
210 # upper nibble is the partition width
211 comb += expected[b*8+4:b*8+8].eq(w)
212 # truncate the output, compare and assert
213 comb += Assert(p_output[:bit_width] == expected)
214 return m
215
216
217 class GateGenerator(Elaboratable):
218 """Produces partition gates at random
219
220 `p_offset`, `p_width` and `p_finish` describe the selected partition
221 """
222 def __init__(self, mwidth):
223 self.mwidth = mwidth
224 """Number of partitions"""
225 self.gates = Signal(mwidth-1)
226 """Generated partition gates"""
227 self.p_offset = Signal(range(mwidth))
228 """Generated partition start point"""
229 self.p_width = Signal(range(mwidth+1))
230 """Generated partition width"""
231 self.p_finish = Signal(range(mwidth+1))
232 """Generated partition end point"""
233
234 def elaborate(self, _):
235 m = Module()
236 comb = m.d.comb
237 mwidth = self.mwidth
238 gates = self.gates
239 p_offset = self.p_offset
240 p_width = self.p_width
241 p_finish = self.p_finish
242 comb += p_finish.eq(p_offset + p_width)
243 # Partition must not be empty, and fit within the signal.
244 comb += Assume(p_width != 0)
245 comb += Assume(p_offset + p_width <= mwidth)
246
247 # Build the corresponding partition
248 # Use Assume to constraint the pattern to conform to the given offset
249 # and width. For each gate bit it is:
250 # 1) one, if on the partition boundary
251 # 2) zero, if it's inside the partition
252 # 3) don't care, otherwise
253 p_gates = Signal(mwidth+1)
254 for i in range(mwidth+1):
255 with m.If(i == p_offset):
256 # Partitions begin with 1
257 comb += Assume(p_gates[i] == 1)
258 with m.If((i > p_offset) & (i < p_finish)):
259 # The interior are all zeros
260 comb += Assume(p_gates[i] == 0)
261 with m.If(i == p_finish):
262 # End with 1 again
263 comb += Assume(p_gates[i] == 1)
264 # Remove guard bits at each end, before assigning to the output gates
265 comb += gates.eq(p_gates[1:])
266 return m
267
268
269 class GeneratorDriver(Elaboratable):
270 def __init__(self):
271 # inputs and outputs
272 pass
273
274 @staticmethod
275 def elaborate(_):
276 m = Module()
277 comb = m.d.comb
278 width = 64
279 mwidth = 8
280 # Setup partition points and gates
281 points = PartitionPoints()
282 gates = Signal(mwidth-1)
283 step = int(width/mwidth)
284 for i in range(mwidth-1):
285 points[(i+1)*step] = gates[i]
286 # Instantiate the partitioned pattern producer and the DUT
287 m.submodules.dut = dut = PartitionedPattern(width, points)
288 m.submodules.gen = gen = GateGenerator(mwidth)
289 comb += gates.eq(gen.gates)
290 # Generate shifted down outputs
291 p_offset = gen.p_offset
292 p_width = gen.p_width
293 p_output = Signal(width)
294 for i in range(mwidth):
295 with m.If(p_offset == i):
296 comb += p_output.eq(dut.output[i*step:])
297 # Generate and check expected values for all possible partition sizes.
298 for w in range(1, mwidth+1):
299 with m.If(p_width == w):
300 # calculate the expected output, for the given bit width
301 bit_width = w * step
302 expected = Signal(bit_width, name=f"expected_{w}")
303 for b in range(w):
304 # lower nibble is the position
305 comb += expected[b*8:b*8+4].eq(b+1)
306 # upper nibble is the partition width
307 comb += expected[b*8+4:b*8+8].eq(w)
308 # truncate the output, compare and assert
309 comb += Assert(p_output[:bit_width] == expected)
310 # Output an example.
311 # Make it interesting, by having four partitions.
312 # Make the selected partition not start at the very beginning.
313 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
314 return m
315
316
317 def make_partitions(step, mwidth):
318 gates = Signal(mwidth - 1)
319 points = PartitionPoints()
320 for i in range(mwidth-1):
321 points[(i + 1) * step] = gates[i]
322 return points, gates
323
324
325 class ComparisonOpDriver(Elaboratable):
326 """Checks comparison operations on partitioned signals"""
327 def __init__(self, op, width, mwidth):
328 self.op = op
329 """Operation to perform. Must accept two integer-like inputs and give
330 a predicate-like output (1-bit partitions in case of
331 PartitionedSignal types)"""
332 self.width = width
333 """Partition full width"""
334 self.mwidth = mwidth
335 """Maximum number of equally sized partitions"""
336
337 def elaborate(self, _):
338 m = Module()
339 comb = m.d.comb
340 width = self.width
341 mwidth = self.mwidth
342 # setup partition points and gates
343 step = int(width/mwidth)
344 points, gates = make_partitions(step, mwidth)
345 # setup inputs and outputs
346 a = PartitionedSignal(points, width)
347 b = PartitionedSignal(points, width)
348 output = Signal(mwidth)
349 a.set_module(m)
350 b.set_module(m)
351 # perform the operation on the partitioned signals
352 comb += output.eq(self.op(a, b))
353 # output a test case
354 comb += Cover(output != 0)
355 return m
356
357
358 class PartitionTestCase(FHDLTestCase):
359 def test_formal(self):
360 traces = [
361 ('p_offset[2:0]', {'base': 'dec'}),
362 ('p_width[3:0]', {'base': 'dec'}),
363 ('p_finish[3:0]', {'base': 'dec'}),
364 ('p_gates[8:0]', {'base': 'bin'}),
365 ('dut', {'submodule': 'dut'}, [
366 ('gates[6:0]', {'base': 'bin'}),
367 'output[63:0]']),
368 'p_output[63:0]', 'expected_3[21:0]']
369 write_gtkw(
370 'proof_partition_cover.gtkw',
371 os.path.dirname(__file__) +
372 '/proof_partition_formal/engine_0/trace0.vcd',
373 traces,
374 module='top',
375 zoom=-3
376 )
377 write_gtkw(
378 'proof_partition_bmc.gtkw',
379 os.path.dirname(__file__) +
380 '/proof_partition_formal/engine_0/trace.vcd',
381 traces,
382 module='top',
383 zoom=-3
384 )
385 module = Driver()
386 self.assertFormal(module, mode="bmc", depth=1)
387 self.assertFormal(module, mode="cover", depth=1)
388
389 def test_generator(self):
390 traces = [
391 ('p_offset[2:0]', {'base': 'dec'}),
392 ('p_width[3:0]', {'base': 'dec'}),
393 ('p_finish[3:0]', {'base': 'dec'}),
394 ('p_gates[8:0]', {'base': 'bin'}),
395 ('dut', {'submodule': 'dut'}, [
396 ('gates[6:0]', {'base': 'bin'}),
397 'output[63:0]']),
398 'p_output[63:0]', 'expected_3[21:0]']
399 write_gtkw(
400 'proof_partition_generator_cover.gtkw',
401 os.path.dirname(__file__) +
402 '/proof_partition_generator/engine_0/trace0.vcd',
403 traces,
404 module='top',
405 zoom=-3
406 )
407 write_gtkw(
408 'proof_partition_generator_bmc.gtkw',
409 os.path.dirname(__file__) +
410 '/proof_partition_generator/engine_0/trace.vcd',
411 traces,
412 module='top',
413 zoom=-3
414 )
415 module = GeneratorDriver()
416 self.assertFormal(module, mode="bmc", depth=1)
417 self.assertFormal(module, mode="cover", depth=1)
418
419 def test_partsig_eq(self):
420 traces = [
421 ('eq_1', {'submodule': 'eq_1'}, [
422 ('gates[6:0]', {'base': 'bin'}),
423 'a[63:0]', 'b[63:0]',
424 ('output[7:0]', {'base': 'bin'})])]
425 write_gtkw(
426 'proof_partsig_eq_cover.gtkw',
427 os.path.dirname(__file__) +
428 '/proof_partition_partsig_eq/engine_0/trace0.vcd',
429 traces,
430 module='top',
431 zoom=-3
432 )
433 write_gtkw(
434 'proof_partsig_eq_bmc.gtkw',
435 os.path.dirname(__file__) +
436 '/proof_partition_partsig_eq/engine_0/trace.vcd',
437 traces,
438 module='top',
439 zoom=-3
440 )
441 module = ComparisonOpDriver(operator.eq, 64, 8)
442 self.assertFormal(module, mode="bmc", depth=1)
443 self.assertFormal(module, mode="cover", depth=1)
444
445
446 if __name__ == '__main__':
447 unittest.main()