switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34 import operator
35
36 from nmigen import Elaboratable, Signal, Module, Const, Repl
37 from nmigen.asserts import Assert, Cover
38 from nmigen.hdl.ast import Assume
39
40 from nmutil.formaltest import FHDLTestCase
41 from nmutil.gtkw import write_gtkw
42
43 from ieee754.part_mul_add.partpoints import PartitionPoints
44 from ieee754.part.partsig import PartitionedSignal
45
46
47 class PartitionedPattern(Elaboratable):
48 """ Generate a unique pattern, depending on partition size.
49
50 * 1-byte partitions: 0x11
51 * 2-byte partitions: 0x21 0x22
52 * 3-byte partitions: 0x31 0x32 0x33
53
54 And so on.
55
56 Useful as a test vector for testing the formal prover
57
58 """
59 def __init__(self, width, partition_points):
60 self.width = width
61 self.partition_points = PartitionPoints(partition_points)
62 self.mwidth = len(self.partition_points)+1
63 self.output = Signal(self.width, reset_less=True)
64
65 def elaborate(self, platform):
66 m = Module()
67 comb = m.d.comb
68
69 # Add a guard bit at each end
70 positions = [0] + list(self.partition_points.keys()) + [self.width]
71 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
72 # Begin counting at one
73 last_start = positions[0]
74 last_end = positions[1]
75 last_middle = (last_start+last_end)//2
76 comb += self.output[last_start:last_middle].eq(1)
77 # Build an incrementing cascade
78 for i in range(1, self.mwidth):
79 start = positions[i]
80 end = positions[i+1]
81 middle = (start + end) // 2
82 # Propagate from the previous byte, adding one to it.
83 with m.If(~gates[i]):
84 comb += self.output[start:middle].eq(
85 self.output[last_start:last_middle] + 1)
86 with m.Else():
87 # ... unless it's a partition boundary. If so, start again.
88 comb += self.output[start:middle].eq(1)
89 last_start = start
90 last_middle = middle
91 # Mirror the nibbles on the last byte
92 last_start = positions[-2]
93 last_end = positions[-1]
94 last_middle = (last_start+last_end)//2
95 comb += self.output[last_middle:last_end].eq(
96 self.output[last_start:last_middle])
97 for i in range(self.mwidth, 0, -1):
98 start = positions[i-1]
99 end = positions[i]
100 middle = (start + end) // 2
101 # Propagate from the previous byte.
102 with m.If(~gates[i]):
103 comb += self.output[middle:end].eq(
104 self.output[last_middle:last_end])
105 with m.Else():
106 # ... unless it's a partition boundary.
107 # If so, mirror the nibbles again.
108 comb += self.output[middle:end].eq(
109 self.output[start:middle])
110 last_middle = middle
111 last_end = end
112
113 return m
114
115
116 def make_partitions(step, mwidth):
117 """Make equally spaced partition points
118
119 :param step: smallest partition width
120 :param mwidth: maximum number of partitions
121 :returns: partition points, and corresponding gates"""
122 gates = Signal(mwidth - 1)
123 points = PartitionPoints()
124 for i in range(mwidth-1):
125 points[(i + 1) * step] = gates[i]
126 return points, gates
127
128
129 # This defines a module to drive the device under test and assert
130 # properties about its outputs
131 class Driver(Elaboratable):
132 def __init__(self):
133 # inputs and outputs
134 pass
135
136 @staticmethod
137 def elaborate(_):
138 m = Module()
139 comb = m.d.comb
140 width = 64
141 mwidth = 8
142 # Setup partition points and gates
143 step = int(width/mwidth)
144 points, gates = make_partitions(step, mwidth)
145 # Instantiate the partitioned pattern producer
146 m.submodules.dut = dut = PartitionedPattern(width, points)
147 # Directly check some cases
148 with m.If(gates == 0):
149 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
150 with m.If(gates == 0b1100101):
151 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
152 with m.If(gates == 0b0001000):
153 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
154 with m.If(gates == 0b0100001):
155 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
156 with m.If(gates == 0b1000001):
157 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
158 with m.If(gates == 0b0000001):
159 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
160 # Choose a partition offset and width at random.
161 p_offset = Signal(range(mwidth))
162 p_width = Signal(range(mwidth+1))
163 p_finish = Signal(range(mwidth+1))
164 comb += p_finish.eq(p_offset + p_width)
165 # Partition must not be empty, and fit within the signal.
166 comb += Assume(p_width != 0)
167 comb += Assume(p_offset + p_width <= mwidth)
168
169 # Build the corresponding partition
170 # Use Assume to constraint the pattern to conform to the given offset
171 # and width. For each gate bit it is:
172 # 1) one, if on the partition boundary
173 # 2) zero, if it's inside the partition
174 # 3) don't care, otherwise
175 p_gates = Signal(mwidth+1)
176 for i in range(mwidth+1):
177 with m.If(i == p_offset):
178 # Partitions begin with 1
179 comb += Assume(p_gates[i] == 1)
180 with m.If((i > p_offset) & (i < p_finish)):
181 # The interior are all zeros
182 comb += Assume(p_gates[i] == 0)
183 with m.If(i == p_finish):
184 # End with 1 again
185 comb += Assume(p_gates[i] == 1)
186 # Check some possible partitions generating a given pattern
187 with m.If(p_gates == 0b0100110):
188 comb += Assert(((p_offset == 1) & (p_width == 1)) |
189 ((p_offset == 2) & (p_width == 3)))
190 # Remove guard bits at each end and assign to the DUT gates
191 comb += gates.eq(p_gates[1:])
192 # Generate shifted down outputs:
193 p_output = Signal(width)
194 positions = [0] + list(points.keys()) + [width]
195 for i in range(mwidth):
196 with m.If(p_offset == i):
197 comb += p_output.eq(dut.output[positions[i]:])
198 # Some checks on the shifted down output, irrespective of offset:
199 with m.If(p_width == 2):
200 comb += Assert(p_output[:16] == 0x_22_21)
201 with m.If(p_width == 4):
202 comb += Assert(p_output[:32] == 0x_44_43_42_41)
203 # test zero shift
204 with m.If(p_offset == 0):
205 comb += Assert(p_output == dut.output)
206 # Output an example.
207 # Make it interesting, by having four partitions.
208 # Make the selected partition not start at the very beginning.
209 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
210 # Generate and check expected values for all possible partition sizes.
211 # Here, we assume partition sizes are multiple of the smaller size.
212 for w in range(1, mwidth+1):
213 with m.If(p_width == w):
214 # calculate the expected output, for the given bit width
215 bit_width = w * step
216 expected = Signal(bit_width, name=f"expected_{w}")
217 for b in range(w):
218 # lower nibble is the position
219 comb += expected[b*8:b*8+4].eq(b+1)
220 # upper nibble is the partition width
221 comb += expected[b*8+4:b*8+8].eq(w)
222 # truncate the output, compare and assert
223 comb += Assert(p_output[:bit_width] == expected)
224 return m
225
226
227 class GateGenerator(Elaboratable):
228 """Produces partition gates at random
229
230 `p_offset`, `p_width` and `p_finish` describe the selected partition
231 """
232 def __init__(self, mwidth):
233 self.mwidth = mwidth
234 """Number of partitions"""
235 self.gates = Signal(mwidth-1)
236 """Generated partition gates"""
237 self.p_offset = Signal(range(mwidth))
238 """Generated partition start point"""
239 self.p_width = Signal(range(mwidth+1))
240 """Generated partition width"""
241 self.p_finish = Signal(range(mwidth+1))
242 """Generated partition end point"""
243
244 def elaborate(self, _):
245 m = Module()
246 comb = m.d.comb
247 mwidth = self.mwidth
248 gates = self.gates
249 p_offset = self.p_offset
250 p_width = self.p_width
251 p_finish = self.p_finish
252 comb += p_finish.eq(p_offset + p_width)
253 # Partition must not be empty, and fit within the signal.
254 comb += Assume(p_width != 0)
255 comb += Assume(p_offset + p_width <= mwidth)
256
257 # Build the corresponding partition
258 # Use Assume to constraint the pattern to conform to the given offset
259 # and width. For each gate bit it is:
260 # 1) one, if on the partition boundary
261 # 2) zero, if it's inside the partition
262 # 3) don't care, otherwise
263 p_gates = Signal(mwidth+1)
264 for i in range(mwidth+1):
265 with m.If(i == p_offset):
266 # Partitions begin with 1
267 comb += Assume(p_gates[i] == 1)
268 with m.If((i > p_offset) & (i < p_finish)):
269 # The interior are all zeros
270 comb += Assume(p_gates[i] == 0)
271 with m.If(i == p_finish):
272 # End with 1 again
273 comb += Assume(p_gates[i] == 1)
274 # Remove guard bits at each end, before assigning to the output gates
275 comb += gates.eq(p_gates[1:])
276 return m
277
278
279 class GeneratorDriver(Elaboratable):
280 def __init__(self):
281 # inputs and outputs
282 pass
283
284 @staticmethod
285 def elaborate(_):
286 m = Module()
287 comb = m.d.comb
288 width = 64
289 mwidth = 8
290 # Setup partition points and gates
291 step = int(width/mwidth)
292 points, gates = make_partitions(step, mwidth)
293 # Instantiate the partitioned pattern producer and the DUT
294 m.submodules.dut = dut = PartitionedPattern(width, points)
295 m.submodules.gen = gen = GateGenerator(mwidth)
296 comb += gates.eq(gen.gates)
297 # Generate shifted down outputs
298 p_offset = gen.p_offset
299 p_width = gen.p_width
300 p_output = Signal(width)
301 for i in range(mwidth):
302 with m.If(p_offset == i):
303 comb += p_output.eq(dut.output[i*step:])
304 # Generate and check expected values for all possible partition sizes.
305 for w in range(1, mwidth+1):
306 with m.If(p_width == w):
307 # calculate the expected output, for the given bit width
308 bit_width = w * step
309 expected = Signal(bit_width, name=f"expected_{w}")
310 for b in range(w):
311 # lower nibble is the position
312 comb += expected[b*8:b*8+4].eq(b+1)
313 # upper nibble is the partition width
314 comb += expected[b*8+4:b*8+8].eq(w)
315 # truncate the output, compare and assert
316 comb += Assert(p_output[:bit_width] == expected)
317 # Output an example.
318 # Make it interesting, by having four partitions.
319 # Make the selected partition not start at the very beginning.
320 comb += Cover((sum(gates) == 3) & (p_offset != 0) & (p_width == 3))
321 return m
322
323
324 class ComparisonOpDriver(Elaboratable):
325 """Checks comparison operations on partitioned signals"""
326 def __init__(self, op, width=64, mwidth=8, nops=2):
327 self.op = op
328 """Operation to perform. Must accept two integer-like inputs and give
329 a predicate-like output (1-bit partitions in case of
330 PartitionedSignal types)"""
331 self.width = width
332 """Partition full width"""
333 self.mwidth = mwidth
334 """Maximum number of equally sized partitions"""
335 self.nops = nops
336 """Number of input operands"""
337 def elaborate(self, _):
338 m = Module()
339 comb = m.d.comb
340 width = self.width
341 mwidth = self.mwidth
342 nops = self.nops
343 # setup partition points and gates
344 step = int(width/mwidth)
345 points, gates = make_partitions(step, mwidth)
346 # setup inputs and outputs
347 operands = list()
348 for i in range(nops):
349 inp = PartitionedSignal(points, width, name=f"i_{i+1}")
350 inp.set_module(m)
351 operands.append(inp)
352 output = Signal(mwidth)
353 # perform the operation on the partitioned signals
354 comb += output.eq(self.op(*operands))
355 # instantiate the partitioned gate generator and connect the gates
356 m.submodules.gen = gen = GateGenerator(mwidth)
357 comb += gates.eq(gen.gates)
358 p_offset = gen.p_offset
359 p_width = gen.p_width
360 # generate shifted down inputs and outputs
361 p_operands = list()
362 for i in range(nops):
363 p_i = Signal(width, name=f"p_{i+1}")
364 p_operands.append(p_i)
365 for pos in range(mwidth):
366 with m.If(p_offset == pos):
367 comb += p_i.eq(operands[i].sig[pos * step:])
368 p_output = Signal(mwidth)
369 for pos in range(mwidth):
370 with m.If(p_offset == pos):
371 comb += p_output.eq(output[pos:])
372 # generate and check expected values for all possible partition sizes
373 for w in range(1, mwidth+1):
374 with m.If(p_width == w):
375 # calculate the expected output, for the given bit width,
376 # truncating the inputs to the partition size
377 input_bit_width = w * step
378 output_bit_width = w
379 expected = Signal(output_bit_width, name=f"expected_{w}")
380 trunc_operands = list()
381 for i in range(nops):
382 t_i = Signal(input_bit_width, name=f"t{w}_{i+1}")
383 trunc_operands.append(t_i)
384 comb += t_i.eq(p_operands[i][:input_bit_width])
385 lsb = Signal(name=f"lsb_{w}")
386 comb += lsb.eq(self.op(*trunc_operands))
387 comb += expected.eq(Repl(lsb, output_bit_width))
388 # truncate the output, compare and assert
389 comb += Assert(p_output[:output_bit_width] == expected)
390 # output a test case
391 comb += Cover((p_offset != 0) & (p_width == 3) & (sum(output) > 1) &
392 (p_output != 0))
393 return m
394
395
396 class PartitionTestCase(FHDLTestCase):
397 def test_formal(self):
398 style = {
399 'dec': {'base': 'dec'},
400 'bin': {'base': 'bin'}
401 }
402 traces = [
403 ('p_offset[2:0]', 'dec'),
404 ('p_width[3:0]', 'dec'),
405 ('p_finish[3:0]', 'dec'),
406 ('p_gates[8:0]', 'bin'),
407 ('dut', {'submodule': 'dut'}, [
408 ('gates[6:0]', 'bin'),
409 'output[63:0]']),
410 'p_output[63:0]', 'expected_3[21:0]']
411 write_gtkw(
412 'proof_partition_cover.gtkw',
413 os.path.dirname(__file__) +
414 '/proof_partition_formal/engine_0/trace0.vcd',
415 traces, style,
416 module='top',
417 zoom=-3
418 )
419 write_gtkw(
420 'proof_partition_bmc.gtkw',
421 os.path.dirname(__file__) +
422 '/proof_partition_formal/engine_0/trace.vcd',
423 traces, style,
424 module='top',
425 zoom=-3
426 )
427 module = Driver()
428 self.assertFormal(module, mode="bmc", depth=1)
429 self.assertFormal(module, mode="cover", depth=1)
430
431 def test_generator(self):
432 style = {
433 'dec': {'base': 'dec'},
434 'bin': {'base': 'bin'}
435 }
436 traces = [
437 ('p_offset[2:0]', 'dec'),
438 ('p_width[3:0]', 'dec'),
439 ('p_finish[3:0]', 'dec'),
440 ('p_gates[8:0]', 'bin'),
441 ('dut', {'submodule': 'dut'}, [
442 ('gates[6:0]', 'bin'),
443 'output[63:0]']),
444 'p_output[63:0]', 'expected_3[21:0]',
445 'a_3[23:0]', 'b_3[32:0]', 'expected_3[2:0]']
446 write_gtkw(
447 'proof_partition_generator_cover.gtkw',
448 os.path.dirname(__file__) +
449 '/proof_partition_generator/engine_0/trace0.vcd',
450 traces, style,
451 module='top',
452 zoom=-3
453 )
454 write_gtkw(
455 'proof_partition_generator_bmc.gtkw',
456 os.path.dirname(__file__) +
457 '/proof_partition_generator/engine_0/trace.vcd',
458 traces, style,
459 module='top',
460 zoom=-3
461 )
462 module = GeneratorDriver()
463 self.assertFormal(module, mode="bmc", depth=1)
464 self.assertFormal(module, mode="cover", depth=1)
465
466 def test_partsig_eq(self):
467 style = {
468 'dec': {'base': 'dec'},
469 'bin': {'base': 'bin'}
470 }
471 traces = [
472 ('p_offset[2:0]', 'dec'),
473 ('p_width[3:0]', 'dec'),
474 ('p_gates[8:0]', 'bin'),
475 'i_1[63:0]', 'i_2[63:0]',
476 ('eq_1', {'submodule': 'eq_1'}, [
477 ('gates[6:0]', 'bin'),
478 'a[63:0]', 'b[63:0]',
479 ('output[7:0]', 'bin')]),
480 'p_1[63:0]', 'p_2[63:0]',
481 ('p_output[7:0]', 'bin'),
482 't3_1[23:0]', 't3_2[23:0]', 'lsb_3',
483 ('expected_3[2:0]', 'bin')]
484 write_gtkw(
485 'proof_partsig_eq_cover.gtkw',
486 os.path.dirname(__file__) +
487 '/proof_partition_partsig_eq/engine_0/trace0.vcd',
488 traces, style,
489 module='top',
490 zoom=-3
491 )
492 write_gtkw(
493 'proof_partsig_eq_bmc.gtkw',
494 os.path.dirname(__file__) +
495 '/proof_partition_partsig_eq/engine_0/trace.vcd',
496 traces, style,
497 module='top',
498 zoom=-3
499 )
500 module = ComparisonOpDriver(operator.eq)
501 self.assertFormal(module, mode="bmc", depth=1)
502 self.assertFormal(module, mode="cover", depth=1)
503
504 def test_partsig_ne(self):
505 module = ComparisonOpDriver(operator.ne)
506 self.assertFormal(module, mode="bmc", depth=1)
507
508 def test_partsig_gt(self):
509 module = ComparisonOpDriver(operator.gt)
510 self.assertFormal(module, mode="bmc", depth=1)
511
512 def test_partsig_ge(self):
513 module = ComparisonOpDriver(operator.ge)
514 self.assertFormal(module, mode="bmc", depth=1)
515
516 def test_partsig_lt(self):
517 module = ComparisonOpDriver(operator.lt)
518 self.assertFormal(module, mode="bmc", depth=1)
519
520 def test_partsig_le(self):
521 module = ComparisonOpDriver(operator.le)
522 self.assertFormal(module, mode="bmc", depth=1)
523
524 def test_partsig_all(self):
525 style = {
526 'dec': {'base': 'dec'},
527 'bin': {'base': 'bin'}
528 }
529 traces = [
530 ('p_offset[2:0]', 'dec'),
531 ('p_width[3:0]', 'dec'),
532 ('p_gates[8:0]', 'bin'),
533 'i_1[63:0]',
534 ('eq_1', {'submodule': 'eq_1'}, [
535 ('gates[6:0]', 'bin'),
536 'a[63:0]', 'b[63:0]',
537 ('output[7:0]', 'bin')]),
538 'p_1[63:0]',
539 ('p_output[7:0]', 'bin'),
540 't3_1[23:0]', 'lsb_3',
541 ('expected_3[2:0]', 'bin')]
542 write_gtkw(
543 'proof_partsig_all_cover.gtkw',
544 os.path.dirname(__file__) +
545 '/proof_partition_partsig_all/engine_0/trace0.vcd',
546 traces, style,
547 module='top',
548 zoom=-3
549 )
550
551 def op_all(obj):
552 return obj.all()
553
554 module = ComparisonOpDriver(op_all, nops=1)
555 self.assertFormal(module, mode="bmc", depth=1)
556 self.assertFormal(module, mode="cover", depth=1)
557
558
559 if __name__ == '__main__':
560 unittest.main()