fix some borked imports
[ieee754fpu.git] / src / ieee754 / part_shift / formal / proof_shift_scalar.py
1 # Proof of correctness for partitioned scalar shifter
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
3
4 from nmigen import Module, Signal, Elaboratable, Mux, Cat
5 from nmigen.asserts import Assert, AnyConst, Assume
6 from nmutil.formaltest import FHDLTestCase
7 from nmigen.cli import rtlil
8
9 from ieee754.part_mul_add.partpoints import PartitionPoints
10 from ieee754.part_shift.part_shift_scalar import PartitionedScalarShift
11 import unittest
12
13
14 # This defines a module to drive the device under test and assert
15 # properties about its outputs
16 class ShifterDriver(Elaboratable):
17 def __init__(self):
18 # inputs and outputs
19 pass
20
21 def get_intervals(self, signal, points):
22 start = 0
23 interval = []
24 keys = list(points.keys()) + [signal.width]
25 for key in keys:
26 end = key
27 interval.append(signal[start:end])
28 start = end
29 return interval
30
31 def elaborate(self, platform):
32 m = Module()
33 comb = m.d.comb
34 width = 24
35 shifterwidth = 5
36 mwidth = 3
37
38 # setup the inputs and outputs of the DUT as anyconst
39 data = Signal(width)
40 out = Signal(width)
41 shifter = Signal(shifterwidth)
42 points = PartitionPoints()
43 gates = Signal(mwidth-1)
44 shift_right = Signal()
45 step = int(width/mwidth)
46 for i in range(mwidth-1):
47 points[(i+1)*step] = gates[i]
48 print(points)
49
50 comb += [data.eq(AnyConst(width)),
51 shift_right.eq(AnyConst(1)),
52 shifter.eq(AnyConst(shifterwidth)),
53 gates.eq(AnyConst(mwidth-1))]
54
55 m.submodules.dut = dut = PartitionedScalarShift(width, points)
56
57 data_intervals = self.get_intervals(data, points)
58 out_intervals = self.get_intervals(out, points)
59
60 comb += [dut.data.eq(data),
61 dut.shifter.eq(shifter),
62 dut.shift_right.eq(shift_right),
63 out.eq(dut.output)]
64
65 expected = Signal(width)
66
67 with m.If(shift_right == 0):
68 with m.Switch(points.as_sig()):
69 with m.Case(0b00):
70 comb += Assert(
71 out[0:24] == (data[0:24] << (shifter & 0x1f)) &
72 0xffffff)
73
74 with m.Case(0b01):
75 comb += Assert(out[0:8] ==
76 (data[0:8] << (shifter & 0x7)) & 0xFF)
77 comb += Assert(out[8:24] ==
78 (data[8:24] << (shifter & 0xf)) & 0xffff)
79
80 with m.Case(0b10):
81 comb += Assert(out[16:24] ==
82 (data[16:24] << (shifter & 0x7)) & 0xff)
83 comb += Assert(out[0:16] ==
84 (data[0:16] << (shifter & 0xf)) & 0xffff)
85
86 with m.Case(0b11):
87 comb += Assert(out[0:8] ==
88 (data[0:8] << (shifter & 0x7)) & 0xFF)
89 comb += Assert(out[8:16] ==
90 (data[8:16] << (shifter & 0x7)) & 0xff)
91 comb += Assert(out[16:24] ==
92 (data[16:24] << (shifter & 0x7)) & 0xff)
93 with m.Else():
94 with m.Switch(points.as_sig()):
95 with m.Case(0b00):
96 comb += Assert(
97 out[0:24] == (data[0:24] >> (shifter & 0x1f)) &
98 0xffffff)
99
100 with m.Case(0b01):
101 comb += Assert(out[0:8] ==
102 (data[0:8] >> (shifter & 0x7)) & 0xFF)
103 comb += Assert(out[8:24] ==
104 (data[8:24] >> (shifter & 0xf)) & 0xffff)
105
106 with m.Case(0b10):
107 comb += Assert(out[16:24] ==
108 (data[16:24] >> (shifter & 0x7)) & 0xff)
109 comb += Assert(out[0:16] ==
110 (data[0:16] >> (shifter & 0xf)) & 0xffff)
111
112 with m.Case(0b11):
113 comb += Assert(out[0:8] ==
114 (data[0:8] >> (shifter & 0x7)) & 0xFF)
115 comb += Assert(out[8:16] ==
116 (data[8:16] >> (shifter & 0x7)) & 0xff)
117 comb += Assert(out[16:24] ==
118 (data[16:24] >> (shifter & 0x7)) & 0xff)
119 return m
120
121 class PartitionedScalarShiftTestCase(FHDLTestCase):
122 def test_shift(self):
123 module = ShifterDriver()
124 self.assertFormal(module, mode="bmc", depth=4)
125 def test_ilang(self):
126 width = 24
127 mwidth = 3
128 gates = Signal(mwidth-1)
129 points = PartitionPoints()
130 step = int(width/mwidth)
131 for i in range(mwidth-1):
132 points[(i+1)*step] = gates[i]
133 print(points)
134 dut = PartitionedScalarShift(width, points)
135 vl = rtlil.convert(dut, ports=[gates, dut.data,
136 dut.shifter,
137 dut.output])
138 with open("scalar_shift.il", "w") as f:
139 f.write(vl)
140
141 if __name__ == "__main__":
142 unittest.main()