Work only with nibbles on the ascending cascade
[ieee754fpu.git] / src / ieee754 / part / formal / proof_partition.py
1 """Formal verification of partitioned operations
2
3 The approach is to take an arbitrary partition, by choosing its start point
4 and size at random. Use ``Assume`` to ensure it is a whole unbroken partition
5 (start and end points are one, with only zeros in between). Shift inputs and
6 outputs down to zero. Loop over all possible partition sizes and, if it's the
7 right size, compute the expected value, compare with the result, and assert.
8
9 We are turning the for-loops around (on their head), such that we start from
10 the *lengths* (and positions) and perform the ``Assume`` on the resultant
11 partition bits.
12
13 In other words, we have patterns as follows (assuming 32-bit words)::
14
15 8-bit offsets 0,1,2,3
16 16-bit offsets 0,1,2
17 24-bit offsets 0,1
18 32-bit
19
20 * for 8-bit the partition bit is 1 and the previous is also 1
21
22 * for 16-bit the partition bit at the offset must be 0 and be surrounded by 1
23
24 * for 24-bit the partition bits at the offset and at offset+1 must be 0 and at
25 offset+2 and offset-1 must be 1
26
27 * for 32-bit all 3 bits must be 0 and be surrounded by 1 (guard bits are added
28 at each end for this purpose)
29
30 """
31
32 import os
33 import unittest
34
35 from nmigen import Elaboratable, Signal, Module, Const
36 from nmigen.asserts import Assert, Cover
37
38 from nmutil.formaltest import FHDLTestCase
39 from nmutil.gtkw import write_gtkw
40
41 from ieee754.part_mul_add.partpoints import PartitionPoints
42
43
44 class PartitionedPattern(Elaboratable):
45 """ Generate a unique pattern, depending on partition size.
46
47 * 1-byte partitions: 0x11
48 * 2-byte partitions: 0x21 0x22
49 * 3-byte partitions: 0x31 0x32 0x33
50
51 And so on.
52
53 Useful as a test vector for testing the formal prover
54
55 """
56 def __init__(self, width, partition_points):
57 self.width = width
58 self.partition_points = PartitionPoints(partition_points)
59 self.mwidth = len(self.partition_points)+1
60 self.output = Signal(self.width, reset_less=True)
61
62 def elaborate(self, platform):
63 m = Module()
64 comb = m.d.comb
65
66 # Add a guard bit at each end
67 positions = [0] + list(self.partition_points.keys()) + [self.width]
68 gates = [Const(1)] + list(self.partition_points.values()) + [Const(1)]
69 # Begin counting at one
70 last_start = positions[0]
71 last_end = positions[1]
72 last_middle = (last_start+last_end)//2
73 comb += self.output[last_start:last_middle].eq(1)
74 # Build an incrementing cascade
75 for i in range(1, self.mwidth):
76 start = positions[i]
77 end = positions[i+1]
78 middle = (start + end) // 2
79 # Propagate from the previous byte, adding one to it.
80 with m.If(~gates[i]):
81 comb += self.output[start:middle].eq(
82 self.output[last_start:last_middle] + 1)
83 with m.Else():
84 # ... unless it's a partition boundary. If so, start again.
85 comb += self.output[start:middle].eq(1)
86 last_start = start
87 last_middle = middle
88 # Mirror the nibbles on the last byte
89 last_start = positions[-2]
90 last_end = positions[-1]
91 last_middle = (last_start+last_end)//2
92 comb += self.output[last_middle:last_end].eq(
93 self.output[last_start:last_middle])
94 for i in range(self.mwidth, 0, -1):
95 start = positions[i-1]
96 end = positions[i]
97 middle = (start + end) // 2
98 # Propagate from the previous byte.
99 with m.If(~gates[i]):
100 comb += self.output[middle:end].eq(
101 self.output[last_middle:last_end])
102 with m.Else():
103 # ... unless it's a partition boundary.
104 # If so, mirror the nibbles again.
105 comb += self.output[middle:end].eq(
106 self.output[start:middle])
107 last_middle = middle
108 last_end = end
109
110 return m
111
112
113 # This defines a module to drive the device under test and assert
114 # properties about its outputs
115 class Driver(Elaboratable):
116 def __init__(self):
117 # inputs and outputs
118 pass
119
120 @staticmethod
121 def elaborate(_):
122 m = Module()
123 comb = m.d.comb
124 width = 64
125 mwidth = 8
126 # Setup partition points and gates
127 points = PartitionPoints()
128 gates = Signal(mwidth-1)
129 step = int(width/mwidth)
130 for i in range(mwidth-1):
131 points[(i+1)*step] = gates[i]
132 # Instantiate the partitioned pattern producer
133 m.submodules.dut = dut = PartitionedPattern(width, points)
134 # Directly check some cases
135 with m.If(gates == 0):
136 comb += Assert(dut.output == 0x_88_87_86_85_84_83_82_81)
137 with m.If(gates == 0b1100101):
138 comb += Assert(dut.output == 0x_11_11_33_32_31_22_21_11)
139 with m.If(gates == 0b0001000):
140 comb += Assert(dut.output == 0x_44_43_42_41_44_43_42_41)
141 with m.If(gates == 0b0100001):
142 comb += Assert(dut.output == 0x_22_21_55_54_53_52_51_11)
143 with m.If(gates == 0b1000001):
144 comb += Assert(dut.output == 0x_11_66_65_64_63_62_61_11)
145 with m.If(gates == 0b0000001):
146 comb += Assert(dut.output == 0x_77_76_75_74_73_72_71_11)
147 # Make it interesting, by having three partitions
148 comb += Cover(sum(gates) == 3)
149 return m
150
151
152 class PartitionTestCase(FHDLTestCase):
153 def test_formal(self):
154 traces = ['output[63:0]', 'gates[6:0]']
155 write_gtkw(
156 'test_formal_cover.gtkw',
157 os.path.dirname(__file__) +
158 '/proof_partition_formal/engine_0/trace0.vcd',
159 traces,
160 module='top.dut',
161 zoom="formal"
162 )
163 write_gtkw(
164 'test_formal_bmc.gtkw',
165 os.path.dirname(__file__) +
166 '/proof_partition_formal/engine_0/trace.vcd',
167 traces,
168 module='top.dut',
169 zoom="formal"
170 )
171 module = Driver()
172 self.assertFormal(module, mode="bmc", depth=1)
173 self.assertFormal(module, mode="cover", depth=1)
174
175
176 if __name__ == '__main__':
177 unittest.main()