fix some borked imports
[ieee754fpu.git] / src / ieee754 / part_shift / formal / proof_shift_dynamic.py
1 # Proof of correctness for partitioned dynamic shifter
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import Module, Signal, Elaboratable, Mux, Cat
5 from nmigen.asserts import Assert, AnyConst
6 from nmutil.formaltest import FHDLTestCase
7 from nmigen.cli import rtlil
8
9 from ieee754.part_mul_add.partpoints import PartitionPoints
10 from ieee754.part_shift.part_shift_dynamic import \
11 PartitionedDynamicShift
12 import unittest
13
14
15 # This defines a module to drive the device under test and assert
16 # properties about its outputs
17 class ShifterDriver(Elaboratable):
18 def __init__(self):
19 # inputs and outputs
20 pass
21
22 def get_intervals(self, signal, points):
23 start = 0
24 interval = []
25 keys = list(points.keys()) + [signal.width]
26 for key in keys:
27 end = key
28 interval.append(signal[start:end])
29 start = end
30 return interval
31
32 def elaborate(self, platform):
33 m = Module()
34 comb = m.d.comb
35 width = 32
36 mwidth = 4
37
38 # setup the inputs and outputs of the DUT as anyconst
39 a = Signal(width)
40 b = Signal(width)
41 shift_right = Signal()
42 out = Signal(width)
43 points = PartitionPoints()
44 gates = Signal(mwidth-1)
45 step = int(width/mwidth)
46 for i in range(mwidth-1):
47 points[(i+1)*step] = gates[i]
48 print(points)
49
50 comb += [a.eq(AnyConst(width)),
51 b.eq(AnyConst(width)),
52 shift_right.eq(AnyConst(1)),
53 gates.eq(AnyConst(mwidth-1))]
54
55 m.submodules.dut = dut = PartitionedDynamicShift(width, points)
56
57 a_intervals = self.get_intervals(a, points)
58 b_intervals = self.get_intervals(b, points)
59 out_intervals = self.get_intervals(out, points)
60
61 comb += [dut.a.eq(a),
62 dut.b.eq(b),
63 dut.shift_right.eq(shift_right),
64 out.eq(dut.output)]
65
66
67 with m.If(shift_right == 0):
68 with m.Switch(points.as_sig()):
69 with m.Case(0b000):
70 comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
71 with m.Case(0b001):
72 comb += Assert(out_intervals[0] ==
73 (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
74 comb += Assert(Cat(out_intervals[1:4]) ==
75 (Cat(a_intervals[1:4])
76 << b_intervals[1][0:5]) & 0xffffff)
77 with m.Case(0b010):
78 comb += Assert(Cat(out_intervals[0:2]) ==
79 (Cat(a_intervals[0:2])
80 << (b_intervals[0] & 0xf)) & 0xffff)
81 comb += Assert(Cat(out_intervals[2:4]) ==
82 (Cat(a_intervals[2:4])
83 << (b_intervals[2] & 0xf)) & 0xffff)
84 with m.Case(0b011):
85 comb += Assert(out_intervals[0] ==
86 (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
87 comb += Assert(out_intervals[1] ==
88 (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
89 comb += Assert(Cat(out_intervals[2:4]) ==
90 (Cat(a_intervals[2:4])
91 << b_intervals[2][0:4]) & 0xffff)
92 with m.Case(0b100):
93 comb += Assert(Cat(out_intervals[0:3]) ==
94 (Cat(a_intervals[0:3])
95 << b_intervals[0][0:5]) & 0xffffff)
96 comb += Assert(out_intervals[3] ==
97 (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
98 with m.Case(0b101):
99 comb += Assert(out_intervals[0] ==
100 (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
101 comb += Assert(Cat(out_intervals[1:3]) ==
102 (Cat(a_intervals[1:3])
103 << b_intervals[1][0:4]) & 0xffff)
104 comb += Assert(out_intervals[3] ==
105 (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
106 with m.Case(0b110):
107 comb += Assert(Cat(out_intervals[0:2]) ==
108 (Cat(a_intervals[0:2])
109 << b_intervals[0][0:4]) & 0xffff)
110 comb += Assert(out_intervals[2] ==
111 (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
112 comb += Assert(out_intervals[3] ==
113 (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
114 with m.Case(0b111):
115 for i, o in enumerate(out_intervals):
116 comb += Assert(o ==
117 (a_intervals[i] << b_intervals[i][0:3])
118 & 0xff)
119 with m.Else():
120 with m.Switch(points.as_sig()):
121 with m.Case(0b000):
122 comb += Assert(out == (a>>b[0:5]) & 0xffffffff)
123 with m.Case(0b001):
124 comb += Assert(out_intervals[0] ==
125 (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
126 comb += Assert(Cat(out_intervals[1:4]) ==
127 (Cat(a_intervals[1:4])
128 >> b_intervals[1][0:5]) & 0xffffff)
129 with m.Case(0b010):
130 comb += Assert(Cat(out_intervals[0:2]) ==
131 (Cat(a_intervals[0:2])
132 >> (b_intervals[0] & 0xf)) & 0xffff)
133 comb += Assert(Cat(out_intervals[2:4]) ==
134 (Cat(a_intervals[2:4])
135 >> (b_intervals[2] & 0xf)) & 0xffff)
136 with m.Case(0b011):
137 comb += Assert(out_intervals[0] ==
138 (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
139 comb += Assert(out_intervals[1] ==
140 (a_intervals[1] >> b_intervals[1][0:3]) & 0xff)
141 comb += Assert(Cat(out_intervals[2:4]) ==
142 (Cat(a_intervals[2:4])
143 >> b_intervals[2][0:4]) & 0xffff)
144 with m.Case(0b100):
145 comb += Assert(Cat(out_intervals[0:3]) ==
146 (Cat(a_intervals[0:3])
147 >> b_intervals[0][0:5]) & 0xffffff)
148 comb += Assert(out_intervals[3] ==
149 (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
150 with m.Case(0b101):
151 comb += Assert(out_intervals[0] ==
152 (a_intervals[0] >> b_intervals[0][0:3]) & 0xff)
153 comb += Assert(Cat(out_intervals[1:3]) ==
154 (Cat(a_intervals[1:3])
155 >> b_intervals[1][0:4]) & 0xffff)
156 comb += Assert(out_intervals[3] ==
157 (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
158 with m.Case(0b110):
159 comb += Assert(Cat(out_intervals[0:2]) ==
160 (Cat(a_intervals[0:2])
161 >> b_intervals[0][0:4]) & 0xffff)
162 comb += Assert(out_intervals[2] ==
163 (a_intervals[2] >> b_intervals[2][0:3]) & 0xff)
164 comb += Assert(out_intervals[3] ==
165 (a_intervals[3] >> b_intervals[3][0:3]) & 0xff)
166 with m.Case(0b111):
167 for i, o in enumerate(out_intervals):
168 comb += Assert(o ==
169 (a_intervals[i] >> b_intervals[i][0:3])
170 & 0xff)
171
172 return m
173
174 class PartitionedDynamicShiftTestCase(FHDLTestCase):
175 def test_shift(self):
176 module = ShifterDriver()
177 self.assertFormal(module, mode="bmc", depth=4)
178
179 def test_ilang(self):
180 width = 64
181 mwidth = 8
182 gates = Signal(mwidth-1)
183 points = PartitionPoints()
184 step = int(width/mwidth)
185 for i in range(mwidth-1):
186 points[(i+1)*step] = gates[i]
187 print(points)
188 dut = PartitionedDynamicShift(width, points)
189 vl = rtlil.convert(dut, ports=[gates, dut.a, dut.b, dut.output])
190 with open("dynamic_shift.il", "w") as f:
191 f.write(vl)
192
193
194 if __name__ == "__main__":
195 unittest.main()