1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
4 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
5 # of Horizon 2020 EU Programme 957073.
7 from nmutil
.plain_data
import fields
, replace
10 from nmutil
.formaltest
import FHDLTestCase
11 from nmutil
.sim_util
import do_sim
, hash_256
12 from nmigen
.sim
import Tick
, Delay
13 from nmigen
.hdl
.ast
import Signal
14 from nmigen
.hdl
.dsl
import Module
15 from soc
.fu
.div
.experiment
.goldschmidt_div_sqrt
import (
16 GoldschmidtDivHDL
, GoldschmidtDivHDLState
, GoldschmidtDivParams
,
17 GoldschmidtDivState
, ParamsNotAccurateEnough
, goldschmidt_div
,
18 FixedPoint
, RoundDir
, goldschmidt_sqrt_rsqrt
)
21 class TestFixedPoint(FHDLTestCase
):
22 def test_str_roundtrip(self
):
23 for frac_wid
in range(8):
24 for bits
in range(-1 << 9, 1 << 9):
25 with self
.subTest(bits
=hex(bits
), frac_wid
=frac_wid
):
26 value
= FixedPoint(bits
, frac_wid
)
27 round_trip_value
= FixedPoint
.cast(str(value
))
28 self
.assertEqual(value
, round_trip_value
)
34 except (ValueError, ZeroDivisionError) as e
:
35 return None, e
.__class
__.__name
__
38 for frac_wid
in range(8):
39 for bits
in range(1 << 9):
40 for round_dir
in RoundDir
:
41 radicand
= FixedPoint(bits
, frac_wid
)
42 expected_f
= math
.sqrt(float(radicand
))
43 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
44 expected_f
, frac_wid
, round_dir
))
45 with self
.subTest(radicand
=repr(radicand
),
46 round_dir
=str(round_dir
),
47 expected
=repr(expected
)):
48 result
= self
.trap(lambda: radicand
.sqrt(round_dir
))
49 self
.assertEqual(result
, expected
)
52 for frac_wid
in range(8):
53 for bits
in range(1, 1 << 9):
54 for round_dir
in RoundDir
:
55 radicand
= FixedPoint(bits
, frac_wid
)
56 expected_f
= 1 / math
.sqrt(float(radicand
))
57 expected
= self
.trap(lambda: FixedPoint
.with_frac_wid(
58 expected_f
, frac_wid
, round_dir
))
59 with self
.subTest(radicand
=repr(radicand
),
60 round_dir
=str(round_dir
),
61 expected
=repr(expected
)):
62 result
= self
.trap(lambda: radicand
.rsqrt(round_dir
))
63 self
.assertEqual(result
, expected
)
66 class TestGoldschmidtDiv(FHDLTestCase
):
68 with self
.assertRaises(ParamsNotAccurateEnough
):
69 GoldschmidtDivParams(io_width
=3, extra_precision
=2,
70 table_addr_bits
=3, table_data_bits
=5,
74 with self
.assertRaises(ParamsNotAccurateEnough
):
75 GoldschmidtDivParams(io_width
=4, extra_precision
=1,
76 table_addr_bits
=1, table_data_bits
=5,
80 def cases(io_width
, cases
=None):
81 assert isinstance(io_width
, int) and io_width
>= 1
84 assert isinstance(d
, int) \
85 and 0 < d
< (1 << io_width
), "invalid case"
86 assert isinstance(n
, int) \
87 and 0 <= n
< (d
<< io_width
), "invalid case"
90 assert io_width
* 2 <= 256, \
91 "can't generate big enough numbers for test cases"
92 for i
in range(10000):
93 d
= hash_256(f
'd {i}') % (1 << io_width
)
96 n
= hash_256(f
'n {i}') % (d
<< io_width
)
99 for d
in range(1, 1 << io_width
):
100 for n
in range(d
<< io_width
):
103 def tst(self
, io_width
, cases
=None):
104 assert isinstance(io_width
, int)
105 params
= GoldschmidtDivParams
.get(io_width
)
106 with self
.subTest(params
=str(params
)):
107 for n
, d
in self
.cases(io_width
, cases
):
108 expected_q
, expected_r
= divmod(n
, d
)
109 with self
.subTest(n
=hex(n
), d
=hex(d
),
110 expected_q
=hex(expected_q
),
111 expected_r
=hex(expected_r
)):
115 assert isinstance(state
, GoldschmidtDivState
)
116 trace
.append((replace(state
)))
117 q
, r
= goldschmidt_div(n
, d
, params
, trace
=trace_fn
)
118 with self
.subTest(q
=hex(q
), r
=hex(r
), trace
=repr(trace
)):
119 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
121 def tst_sim(self
, io_width
, cases
=None, pipe_reg_indexes
=(),
123 assert isinstance(io_width
, int)
124 params
= GoldschmidtDivParams
.get(io_width
)
126 dut
= GoldschmidtDivHDL(params
, pipe_reg_indexes
=pipe_reg_indexes
,
128 m
.submodules
.dut
= dut
129 # make sync domain get added
130 m
.d
.sync
+= Signal().eq(0)
134 for n
, d
in self
.cases(io_width
, cases
):
139 def check_interals(n
, d
):
140 # check internals only if dut is completely combinatorial
141 # so we don't have to figure out how to read values in
142 # previous clock cycles
143 if dut
.total_pipeline_registers
!= 0:
147 def ref_trace_fn(state
):
148 assert isinstance(state
, GoldschmidtDivState
)
149 ref_trace
.append((replace(state
)))
150 goldschmidt_div(n
=n
, d
=d
, params
=params
, trace
=ref_trace_fn
)
151 self
.assertEqual(len(dut
.trace
), len(ref_trace
))
152 for index
, state
in enumerate(dut
.trace
):
153 ref_state
= ref_trace
[index
]
154 last_op
= None if index
== 0 else params
.ops
[index
- 1]
155 with self
.subTest(index
=index
, state
=repr(state
),
156 ref_state
=repr(ref_state
),
157 last_op
=str(last_op
)):
158 for field
in fields(GoldschmidtDivHDLState
):
159 sig
= getattr(state
, field
)
160 if not isinstance(sig
, Signal
):
162 ref_value
= getattr(ref_state
, field
)
163 ref_value_str
= repr(ref_value
)
164 if isinstance(ref_value
, int):
165 ref_value_str
= hex(ref_value
)
167 with self
.subTest(field_name
=field
,
169 sig_shape
=repr(sig
.shape()),
171 ref_value
=ref_value_str
):
172 if isinstance(ref_value
, int):
173 self
.assertEqual(value
, ref_value
)
175 assert isinstance(ref_value
, FixedPoint
)
176 self
.assertEqual(value
, ref_value
.bits
)
180 for _
in range(dut
.total_pipeline_registers
):
182 for n
, d
in self
.cases(io_width
, cases
):
184 expected_q
, expected_r
= divmod(n
, d
)
185 with self
.subTest(n
=hex(n
), d
=hex(d
),
186 expected_q
=hex(expected_q
),
187 expected_r
=hex(expected_r
)):
190 with self
.subTest(q
=hex(q
), r
=hex(r
)):
191 self
.assertEqual((q
, r
), (expected_q
, expected_r
))
192 yield from check_interals(n
, d
)
196 with self
.subTest(params
=str(params
)):
197 with
do_sim(self
, m
, (dut
.n
, dut
.d
, dut
.q
, dut
.r
)) as sim
:
199 sim
.add_process(inputs_proc
)
200 sim
.add_process(check_outputs
)
203 def test_1_through_4(self
):
204 for io_width
in range(1, 4 + 1):
205 with self
.subTest(io_width
=io_width
):
226 def test_sim_5(self
):
229 def test_sim_8(self
):
232 def test_sim_16(self
):
235 def test_sim_32(self
):
238 def test_sim_64(self
):
241 def tst_params(self
, io_width
):
242 assert isinstance(io_width
, int)
243 params
= GoldschmidtDivParams
.get(io_width
)
247 def test_params_1(self
):
250 def test_params_2(self
):
253 def test_params_3(self
):
256 def test_params_4(self
):
259 def test_params_5(self
):
262 def test_params_6(self
):
265 def test_params_7(self
):
268 def test_params_8(self
):
271 def test_params_9(self
):
274 def test_params_10(self
):
277 def test_params_11(self
):
280 def test_params_12(self
):
283 def test_params_13(self
):
286 def test_params_14(self
):
289 def test_params_15(self
):
292 def test_params_16(self
):
295 def test_params_17(self
):
298 def test_params_18(self
):
301 def test_params_19(self
):
304 def test_params_20(self
):
307 def test_params_21(self
):
310 def test_params_22(self
):
313 def test_params_23(self
):
316 def test_params_24(self
):
319 def test_params_25(self
):
322 def test_params_26(self
):
325 def test_params_27(self
):
328 def test_params_28(self
):
331 def test_params_29(self
):
334 def test_params_30(self
):
337 def test_params_31(self
):
340 def test_params_32(self
):
343 def test_params_33(self
):
346 def test_params_34(self
):
349 def test_params_35(self
):
352 def test_params_36(self
):
355 def test_params_37(self
):
358 def test_params_38(self
):
361 def test_params_39(self
):
364 def test_params_40(self
):
367 def test_params_41(self
):
370 def test_params_42(self
):
373 def test_params_43(self
):
376 def test_params_44(self
):
379 def test_params_45(self
):
382 def test_params_46(self
):
385 def test_params_47(self
):
388 def test_params_48(self
):
391 def test_params_49(self
):
394 def test_params_50(self
):
397 def test_params_51(self
):
400 def test_params_52(self
):
403 def test_params_53(self
):
406 def test_params_54(self
):
409 def test_params_55(self
):
412 def test_params_56(self
):
415 def test_params_57(self
):
418 def test_params_58(self
):
421 def test_params_59(self
):
424 def test_params_60(self
):
427 def test_params_61(self
):
430 def test_params_62(self
):
433 def test_params_63(self
):
436 def test_params_64(self
):
440 class TestGoldschmidtSqrtRSqrt(FHDLTestCase
):
441 def tst(self
, io_width
, frac_wid
, extra_precision
,
442 table_addr_bits
, table_data_bits
, iter_count
):
443 assert isinstance(io_width
, int)
444 assert isinstance(frac_wid
, int)
445 assert isinstance(extra_precision
, int)
446 assert isinstance(table_addr_bits
, int)
447 assert isinstance(table_data_bits
, int)
448 assert isinstance(iter_count
, int)
449 with self
.subTest(io_width
=io_width
, frac_wid
=frac_wid
,
450 extra_precision
=extra_precision
,
451 table_addr_bits
=table_addr_bits
,
452 table_data_bits
=table_data_bits
,
453 iter_count
=iter_count
):
454 for bits
in range(1 << io_width
):
455 radicand
= FixedPoint(bits
, frac_wid
)
456 expected_sqrt
= radicand
.sqrt(RoundDir
.DOWN
)
457 expected_rsqrt
= FixedPoint(0, frac_wid
)
459 expected_rsqrt
= radicand
.rsqrt(RoundDir
.DOWN
)
460 with self
.subTest(radicand
=repr(radicand
),
461 expected_sqrt
=repr(expected_sqrt
),
462 expected_rsqrt
=repr(expected_rsqrt
)):
463 sqrt
, rsqrt
= goldschmidt_sqrt_rsqrt(
464 radicand
=radicand
, io_width
=io_width
,
466 extra_precision
=extra_precision
,
467 table_addr_bits
=table_addr_bits
,
468 table_data_bits
=table_data_bits
,
469 iter_count
=iter_count
)
470 with self
.subTest(sqrt
=repr(sqrt
), rsqrt
=repr(rsqrt
)):
471 self
.assertEqual((sqrt
, rsqrt
),
472 (expected_sqrt
, expected_rsqrt
))
475 self
.tst(io_width
=16, frac_wid
=8, extra_precision
=20,
476 table_addr_bits
=4, table_data_bits
=28, iter_count
=4)
479 if __name__
== "__main__":