disable fadd f32 formal proofs by default -- they're too slow
[ieee754fpu.git] / src / ieee754 / div_rem_sqrt_rsqrt / test_core.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from ieee754.div_rem_sqrt_rsqrt.core import (DivPipeCoreConfig,
6 DivPipeCoreSetupStage,
7 DivPipeCoreCalculateStage, DivPipeCoreFinalStage,
8 DivPipeCoreOperation, DivPipeCoreInputData,
9 DivPipeCoreInterstageData, DivPipeCoreOutputData)
10 from ieee754.div_rem_sqrt_rsqrt.algorithm import (FixedUDivRemSqrtRSqrt,
11 Fixed, Operation, div_rem,
12 fixed_sqrt, fixed_rsqrt)
13 import unittest
14 from nmigen import Module, Elaboratable, Signal
15 from nmigen.hdl.ir import Fragment
16 from nmigen.back import rtlil
17 from nmigen.back.pysim import Simulator, Delay, Tick
18 from itertools import chain
19 import inspect
20
21
22 def show_fixed(bits, fract_width, bit_width):
23 fixed = Fixed.from_bits(bits, fract_width, bit_width, False)
24 return f"{str(fixed)}:{repr(fixed)}"
25
26
27 def get_core_op(alg_op):
28 if alg_op is Operation.UDivRem:
29 return DivPipeCoreOperation.UDivRem
30 if alg_op is Operation.SqrtRem:
31 return DivPipeCoreOperation.SqrtRem
32 assert alg_op is Operation.RSqrtRem
33 return DivPipeCoreOperation.RSqrtRem
34
35
36 class TestCaseData:
37 __test__ = False # make pytest ignore class
38
39 def __init__(self,
40 dividend,
41 divisor_radicand,
42 alg_op,
43 quotient_root,
44 remainder,
45 core_config):
46 self.dividend = dividend
47 self.divisor_radicand = divisor_radicand
48 self.alg_op = alg_op
49 self.quotient_root = quotient_root
50 self.remainder = remainder
51 self.core_config = core_config
52
53 @property
54 def core_op(self):
55 return get_core_op(self.alg_op)
56
57 def __str__(self):
58 bit_width = self.core_config.bit_width
59 fract_width = self.core_config.fract_width
60 dividend_str = show_fixed(self.dividend,
61 fract_width * 2,
62 bit_width + fract_width)
63 divisor_radicand_str = show_fixed(self.divisor_radicand,
64 fract_width,
65 bit_width)
66 quotient_root_str = show_fixed(self.quotient_root,
67 fract_width,
68 bit_width)
69 remainder_str = show_fixed(self.remainder,
70 fract_width * 3,
71 bit_width * 3)
72 return f"{{dividend={dividend_str}, " \
73 + f"divisor_radicand={divisor_radicand_str}, " \
74 + f"op={self.alg_op.name}, " \
75 + f"quotient_root={quotient_root_str}, " \
76 + f"remainder={remainder_str}, " \
77 + f"config={self.core_config}}}"
78
79
80 def generate_test_case(core_config, dividend, divisor_radicand, alg_op):
81 bit_width = core_config.bit_width
82 fract_width = core_config.fract_width
83 obj = FixedUDivRemSqrtRSqrt(dividend,
84 divisor_radicand,
85 alg_op,
86 bit_width,
87 fract_width,
88 core_config.log2_radix)
89 obj.calculate()
90 yield TestCaseData(dividend,
91 divisor_radicand,
92 alg_op,
93 obj.quotient_root,
94 obj.remainder,
95 core_config)
96
97
98 def shifted_ints(total_bits, int_bits):
99 """ Generate a sequence like a generalized binary version of A037124.
100
101 See https://oeis.org/A037124
102
103 Generates the sequence of all non-negative integers ``n`` in ascending
104 order with no repeats where ``n < (1 << total_bits) and n == (v << i)``
105 where ``i`` is a non-negative integer and ``v`` is a non-negative
106 integer less than ``1 << int_bits``.
107 """
108 n = 0
109 while n < (1 << total_bits):
110 yield n
111 if n < (1 << int_bits):
112 n += 1
113 else:
114 n += 1 << (n.bit_length() - int_bits)
115
116
117 def partitioned_ints(bit_width):
118 """ Get ints with all 1s on one side and 0s on the other. """
119 for i in range(bit_width):
120 yield (-1 << i) & ((1 << bit_width) - 1)
121 yield (1 << (i + 1)) - 1
122
123
124 class TestShiftedInts(unittest.TestCase):
125 def test(self):
126 expected = [0x000,
127 0x001,
128 0x002, 0x003,
129 0x004, 0x005, 0x006, 0x007,
130 0x008, 0x009, 0x00A, 0x00B, 0x00C, 0x00D, 0x00E, 0x00F,
131 0x010, 0x012, 0x014, 0x016, 0x018, 0x01A, 0x01C, 0x01E,
132 0x020, 0x024, 0x028, 0x02C, 0x030, 0x034, 0x038, 0x03C,
133 0x040, 0x048, 0x050, 0x058, 0x060, 0x068, 0x070, 0x078,
134 0x080, 0x090, 0x0A0, 0x0B0, 0x0C0, 0x0D0, 0x0E0, 0x0F0,
135 0x100, 0x120, 0x140, 0x160, 0x180, 0x1A0, 0x1C0, 0x1E0,
136 0x200, 0x240, 0x280, 0x2C0, 0x300, 0x340, 0x380, 0x3C0,
137 0x400, 0x480, 0x500, 0x580, 0x600, 0x680, 0x700, 0x780,
138 0x800, 0x900, 0xA00, 0xB00, 0xC00, 0xD00, 0xE00, 0xF00]
139 self.assertEqual(list(shifted_ints(12, 4)), expected)
140
141
142 def get_test_cases(core_config,
143 dividends=None,
144 divisors=None,
145 radicands=None):
146 if dividends is None:
147 dividend_width = core_config.bit_width + core_config.fract_width
148 dividends = [*shifted_ints(dividend_width,
149 max(3, core_config.log2_radix)),
150 *partitioned_ints(dividend_width)]
151 else:
152 assert isinstance(dividends, list)
153 if divisors is None:
154 divisors = [*shifted_ints(core_config.bit_width,
155 max(3, core_config.log2_radix)),
156 *partitioned_ints(core_config.bit_width)]
157 else:
158 assert isinstance(divisors, list)
159 if radicands is None:
160 radicands = [*shifted_ints(core_config.bit_width, 5),
161 *partitioned_ints(core_config.bit_width)]
162 else:
163 assert isinstance(radicands, list)
164
165 for alg_op in reversed(Operation): # put UDivRem at end
166 if get_core_op(alg_op) not in core_config.supported:
167 continue
168 if alg_op is Operation.UDivRem:
169 for dividend in dividends:
170 for divisor in divisors:
171 yield from generate_test_case(core_config,
172 dividend,
173 divisor,
174 alg_op)
175 else:
176 for radicand in radicands:
177 yield from generate_test_case(core_config,
178 0,
179 radicand,
180 alg_op)
181
182
183 class DivPipeCoreTestPipeline(Elaboratable):
184 def __init__(self, core_config, sync):
185 self.setup_stage = DivPipeCoreSetupStage(core_config)
186 self.calculate_stages = [
187 DivPipeCoreCalculateStage(core_config, stage_index)
188 for stage_index in range(core_config.n_stages)]
189 self.final_stage = DivPipeCoreFinalStage(core_config)
190 self.interstage_signals = [
191 DivPipeCoreInterstageData(core_config, reset_less=True)
192 for i in range(core_config.n_stages + 1)]
193 self.i = DivPipeCoreInputData(core_config, reset_less=True)
194 self.o = DivPipeCoreOutputData(core_config, reset_less=True)
195 self.sync = sync
196
197 def elaborate(self, platform):
198 m = Module()
199 stages = [self.setup_stage, *self.calculate_stages, self.final_stage]
200 stage_inputs = [self.i, *self.interstage_signals]
201 stage_outputs = [*self.interstage_signals, self.o]
202 for stage, input, output in zip(stages, stage_inputs, stage_outputs):
203 stage.setup(m, input)
204 assignments = output.eq(stage.process(input))
205 if self.sync:
206 m.d.sync += assignments
207 else:
208 m.d.comb += assignments
209 return m
210
211 def traces(self):
212 yield from self.i
213 # for interstage_signal in self.interstage_signals:
214 # yield from interstage_signal
215 yield from self.o
216
217
218 def trace_process(process, prefix="trace:", silent=False):
219 def generator():
220 if inspect.isgeneratorfunction(process):
221 proc = process()
222 else:
223 proc = process
224 response = None
225 while True:
226 try:
227 command = proc.send(response)
228 if not silent:
229 print(prefix, command)
230 except StopIteration:
231 return
232 except Exception as e:
233 if not silent:
234 print(prefix, "raised:", e)
235 raise e
236 response = (yield command)
237 if not silent:
238 print(prefix, "->", response)
239 return generator
240
241
242 class TestDivPipeCore(unittest.TestCase):
243 def handle_config(self,
244 core_config,
245 test_cases=None,
246 sync=True):
247 if test_cases is None:
248 test_cases = get_test_cases(core_config)
249 test_cases = list(test_cases)
250 base_name = f"test_div_pipe_core_bit_width_{core_config.bit_width}"
251 base_name += f"_fract_width_{core_config.fract_width}"
252 base_name += f"_radix_{1 << core_config.log2_radix}"
253 if not sync:
254 base_name += "_comb"
255 if core_config.supported != frozenset(DivPipeCoreOperation):
256 name_map = {
257 DivPipeCoreOperation.UDivRem: "div",
258 DivPipeCoreOperation.SqrtRem: "sqrt",
259 DivPipeCoreOperation.RSqrtRem: "rsqrt",
260 }
261 # loop using iter(DivPipeCoreOperation) to maintain order
262 for op in DivPipeCoreOperation:
263 if op in core_config.supported:
264 base_name += f"_{name_map[op]}"
265 base_name+="_only"
266
267 with self.subTest(part="synthesize"):
268 dut = DivPipeCoreTestPipeline(core_config, sync)
269 vl = rtlil.convert(dut, ports=[*dut.i, *dut.o])
270 with open(f"{base_name}.il", "w") as f:
271 f.write(vl)
272 dut = DivPipeCoreTestPipeline(core_config, sync)
273 sim = Simulator(dut)
274 with sim.write_vcd(vcd_file=open(f"{base_name}.vcd", "w"),
275 gtkw_file=open(f"{base_name}.gtkw", "w"),
276 traces=[*dut.traces()]):
277 def generate_process():
278 if not sync:
279 yield Delay(1e-6)
280 for test_case in test_cases:
281 if sync:
282 yield Tick()
283 yield dut.i.dividend.eq(test_case.dividend)
284 yield dut.i.divisor_radicand.eq(test_case.divisor_radicand)
285 yield dut.i.operation.eq(int(test_case.core_op))
286 if sync:
287 yield Delay(0.9e-6)
288 else:
289 yield Delay(1e-6)
290
291 def check_process():
292 # sync with generator
293 if sync:
294 yield Tick()
295 for _ in range(core_config.n_stages):
296 yield Tick()
297 yield Tick()
298 else:
299 yield Delay(0.5e-6)
300
301 # now synched with generator
302 for test_case in test_cases:
303 if sync:
304 yield Tick()
305 yield Delay(0.9e-6)
306 else:
307 yield Delay(1e-6)
308 quotient_root = (yield dut.o.quotient_root)
309 remainder = (yield dut.o.remainder)
310 with self.subTest(test_case=str(test_case)):
311 self.assertEqual(quotient_root,
312 test_case.quotient_root,
313 str(test_case))
314 self.assertEqual(remainder, test_case.remainder,
315 str(test_case))
316 if sync:
317 sim.add_clock(2e-6)
318 silent = True
319 sim.add_process(trace_process(generate_process, "generate:", silent=silent))
320 sim.add_process(trace_process(check_process, "check:", silent=silent))
321 sim.run()
322
323 def test_bit_width_2_fract_width_1_radix_2_comb(self):
324 self.handle_config(DivPipeCoreConfig(bit_width=2,
325 fract_width=1,
326 log2_radix=1),
327 sync=False)
328
329 def test_bit_width_2_fract_width_1_radix_2(self):
330 self.handle_config(DivPipeCoreConfig(bit_width=2,
331 fract_width=1,
332 log2_radix=1))
333
334 def test_bit_width_8_fract_width_4_radix_2_comb(self):
335 self.handle_config(DivPipeCoreConfig(bit_width=8,
336 fract_width=4,
337 log2_radix=1),
338 sync=False)
339
340 def test_bit_width_8_fract_width_4_radix_2(self):
341 self.handle_config(DivPipeCoreConfig(bit_width=8,
342 fract_width=4,
343 log2_radix=1))
344
345 def test_bit_width_8_fract_width_4_radix_4_comb(self):
346 self.handle_config(DivPipeCoreConfig(bit_width=8,
347 fract_width=4,
348 log2_radix=2),
349 sync=False)
350
351 def test_bit_width_8_fract_width_4_radix_4(self):
352 self.handle_config(DivPipeCoreConfig(bit_width=8,
353 fract_width=4,
354 log2_radix=2))
355
356 def test_bit_width_8_fract_width_4_radix_4_div_only(self):
357 supported = (DivPipeCoreOperation.UDivRem,)
358 self.handle_config(DivPipeCoreConfig(bit_width=8,
359 fract_width=4,
360 log2_radix=2,
361 supported=supported))
362
363 def test_bit_width_8_fract_width_4_radix_4_comb_div_only(self):
364 supported = (DivPipeCoreOperation.UDivRem,)
365 self.handle_config(DivPipeCoreConfig(bit_width=8,
366 fract_width=4,
367 log2_radix=2,
368 supported=supported),
369 sync=False)
370
371 @unittest.skip("really slow")
372 def test_bit_width_32_fract_width_24_radix_8_comb(self):
373 self.handle_config(DivPipeCoreConfig(bit_width=32,
374 fract_width=24,
375 log2_radix=3),
376 sync=False)
377
378 @unittest.skip("really slow")
379 def test_bit_width_32_fract_width_24_radix_8(self):
380 self.handle_config(DivPipeCoreConfig(bit_width=32,
381 fract_width=24,
382 log2_radix=3))
383
384 @unittest.skip("really slow")
385 def test_bit_width_32_fract_width_28_radix_8_comb(self):
386 self.handle_config(DivPipeCoreConfig(bit_width=32,
387 fract_width=28,
388 log2_radix=3),
389 sync=False)
390
391 @unittest.skip("really slow")
392 def test_bit_width_32_fract_width_28_radix_8(self):
393 self.handle_config(DivPipeCoreConfig(bit_width=32,
394 fract_width=28,
395 log2_radix=3))
396
397 # FIXME: add more test_* functions
398
399
400 if __name__ == '__main__':
401 unittest.main()