Allow the formal engine to perform a same-cycle result in the ALU
[soc.git] / src / soc / fu / div / test / test_fsm.py
1 import unittest
2 from soc.fu.div.fsm import DivState, DivStateInit, DivStateNext
3 from nmigen import Elaboratable, Module, Signal, unsigned
4 from nmigen.cli import rtlil
5 try:
6 from nmigen.sim.pysim import Simulator, Delay, Tick
7 except ImportError:
8 from nmigen.back.pysim import Simulator, Delay, Tick
9
10
11 class CheckEvent(Elaboratable):
12 """helper to add indication to vcd when signals are checked
13 """
14
15 def __init__(self):
16 self.event = Signal()
17
18 def trigger(self):
19 yield self.event.eq(~self.event)
20
21 def elaborate(self, platform):
22 m = Module()
23 # use event somehow so nmigen simulation knows about it
24 m.d.comb += Signal().eq(self.event)
25 return m
26
27
28 class DivStateCombTest(Elaboratable):
29 """Test stringing a bunch of copies of the FSM state-function together
30 """
31
32 def __init__(self, quotient_width):
33 self.check_event = CheckEvent()
34 self.quotient_width = quotient_width
35 self.dividend = Signal(unsigned(quotient_width * 2))
36 self.divisor = Signal(unsigned(quotient_width))
37 self.quotient = Signal(unsigned(quotient_width))
38 self.remainder = Signal(unsigned(quotient_width))
39 self.expected_quotient = Signal(unsigned(quotient_width))
40 self.expected_remainder = Signal(unsigned(quotient_width))
41 self.expected_valid = Signal()
42 self.states = []
43 for i in range(quotient_width + 1):
44 state = DivState(quotient_width=quotient_width, name=f"state{i}")
45 self.states.append(state)
46 self.init = DivStateInit(quotient_width)
47 self.nexts = []
48 for i in range(quotient_width):
49 next = DivStateNext(quotient_width)
50 self.nexts.append(next)
51
52 def elaborate(self, platform):
53 m = Module()
54 m.submodules.check_event = self.check_event
55 m.submodules.init = self.init
56 m.d.comb += self.init.dividend.eq(self.dividend)
57 m.d.comb += self.states[0].eq(self.init.o)
58 last_state = self.states[0]
59 for i in range(self.quotient_width):
60 setattr(m.submodules, f"next{i}", self.nexts[i])
61 m.d.comb += self.nexts[i].divisor.eq(self.divisor)
62 m.d.comb += self.nexts[i].i.eq(last_state)
63 last_state = self.states[i + 1]
64 m.d.comb += last_state.eq(self.nexts[i].o)
65 m.d.comb += self.quotient.eq(last_state.quotient)
66 m.d.comb += self.remainder.eq(last_state.remainder)
67 m.d.comb += self.expected_valid.eq(
68 (self.dividend < (self.divisor << self.quotient_width))
69 & (self.divisor != 0))
70 with m.If(self.expected_valid):
71 m.d.comb += self.expected_quotient.eq(
72 self.dividend // self.divisor)
73 m.d.comb += self.expected_remainder.eq(
74 self.dividend % self.divisor)
75 return m
76
77
78 class DivStateFSMTest(Elaboratable):
79 def __init__(self, quotient_width):
80 self.check_done_event = CheckEvent()
81 self.check_event = CheckEvent()
82 self.quotient_width = quotient_width
83 self.dividend = Signal(unsigned(quotient_width * 2))
84 self.divisor = Signal(unsigned(quotient_width))
85 self.quotient = Signal(unsigned(quotient_width))
86 self.remainder = Signal(unsigned(quotient_width))
87 self.expected_quotient = Signal(unsigned(quotient_width))
88 self.expected_remainder = Signal(unsigned(quotient_width))
89 self.expected_valid = Signal()
90 self.state = DivState(quotient_width=quotient_width,
91 name="state")
92 self.next_state = DivState(quotient_width=quotient_width,
93 name="next_state")
94 self.init = DivStateInit(quotient_width)
95 self.next = DivStateNext(quotient_width)
96 self.state_done = Signal()
97 self.next_state_done = Signal()
98 self.clear = Signal(reset=1)
99
100 def elaborate(self, platform):
101 m = Module()
102 m.submodules.check_event = self.check_event
103 m.submodules.check_done_event = self.check_done_event
104 m.submodules.init = self.init
105 m.submodules.next = self.next
106 m.d.comb += self.init.dividend.eq(self.dividend)
107 m.d.comb += self.next.divisor.eq(self.divisor)
108 m.d.comb += self.quotient.eq(self.state.quotient)
109 m.d.comb += self.remainder.eq(self.state.remainder)
110 m.d.comb += self.next.i.eq(self.state)
111 m.d.comb += self.state_done.eq(self.state.done)
112 m.d.comb += self.next_state_done.eq(self.next_state.done)
113
114 with m.If(self.state.done | self.clear):
115 m.d.comb += self.next_state.eq(self.init.o)
116 with m.Else():
117 m.d.comb += self.next_state.eq(self.next.o)
118
119 m.d.sync += self.state.eq(self.next_state)
120
121 m.d.comb += self.expected_valid.eq(
122 (self.dividend < (self.divisor << self.quotient_width))
123 & (self.divisor != 0))
124 with m.If(self.expected_valid):
125 m.d.comb += self.expected_quotient.eq(
126 self.dividend // self.divisor)
127 m.d.comb += self.expected_remainder.eq(
128 self.dividend % self.divisor)
129 return m
130
131
132 def get_cases(quotient_width):
133 test_cases = []
134 mask = ~(~0 << quotient_width)
135 for i in range(-3, 4):
136 test_cases.append(i & mask)
137 for i in [-1, 0, 1]:
138 test_cases.append((i + (mask >> 1)) & mask)
139 test_cases.sort()
140 return test_cases
141
142
143 class TestDivState(unittest.TestCase):
144 def test_div_state_comb(self, quotient_width=8):
145 test_cases = get_cases(quotient_width)
146 mask = ~(~0 << quotient_width)
147 dut = DivStateCombTest(quotient_width)
148 vl = rtlil.convert(dut,
149 ports=[dut.dividend,
150 dut.divisor,
151 dut.quotient,
152 dut.remainder])
153 with open("div_fsm_comb_pipeline.il", "w") as f:
154 f.write(vl)
155 dut = DivStateCombTest(quotient_width)
156
157 def check(dividend, divisor):
158 with self.subTest(dividend=f"{dividend:#x}",
159 divisor=f"{divisor:#x}"):
160 yield from dut.check_event.trigger()
161 for i in range(quotient_width + 1):
162 # done must be correct and eventually true
163 # even if a div-by-zero or overflow occurred
164 done = yield dut.states[i].done
165 self.assertEqual(done, i == quotient_width)
166 if divisor != 0:
167 quotient = dividend // divisor
168 remainder = dividend % divisor
169 if quotient <= mask:
170 with self.subTest(quotient=f"{quotient:#x}",
171 remainder=f"{remainder:#x}"):
172 self.assertTrue((yield dut.expected_valid))
173 self.assertEqual((yield dut.expected_quotient),
174 quotient)
175 self.assertEqual((yield dut.expected_remainder),
176 remainder)
177 self.assertEqual((yield dut.quotient), quotient)
178 self.assertEqual((yield dut.remainder), remainder)
179 else:
180 self.assertFalse((yield dut.expected_valid))
181 else:
182 self.assertFalse((yield dut.expected_valid))
183
184 def process(gen):
185 for dividend_high in test_cases:
186 for dividend_low in test_cases:
187 dividend = dividend_low + \
188 (dividend_high << quotient_width)
189 for divisor in test_cases:
190 if gen:
191 yield Delay(0.5e-6)
192 yield dut.dividend.eq(dividend)
193 yield dut.divisor.eq(divisor)
194 yield Delay(0.5e-6)
195 else:
196 yield Delay(1e-6)
197 yield from check(dividend, divisor)
198
199 def gen_process():
200 yield from process(gen=True)
201
202 def check_process():
203 yield from process(gen=False)
204
205 sim = Simulator(dut)
206 with sim.write_vcd(vcd_file="div_fsm_comb_pipeline.vcd",
207 gtkw_file="div_fsm_comb_pipeline.gtkw"):
208
209 sim.add_process(gen_process)
210 sim.add_process(check_process)
211 sim.run()
212
213 def test_div_state_fsm(self, quotient_width=8):
214 test_cases = get_cases(quotient_width)
215 mask = ~(~0 << quotient_width)
216 dut = DivStateFSMTest(quotient_width)
217 vl = rtlil.convert(dut,
218 ports=[dut.dividend,
219 dut.divisor,
220 dut.quotient,
221 dut.remainder])
222 with open("div_fsm.il", "w") as f:
223 f.write(vl)
224
225 def check(dividend, divisor):
226 with self.subTest(dividend=f"{dividend:#x}",
227 divisor=f"{divisor:#x}"):
228 for i in range(quotient_width + 1):
229 yield Tick()
230 yield Delay(0.1e-6)
231 yield from dut.check_done_event.trigger()
232 with self.subTest():
233 # done must be correct and eventually true
234 # even if a div-by-zero or overflow occurred
235 done = yield dut.state.done
236 self.assertEqual(done, i == quotient_width)
237 yield from dut.check_event.trigger()
238 now = None
239 try:
240 # FIXME(programmerjake): replace with public API
241 # see https://github.com/nmigen/nmigen/issues/443
242 now = sim._engine.now
243 except AttributeError:
244 pass
245 if divisor != 0:
246 quotient = dividend // divisor
247 remainder = dividend % divisor
248 if quotient <= mask:
249 with self.subTest(quotient=f"{quotient:#x}",
250 remainder=f"{remainder:#x}",
251 now=f"{now}"):
252 self.assertTrue((yield dut.expected_valid))
253 self.assertEqual((yield dut.expected_quotient),
254 quotient)
255 self.assertEqual((yield dut.expected_remainder),
256 remainder)
257 self.assertEqual((yield dut.quotient), quotient)
258 self.assertEqual((yield dut.remainder), remainder)
259 else:
260 self.assertFalse((yield dut.expected_valid))
261 else:
262 self.assertFalse((yield dut.expected_valid))
263
264 def process(gen):
265 if gen:
266 yield dut.clear.eq(1)
267 yield Tick()
268 else:
269 yield from dut.check_event.trigger()
270 yield from dut.check_done_event.trigger()
271 for dividend_high in test_cases:
272 for dividend_low in test_cases:
273 dividend = dividend_low + \
274 (dividend_high << quotient_width)
275 for divisor in test_cases:
276 if gen:
277 yield Delay(0.2e-6)
278 yield dut.clear.eq(0)
279 yield dut.dividend.eq(dividend)
280 yield dut.divisor.eq(divisor)
281 for _ in range(quotient_width + 1):
282 yield Tick()
283 else:
284 yield from check(dividend, divisor)
285
286 def gen_process():
287 yield from process(gen=True)
288
289 def check_process():
290 yield from process(gen=False)
291
292 sim = Simulator(dut)
293 with sim.write_vcd(vcd_file="div_fsm.vcd",
294 gtkw_file="div_fsm.gtkw"):
295
296 sim.add_clock(1e-6)
297 sim.add_process(gen_process)
298 sim.add_process(check_process)
299 sim.run()
300
301
302 if __name__ == "__main__":
303 unittest.main()