f257d0243b897456d4aded6de62b5d12cae2e1db
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34 import operator
35
36 from nmigen import Elaboratable, Signal, Module, Const, Repl
37 from nmigen.asserts import Assert, Cover
38 from nmigen.hdl.ast import Assume
39
40 from nmutil.formaltest import FHDLTestCase
41 from nmutil.gtkw import write_gtkw
42
43 from ieee754.part_mul_add.partpoints import PartitionPoints
44 from ieee754.part.partsig import PartitionedSignal
45
46
47 class PartitionedPattern(Elaboratable):
48 """ Generate a unique pattern, depending on partition size.
49
50 * 1-byte partitions: 0x11
51 * 2-byte partitions: 0x21 0x22
52 * 3-byte partitions: 0x31 0x32 0x33
53
54 And so on.
55
56 Useful as a test vector for testing the formal prover
57
58 """
59 def __init__(self, width, partition_points):
60 self.width = width
61 self.partition_points = PartitionPoints(partition_points)
62 self.mwidth = len(self.partition_points)+1
63 self.output = Signal(self.width, reset_less=True)
64
65 def elaborate(self, platform):
66 m = Module()
67 comb = m.d.comb
68
69 # Add a guard bit at each end
70 positions = [0] + list(self.partition_points.keys()) + [self.width]
71 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
72 # Begin counting at one
73 last_start = positions[0]
74 last_end = positions[1]
75 last_middle = (last_start+last_end)//2
76 comb += self.output[last_start:last_middle].eq(1)
77 # Build an incrementing cascade
78 for i in range(1, self.mwidth):
79 start = positions[i]
80 end = positions[i+1]
81 middle = (start + end) // 2
82 # Propagate from the previous byte, adding one to it.
83 with m.If(~gates[i]):
84 comb += self.output[start:middle].eq(
85 self.output[last_start:last_middle] + 1)
86 with m.Else():
87 # ... unless it's a partition boundary. If so, start again.
88 comb += self.output[start:middle].eq(1)
89 last_start = start
90 last_middle = middle
91 # Mirror the nibbles on the last byte
92 last_start = positions[-2]
93 last_end = positions[-1]
94 last_middle = (last_start+last_end)//2
95 comb += self.output[last_middle:last_end].eq(
96 self.output[last_start:last_middle])
97 for i in range(self.mwidth, 0, -1):
98 start = positions[i-1]
99 end = positions[i]
100 middle = (start + end) // 2
101 # Propagate from the previous byte.
102 with m.If(~gates[i]):
103 comb += self.output[middle:end].eq(
104 self.output[last_middle:last_end])
105 with m.Else():
106 # ... unless it's a partition boundary.
107 # If so, mirror the nibbles again.
108 comb += self.output[middle:end].eq(
109 self.output[start:middle])
110 last_middle = middle
111 last_end = end
112
113 return m
114
115
116 def make_partitions(step, mwidth):
117 """Make equally spaced partition points
118
119 :param step: smallest partition width
120 :param mwidth: maximum number of partitions
121 :returns: partition points, and corresponding gates"""
122 gates = Signal(mwidth - 1)
123 points = PartitionPoints()
124 for i in range(mwidth-1):
125 points[(i + 1) * step] = gates[i]
126 return points, gates
127
128
129 # This defines a module to drive the device under test and assert
130 # properties about its outputs
131 class Driver(Elaboratable):
132 def __init__(self):
133 # inputs and outputs
134 pass
135
136 @staticmethod
137 def elaborate(_):
138 m = Module()
139 comb = m.d.comb
140 width = 64
141 mwidth = 8
142 # Setup partition points and gates
143 step = int(width/mwidth)
144 points, gates = make_partitions(step, mwidth)
145 # Instantiate the partitioned pattern producer
146 m.submodules.dut = dut = PartitionedPattern(width, points)
147 # Directly check some cases
148 with m.If(gates == 0):
149 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
150 with m.If(gates == 0b1100101):
151 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
152 with m.If(gates == 0b0001000):
153 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
154 with m.If(gates == 0b0100001):
155 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
156 with m.If(gates == 0b1000001):
157 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
158 with m.If(gates == 0b0000001):
159 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
160 # Choose a partition offset and width at random.
161 p_offset = Signal(range(mwidth))
162 p_width = Signal(range(mwidth+1))
163 p_finish = Signal(range(mwidth+1))
164 comb += p_finish.eq(p_offset + p_width)
165 # Partition must not be empty, and fit within the signal.
166 comb += Assume(p_width != 0)
167 comb += Assume(p_offset + p_width <= mwidth)
168
169 # Build the corresponding partition
170 # Use Assume to constraint the pattern to conform to the given offset
171 # and width. For each gate bit it is:
172 # 1) one, if on the partition boundary
173 # 2) zero, if it's inside the partition
174 # 3) don't care, otherwise
175 p_gates = Signal(mwidth+1)
176 for i in range(mwidth+1):
177 with m.If(i == p_offset):
178 # Partitions begin with 1
179 comb += Assume(p_gates[i] == 1)
180 with m.If((i > p_offset) & (i < p_finish)):
181 # The interior are all zeros
182 comb += Assume(p_gates[i] == 0)
183 with m.If(i == p_finish):
184 # End with 1 again
185 comb += Assume(p_gates[i] == 1)
186 # Check some possible partitions generating a given pattern
187 with m.If(p_gates == 0b0100110):
188 comb += Assert(((p_offset == 1) & (p_width == 1)) |
189 ((p_offset == 2) & (p_width == 3)))
190 # Remove guard bits at each end and assign to the DUT gates
191 comb += gates.eq(p_gates[1:])
192 # Generate shifted down outputs:
193 p_output = Signal(width)
194 positions = [0] + list(points.keys()) + [width]
195 for i in range(mwidth):
196 with m.If(p_offset == i):
197 comb += p_output.eq(dut.output[positions[i]:])
198 # Some checks on the shifted down output, irrespective of offset:
199 with m.If(p_width == 2):
200 comb += Assert(p_output[:16] == 0x_22_21)
201 with m.If(p_width == 4):
202 comb += Assert(p_output[:32] == 0x_44_43_42_41)
203 # test zero shift
204 with m.If(p_offset == 0):
205 comb += Assert(p_output == dut.output)
206 # Output an example.
207 # Make it interesting, by having four partitions.
208 # Make the selected partition not start at the very beginning.
209 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
210 # Generate and check expected values for all possible partition sizes.
211 # Here, we assume partition sizes are multiple of the smaller size.
212 for w in range(1, mwidth+1):
213 with m.If(p_width == w):
214 # calculate the expected output, for the given bit width
215 bit_width = w * step
216 expected = Signal(bit_width, name=f"expected_{w}")
217 for b in range(w):
218 # lower nibble is the position
219 comb += expected[b*8:b*8+4].eq(b+1)
220 # upper nibble is the partition width
221 comb += expected[b*8+4:b*8+8].eq(w)
222 # truncate the output, compare and assert
223 comb += Assert(p_output[:bit_width] == expected)
224 return m
225
226
227 class GateGenerator(Elaboratable):
228 """Produces partition gates at random
229
230 `p_offset`, `p_width` and `p_finish` describe the selected partition
231 """
232 def __init__(self, mwidth):
233 self.mwidth = mwidth
234 """Number of partitions"""
235 self.gates = Signal(mwidth-1)
236 """Generated partition gates"""
237 self.p_offset = Signal(range(mwidth))
238 """Generated partition start point"""
239 self.p_width = Signal(range(mwidth+1))
240 """Generated partition width"""
241 self.p_finish = Signal(range(mwidth+1))
242 """Generated partition end point"""
243
244 def elaborate(self, _):
245 m = Module()
246 comb = m.d.comb
247 mwidth = self.mwidth
248 gates = self.gates
249 p_offset = self.p_offset
250 p_width = self.p_width
251 p_finish = self.p_finish
252 comb += p_finish.eq(p_offset + p_width)
253 # Partition must not be empty, and fit within the signal.
254 comb += Assume(p_width != 0)
255 comb += Assume(p_offset + p_width <= mwidth)
256
257 # Build the corresponding partition
258 # Use Assume to constraint the pattern to conform to the given offset
259 # and width. For each gate bit it is:
260 # 1) one, if on the partition boundary
261 # 2) zero, if it's inside the partition
262 # 3) don't care, otherwise
263 p_gates = Signal(mwidth+1)
264 for i in range(mwidth+1):
265 with m.If(i == p_offset):
266 # Partitions begin with 1
267 comb += Assume(p_gates[i] == 1)
268 with m.If((i > p_offset) & (i < p_finish)):
269 # The interior are all zeros
270 comb += Assume(p_gates[i] == 0)
271 with m.If(i == p_finish):
272 # End with 1 again
273 comb += Assume(p_gates[i] == 1)
274 # Remove guard bits at each end, before assigning to the output gates
275 comb += gates.eq(p_gates[1:])
276 return m
277
278
279 class GeneratorDriver(Elaboratable):
280 def __init__(self):
281 # inputs and outputs
282 pass
283
284 @staticmethod
285 def elaborate(_):
286 m = Module()
287 comb = m.d.comb
288 width = 64
289 mwidth = 8
290 # Setup partition points and gates
291 step = int(width/mwidth)
292 points, gates = make_partitions(step, mwidth)
293 # Instantiate the partitioned pattern producer and the DUT
294 m.submodules.dut = dut = PartitionedPattern(width, points)
295 m.submodules.gen = gen = GateGenerator(mwidth)
296 comb += gates.eq(gen.gates)
297 # Generate shifted down outputs
298 p_offset = gen.p_offset
299 p_width = gen.p_width
300 p_output = Signal(width)
301 for i in range(mwidth):
302 with m.If(p_offset == i):
303 comb += p_output.eq(dut.output[i*step:])
304 # Generate and check expected values for all possible partition sizes.
305 for w in range(1, mwidth+1):
306 with m.If(p_width == w):
307 # calculate the expected output, for the given bit width
308 bit_width = w * step
309 expected = Signal(bit_width, name=f"expected_{w}")
310 for b in range(w):
311 # lower nibble is the position
312 comb += expected[b*8:b*8+4].eq(b+1)
313 # upper nibble is the partition width
314 comb += expected[b*8+4:b*8+8].eq(w)
315 # truncate the output, compare and assert
316 comb += Assert(p_output[:bit_width] == expected)
317 # Output an example.
318 # Make it interesting, by having four partitions.
319 # Make the selected partition not start at the very beginning.
320 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
321 return m
322
323
324 class ComparisonOpDriver(Elaboratable):
325 """Checks comparison operations on partitioned signals"""
326 def __init__(self, op, width, mwidth):
327 self.op = op
328 """Operation to perform. Must accept two integer-like inputs and give
329 a predicate-like output (1-bit partitions in case of
330 PartitionedSignal types)"""
331 self.width = width
332 """Partition full width"""
333 self.mwidth = mwidth
334 """Maximum number of equally sized partitions"""
335
336 def elaborate(self, _):
337 m = Module()
338 comb = m.d.comb
339 width = self.width
340 mwidth = self.mwidth
341 # setup partition points and gates
342 step = int(width/mwidth)
343 points, gates = make_partitions(step, mwidth)
344 # setup inputs and outputs
345 a = PartitionedSignal(points, width)
346 b = PartitionedSignal(points, width)
347 output = Signal(mwidth)
348 a.set_module(m)
349 b.set_module(m)
350 # perform the operation on the partitioned signals
351 comb += output.eq(self.op(a, b))
352 # instantiate the partitioned gate generator and connect the gates
353 m.submodules.gen = gen = GateGenerator(mwidth)
354 comb += gates.eq(gen.gates)
355 p_offset = gen.p_offset
356 p_width = gen.p_width
357 # generate shifted down inputs and outputs
358 p_output = Signal(mwidth)
359 p_a = Signal(width)
360 p_b = Signal(width)
361 for pos in range(mwidth):
362 with m.If(p_offset == pos):
363 comb += p_output.eq(output[pos:])
364 comb += p_a.eq(a.sig[pos * step:])
365 comb += p_b.eq(b.sig[pos * step:])
366 # generate and check expected values for all possible partition sizes
367 for w in range(1, mwidth+1):
368 with m.If(p_width == w):
369 # calculate the expected output, for the given bit width,
370 # truncating the inputs to the partition size
371 input_bit_width = w * step
372 output_bit_width = w
373 expected = Signal(output_bit_width, name=f"expected_{w}")
374 a = Signal(input_bit_width, name=f"a_{w}")
375 b = Signal(input_bit_width, name=f"b_{w}")
376 lsb = Signal(name=f"lsb_{w}")
377 comb += a.eq(p_a[:input_bit_width])
378 comb += b.eq(p_b[:input_bit_width])
379 comb += lsb.eq(self.op(a, b))
380 comb += expected.eq(Repl(lsb, output_bit_width))
381 # truncate the output, compare and assert
382 comb += Assert(p_output[:output_bit_width] == expected)
383 # output a test case
384 comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1) &
385 (p_a != 0) & (p_b != 0) & (p_output != 0))
386 return m
387
388
389 class PartitionTestCase(FHDLTestCase):
390 def test_formal(self):
391 style = {
392 'dec': {'base': 'dec'},
393 'bin': {'base': 'bin'}
394 }
395 traces = [
396 ('p_offset[2:0]', 'dec'),
397 ('p_width[3:0]', 'dec'),
398 ('p_finish[3:0]', 'dec'),
399 ('p_gates[8:0]', 'bin'),
400 ('dut', {'submodule': 'dut'}, [
401 ('gates[6:0]', 'bin'),
402 'output[63:0]']),
403 'p_output[63:0]', 'expected_3[21:0]']
404 write_gtkw(
405 'proof_partition_cover.gtkw',
406 os.path.dirname(__file__) +
407 '/proof_partition_formal/engine_0/trace0.vcd',
408 traces, style,
409 module='top',
410 zoom=-3
411 )
412 write_gtkw(
413 'proof_partition_bmc.gtkw',
414 os.path.dirname(__file__) +
415 '/proof_partition_formal/engine_0/trace.vcd',
416 traces, style,
417 module='top',
418 zoom=-3
419 )
420 module = Driver()
421 self.assertFormal(module, mode="bmc", depth=1)
422 self.assertFormal(module, mode="cover", depth=1)
423
424 def test_generator(self):
425 style = {
426 'dec': {'base': 'dec'},
427 'bin': {'base': 'bin'}
428 }
429 traces = [
430 ('p_offset[2:0]', 'dec'),
431 ('p_width[3:0]', 'dec'),
432 ('p_finish[3:0]', 'dec'),
433 ('p_gates[8:0]', 'bin'),
434 ('dut', {'submodule': 'dut'}, [
435 ('gates[6:0]', 'bin'),
436 'output[63:0]']),
437 'p_output[63:0]', 'expected_3[21:0]',
438 'a_3[23:0]', 'b_3[32:0]', 'expected_3[2:0]']
439 write_gtkw(
440 'proof_partition_generator_cover.gtkw',
441 os.path.dirname(__file__) +
442 '/proof_partition_generator/engine_0/trace0.vcd',
443 traces, style,
444 module='top',
445 zoom=-3
446 )
447 write_gtkw(
448 'proof_partition_generator_bmc.gtkw',
449 os.path.dirname(__file__) +
450 '/proof_partition_generator/engine_0/trace.vcd',
451 traces, style,
452 module='top',
453 zoom=-3
454 )
455 module = GeneratorDriver()
456 self.assertFormal(module, mode="bmc", depth=1)
457 self.assertFormal(module, mode="cover", depth=1)
458
459 def test_partsig_eq(self):
460 style = {
461 'dec': {'base': 'dec'},
462 'bin': {'base': 'bin'}
463 }
464 traces = [
465 ('p_offset[2:0]', 'dec'),
466 ('p_width[3:0]', 'dec'),
467 ('p_gates[8:0]', 'bin'),
468 ('eq_1', {'submodule': 'eq_1'}, [
469 ('gates[6:0]', 'bin'),
470 'a[63:0]', 'b[63:0]',
471 ('output[7:0]', 'bin')]),
472 'p_a[63:0]', 'p_b[63:0]',
473 ('p_output[7:0]', 'bin')]
474 write_gtkw(
475 'proof_partsig_eq_cover.gtkw',
476 os.path.dirname(__file__) +
477 '/proof_partition_partsig_eq/engine_0/trace0.vcd',
478 traces, style,
479 module='top',
480 zoom=-3
481 )
482 write_gtkw(
483 'proof_partsig_eq_bmc.gtkw',
484 os.path.dirname(__file__) +
485 '/proof_partition_partsig_eq/engine_0/trace.vcd',
486 traces, style,
487 module='top',
488 zoom=-3
489 )
490 module = ComparisonOpDriver(operator.eq, 64, 8)
491 self.assertFormal(module, mode="bmc", depth=1)
492 self.assertFormal(module, mode="cover", depth=1)
493
494 def test_partsig_ne(self):
495 module = ComparisonOpDriver(operator.ne, 64, 8)
496 self.assertFormal(module, mode="bmc", depth=1)
497
498 def test_partsig_gt(self):
499 module = ComparisonOpDriver(operator.gt, 64, 8)
500 self.assertFormal(module, mode="bmc", depth=1)
501
502 def test_partsig_ge(self):
503 module = ComparisonOpDriver(operator.ge, 64, 8)
504 self.assertFormal(module, mode="bmc", depth=1)
505
506 def test_partsig_lt(self):
507 module = ComparisonOpDriver(operator.lt, 64, 8)
508 self.assertFormal(module, mode="bmc", depth=1)
509
510 def test_partsig_le(self):
511 module = ComparisonOpDriver(operator.le, 64, 8)
512 self.assertFormal(module, mode="bmc", depth=1)
513
514
515 if __name__ == '__main__':
516 unittest.main()