1 # Proof of correctness for partitioned dynamic shifter
2 # Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
4 from nmigen
import Module
, Signal
, Elaboratable
, Mux
, Cat
5 from nmigen
.asserts
import Assert
, AnyConst
6 from nmutil
.formaltest
import FHDLTestCase
7 from nmigen
.cli
import rtlil
9 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
10 from ieee754
.part_shift
.part_shift_dynamic
import \
11 PartitionedDynamicShift
15 # This defines a module to drive the device under test and assert
16 # properties about its outputs
17 class ShifterDriver(Elaboratable
):
22 def get_intervals(self
, signal
, points
):
25 keys
= list(points
.keys()) + [signal
.width
]
28 interval
.append(signal
[start
:end
])
32 def elaborate(self
, platform
):
38 # setup the inputs and outputs of the DUT as anyconst
41 shift_right
= Signal()
43 points
= PartitionPoints()
44 gates
= Signal(mwidth
-1)
45 step
= int(width
/mwidth
)
46 for i
in range(mwidth
-1):
47 points
[(i
+1)*step
] = gates
[i
]
50 comb
+= [a
.eq(AnyConst(width
)),
51 b
.eq(AnyConst(width
)),
52 shift_right
.eq(AnyConst(1)),
53 gates
.eq(AnyConst(mwidth
-1))]
55 m
.submodules
.dut
= dut
= PartitionedDynamicShift(width
, points
)
57 a_intervals
= self
.get_intervals(a
, points
)
58 b_intervals
= self
.get_intervals(b
, points
)
59 out_intervals
= self
.get_intervals(out
, points
)
63 dut
.shift_right
.eq(shift_right
),
67 with m
.If(shift_right
== 0):
68 with m
.Switch(points
.as_sig()):
70 comb
+= Assert(out
== (a
<<b
[0:5]) & 0xffffffff)
72 comb
+= Assert(out_intervals
[0] ==
73 (a_intervals
[0] << b_intervals
[0][0:3]) & 0xff)
74 comb
+= Assert(Cat(out_intervals
[1:4]) ==
75 (Cat(a_intervals
[1:4])
76 << b_intervals
[1][0:5]) & 0xffffff)
78 comb
+= Assert(Cat(out_intervals
[0:2]) ==
79 (Cat(a_intervals
[0:2])
80 << (b_intervals
[0] & 0xf)) & 0xffff)
81 comb
+= Assert(Cat(out_intervals
[2:4]) ==
82 (Cat(a_intervals
[2:4])
83 << (b_intervals
[2] & 0xf)) & 0xffff)
85 comb
+= Assert(out_intervals
[0] ==
86 (a_intervals
[0] << b_intervals
[0][0:3]) & 0xff)
87 comb
+= Assert(out_intervals
[1] ==
88 (a_intervals
[1] << b_intervals
[1][0:3]) & 0xff)
89 comb
+= Assert(Cat(out_intervals
[2:4]) ==
90 (Cat(a_intervals
[2:4])
91 << b_intervals
[2][0:4]) & 0xffff)
93 comb
+= Assert(Cat(out_intervals
[0:3]) ==
94 (Cat(a_intervals
[0:3])
95 << b_intervals
[0][0:5]) & 0xffffff)
96 comb
+= Assert(out_intervals
[3] ==
97 (a_intervals
[3] << b_intervals
[3][0:3]) & 0xff)
99 comb
+= Assert(out_intervals
[0] ==
100 (a_intervals
[0] << b_intervals
[0][0:3]) & 0xff)
101 comb
+= Assert(Cat(out_intervals
[1:3]) ==
102 (Cat(a_intervals
[1:3])
103 << b_intervals
[1][0:4]) & 0xffff)
104 comb
+= Assert(out_intervals
[3] ==
105 (a_intervals
[3] << b_intervals
[3][0:3]) & 0xff)
107 comb
+= Assert(Cat(out_intervals
[0:2]) ==
108 (Cat(a_intervals
[0:2])
109 << b_intervals
[0][0:4]) & 0xffff)
110 comb
+= Assert(out_intervals
[2] ==
111 (a_intervals
[2] << b_intervals
[2][0:3]) & 0xff)
112 comb
+= Assert(out_intervals
[3] ==
113 (a_intervals
[3] << b_intervals
[3][0:3]) & 0xff)
115 for i
, o
in enumerate(out_intervals
):
117 (a_intervals
[i
] << b_intervals
[i
][0:3])
120 with m
.Switch(points
.as_sig()):
122 comb
+= Assert(out
== (a
>>b
[0:5]) & 0xffffffff)
124 comb
+= Assert(out_intervals
[0] ==
125 (a_intervals
[0] >> b_intervals
[0][0:3]) & 0xff)
126 comb
+= Assert(Cat(out_intervals
[1:4]) ==
127 (Cat(a_intervals
[1:4])
128 >> b_intervals
[1][0:5]) & 0xffffff)
130 comb
+= Assert(Cat(out_intervals
[0:2]) ==
131 (Cat(a_intervals
[0:2])
132 >> (b_intervals
[0] & 0xf)) & 0xffff)
133 comb
+= Assert(Cat(out_intervals
[2:4]) ==
134 (Cat(a_intervals
[2:4])
135 >> (b_intervals
[2] & 0xf)) & 0xffff)
137 comb
+= Assert(out_intervals
[0] ==
138 (a_intervals
[0] >> b_intervals
[0][0:3]) & 0xff)
139 comb
+= Assert(out_intervals
[1] ==
140 (a_intervals
[1] >> b_intervals
[1][0:3]) & 0xff)
141 comb
+= Assert(Cat(out_intervals
[2:4]) ==
142 (Cat(a_intervals
[2:4])
143 >> b_intervals
[2][0:4]) & 0xffff)
145 comb
+= Assert(Cat(out_intervals
[0:3]) ==
146 (Cat(a_intervals
[0:3])
147 >> b_intervals
[0][0:5]) & 0xffffff)
148 comb
+= Assert(out_intervals
[3] ==
149 (a_intervals
[3] >> b_intervals
[3][0:3]) & 0xff)
151 comb
+= Assert(out_intervals
[0] ==
152 (a_intervals
[0] >> b_intervals
[0][0:3]) & 0xff)
153 comb
+= Assert(Cat(out_intervals
[1:3]) ==
154 (Cat(a_intervals
[1:3])
155 >> b_intervals
[1][0:4]) & 0xffff)
156 comb
+= Assert(out_intervals
[3] ==
157 (a_intervals
[3] >> b_intervals
[3][0:3]) & 0xff)
159 comb
+= Assert(Cat(out_intervals
[0:2]) ==
160 (Cat(a_intervals
[0:2])
161 >> b_intervals
[0][0:4]) & 0xffff)
162 comb
+= Assert(out_intervals
[2] ==
163 (a_intervals
[2] >> b_intervals
[2][0:3]) & 0xff)
164 comb
+= Assert(out_intervals
[3] ==
165 (a_intervals
[3] >> b_intervals
[3][0:3]) & 0xff)
167 for i
, o
in enumerate(out_intervals
):
169 (a_intervals
[i
] >> b_intervals
[i
][0:3])
174 class PartitionedDynamicShiftTestCase(FHDLTestCase
):
175 def test_shift(self
):
176 module
= ShifterDriver()
177 self
.assertFormal(module
, mode
="bmc", depth
=4)
179 def test_ilang(self
):
182 gates
= Signal(mwidth
-1)
183 points
= PartitionPoints()
184 step
= int(width
/mwidth
)
185 for i
in range(mwidth
-1):
186 points
[(i
+1)*step
] = gates
[i
]
188 dut
= PartitionedDynamicShift(width
, points
)
189 vl
= rtlil
.convert(dut
, ports
=[gates
, dut
.a
, dut
.b
, dut
.output
])
190 with
open("dynamic_shift.il", "w") as f
:
194 if __name__
== "__main__":