7d87eceebac77801362179bdb4a8ca50140ded55
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
5 from nmigen
import Signal
, Module
, Elaboratable
6 from nmigen
.back
.pysim
import Simulator
, Delay
7 from nmigen
.cli
import rtlil
9 from ieee754
.part
.partsig
import PartitionedSignal
10 from ieee754
.part_mux
.part_mux
import PMux
12 from random
import randint
33 return map(''.join
, itertools
.product('01', repeat
=k
))
36 def create_ilang(dut
, traces
, test_name
):
37 vl
= rtlil
.convert(dut
, ports
=traces
)
38 with
open("%s.il" % test_name
, "w") as f
:
42 def create_simulator(module
, traces
, test_name
):
43 create_ilang(module
, traces
, test_name
)
44 return Simulator(module
)
47 # XXX this is for coriolis2 experimentation
48 class TestAddMod2(Elaboratable
):
49 def __init__(self
, width
, partpoints
):
50 self
.partpoints
= partpoints
51 self
.a
= PartitionedSignal(partpoints
, width
)
52 self
.b
= PartitionedSignal(partpoints
, width
)
53 self
.bsig
= Signal(width
)
54 self
.add_output
= Signal(width
)
55 self
.ls_output
= Signal(width
) # left shift
56 self
.ls_scal_output
= Signal(width
) # left shift
57 self
.rs_output
= Signal(width
) # right shift
58 self
.rs_scal_output
= Signal(width
) # right shift
59 self
.sub_output
= Signal(width
)
60 self
.eq_output
= Signal(len(partpoints
)+1)
61 self
.gt_output
= Signal(len(partpoints
)+1)
62 self
.ge_output
= Signal(len(partpoints
)+1)
63 self
.ne_output
= Signal(len(partpoints
)+1)
64 self
.lt_output
= Signal(len(partpoints
)+1)
65 self
.le_output
= Signal(len(partpoints
)+1)
66 self
.mux_sel
= Signal(len(partpoints
)+1)
67 self
.mux_out
= Signal(width
)
68 self
.carry_in
= Signal(len(partpoints
)+1)
69 self
.add_carry_out
= Signal(len(partpoints
)+1)
70 self
.sub_carry_out
= Signal(len(partpoints
)+1)
71 self
.neg_output
= Signal(width
)
73 def elaborate(self
, platform
):
80 sync
+= self
.lt_output
.eq(self
.a
< self
.b
)
81 sync
+= self
.ne_output
.eq(self
.a
!= self
.b
)
82 sync
+= self
.le_output
.eq(self
.a
<= self
.b
)
83 sync
+= self
.gt_output
.eq(self
.a
> self
.b
)
84 sync
+= self
.eq_output
.eq(self
.a
== self
.b
)
85 sync
+= self
.ge_output
.eq(self
.a
>= self
.b
)
87 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
89 sync
+= self
.add_output
.eq(add_out
)
90 sync
+= self
.add_carry_out
.eq(add_carry
)
92 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
94 sync
+= self
.sub_output
.eq(sub_out
)
95 sync
+= self
.sub_carry_out
.eq(sub_carry
)
97 sync
+= self
.neg_output
.eq(-self
.a
)
99 sync
+= self
.ls_output
.eq(self
.a
<< self
.b
)
100 sync
+= self
.rs_output
.eq(self
.a
>> self
.b
)
101 ppts
= self
.partpoints
102 sync
+= self
.mux_out
.eq(PMux(m
, ppts
, self
.mux_sel
, self
.a
, self
.b
))
104 comb
+= self
.bsig
.eq(self
.b
.sig
)
105 sync
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
106 sync
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
111 class TestAddMod(Elaboratable
):
112 def __init__(self
, width
, partpoints
):
113 self
.partpoints
= partpoints
114 self
.a
= PartitionedSignal(partpoints
, width
)
115 self
.b
= PartitionedSignal(partpoints
, width
)
116 self
.bsig
= Signal(width
)
117 self
.add_output
= Signal(width
)
118 self
.ls_output
= Signal(width
) # left shift
119 self
.ls_scal_output
= Signal(width
) # left shift
120 self
.rs_output
= Signal(width
) # right shift
121 self
.rs_scal_output
= Signal(width
) # right shift
122 self
.sub_output
= Signal(width
)
123 self
.eq_output
= Signal(len(partpoints
)+1)
124 self
.gt_output
= Signal(len(partpoints
)+1)
125 self
.ge_output
= Signal(len(partpoints
)+1)
126 self
.ne_output
= Signal(len(partpoints
)+1)
127 self
.lt_output
= Signal(len(partpoints
)+1)
128 self
.le_output
= Signal(len(partpoints
)+1)
129 self
.mux_sel
= Signal(len(partpoints
)+1)
130 self
.mux_out
= Signal(width
)
131 self
.carry_in
= Signal(len(partpoints
)+1)
132 self
.add_carry_out
= Signal(len(partpoints
)+1)
133 self
.sub_carry_out
= Signal(len(partpoints
)+1)
134 self
.neg_output
= Signal(width
)
136 def elaborate(self
, platform
):
143 comb
+= self
.lt_output
.eq(self
.a
< self
.b
)
144 comb
+= self
.ne_output
.eq(self
.a
!= self
.b
)
145 comb
+= self
.le_output
.eq(self
.a
<= self
.b
)
146 comb
+= self
.gt_output
.eq(self
.a
> self
.b
)
147 comb
+= self
.eq_output
.eq(self
.a
== self
.b
)
148 comb
+= self
.ge_output
.eq(self
.a
>= self
.b
)
150 add_out
, add_carry
= self
.a
.add_op(self
.a
, self
.b
,
152 comb
+= self
.add_output
.eq(add_out
)
153 comb
+= self
.add_carry_out
.eq(add_carry
)
155 sub_out
, sub_carry
= self
.a
.sub_op(self
.a
, self
.b
,
157 comb
+= self
.sub_output
.eq(sub_out
)
158 comb
+= self
.sub_carry_out
.eq(sub_carry
)
160 comb
+= self
.neg_output
.eq(-self
.a
)
162 comb
+= self
.ls_output
.eq(self
.a
<< self
.b
)
164 comb
+= self
.rs_output
.eq(self
.a
>> self
.b
)
165 ppts
= self
.partpoints
167 comb
+= self
.mux_out
.eq(PMux(m
, ppts
, self
.mux_sel
, self
.a
, self
.b
))
169 comb
+= self
.bsig
.eq(self
.b
.sig
)
170 comb
+= self
.ls_scal_output
.eq(self
.a
<< self
.bsig
)
172 comb
+= self
.rs_scal_output
.eq(self
.a
>> self
.bsig
)
177 class TestPartitionPoints(unittest
.TestCase
):
180 part_mask
= Signal(4) # divide into 4-bits
181 module
= TestAddMod(width
, part_mask
)
183 test_name
= "part_sig_add"
189 sim
= create_simulator(module
, traces
, test_name
)
193 def test_ls_scal_fn(carry_in
, a
, b
, mask
):
195 bits
= count_bits(mask
)
196 newb
= b
& ((bits
-1))
197 print ("%x %x %x bits %d trunc %x" % \
198 (a
, b
, mask
, bits
, newb
))
202 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
203 sum = ((a
& mask
) << b
)
205 carry
= (sum & mask
) != sum
207 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
210 def test_rs_scal_fn(carry_in
, a
, b
, mask
):
212 bits
= count_bits(mask
)
213 newb
= b
& ((bits
-1))
214 print ("%x %x %x bits %d trunc %x" % \
215 (a
, b
, mask
, bits
, newb
))
219 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
220 sum = ((a
& mask
) >> b
)
222 carry
= (sum & mask
) != sum
224 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
227 def test_ls_fn(carry_in
, a
, b
, mask
):
229 bits
= count_bits(mask
)
230 fz
= first_zero(mask
)
231 newb
= b
& ((bits
-1)<<fz
)
232 print ("%x %x %x bits %d zero %d trunc %x" % \
233 (a
, b
, mask
, bits
, fz
, newb
))
237 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
240 sum = ((a
& mask
) << b
)
242 carry
= (sum & mask
) != sum
244 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
247 def test_rs_fn(carry_in
, a
, b
, mask
):
249 bits
= count_bits(mask
)
250 fz
= first_zero(mask
)
251 newb
= b
& ((bits
-1)<<fz
)
252 print ("%x %x %x bits %d zero %d trunc %x" % \
253 (a
, b
, mask
, bits
, fz
, newb
))
257 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
260 sum = ((a
& mask
) >> b
)
262 carry
= (sum & mask
) != sum
264 print("res", hex(a
), hex(b
), hex(sum), hex(mask
), hex(result
))
267 def test_add_fn(carry_in
, a
, b
, mask
):
268 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
269 sum = (a
& mask
) + (b
& mask
) + lsb
271 carry
= (sum & mask
) != sum
272 print(a
, b
, sum, mask
)
275 def test_sub_fn(carry_in
, a
, b
, mask
):
276 lsb
= mask
& ~
(mask
-1) if carry_in
else 0
277 sum = (a
& mask
) + (~b
& mask
) + lsb
279 carry
= (sum & mask
) != sum
282 def test_neg_fn(carry_in
, a
, b
, mask
):
283 lsb
= mask
& ~
(mask
- 1) # has only LSB of mask set
284 pos
= lsb
.bit_length() - 1 # find bit position
285 a
= (a
& mask
) >> pos
# shift it to the beginning
286 return ((-a
) << pos
) & mask
, 0 # negate and shift it back
288 def test_op(msg_prefix
, carry
, test_fn
, mod_attr
, *mask_list
):
291 a
, b
= randint(0, 1 << 16), randint(0, 1 << 16)
292 rand_data
.append((a
, b
))
293 for a
, b
in [(0x0000, 0x0000),
299 (0x0000, 0xFFFF)] + rand_data
:
302 carry_sig
= 0xf if carry
else 0
303 yield module
.carry_in
.eq(carry_sig
)
307 for i
, mask
in enumerate(mask_list
):
308 print ("i/mask", i
, hex(mask
))
309 res
, c
= test_fn(carry
, a
, b
, mask
)
311 lsb
= mask
& ~
(mask
- 1)
312 bit_set
= int(math
.log2(lsb
))
313 carry_result |
= c
<< int(bit_set
/4)
314 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
315 # TODO: get (and test) carry output as well
316 print(a
, b
, outval
, carry
)
317 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
318 f
" => 0x{y:X} != 0x{outval:X}"
319 self
.assertEqual(y
, outval
, msg
)
320 if hasattr(module
, "%s_carry_out" % mod_attr
):
321 c_outval
= (yield getattr(module
,
322 "%s_carry_out" % mod_attr
))
323 msg
= f
"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
324 f
" => 0x{carry_result:X} != 0x{c_outval:X}"
325 self
.assertEqual(carry_result
, c_outval
, msg
)
327 for (test_fn
, mod_attr
) in (
328 (test_ls_scal_fn
, "ls_scal"),
330 (test_rs_scal_fn
, "rs_scal"),
332 (test_add_fn
, "add"),
333 (test_sub_fn
, "sub"),
334 (test_neg_fn
, "neg"),
336 yield part_mask
.eq(0)
337 yield from test_op("16-bit", 1, test_fn
, mod_attr
, 0xFFFF)
338 yield from test_op("16-bit", 0, test_fn
, mod_attr
, 0xFFFF)
339 yield part_mask
.eq(0b10)
340 yield from test_op("8-bit", 0, test_fn
, mod_attr
,
342 yield from test_op("8-bit", 1, test_fn
, mod_attr
,
344 yield part_mask
.eq(0b1111)
345 yield from test_op("4-bit", 0, test_fn
, mod_attr
,
346 0xF000, 0x0F00, 0x00F0, 0x000F)
347 yield from test_op("4-bit", 1, test_fn
, mod_attr
,
348 0xF000, 0x0F00, 0x00F0, 0x000F)
350 def test_ne_fn(a
, b
, mask
):
351 return (a
& mask
) != (b
& mask
)
353 def test_lt_fn(a
, b
, mask
):
354 return (a
& mask
) < (b
& mask
)
356 def test_le_fn(a
, b
, mask
):
357 return (a
& mask
) <= (b
& mask
)
359 def test_eq_fn(a
, b
, mask
):
360 return (a
& mask
) == (b
& mask
)
362 def test_gt_fn(a
, b
, mask
):
363 return (a
& mask
) > (b
& mask
)
365 def test_ge_fn(a
, b
, mask
):
366 return (a
& mask
) >= (b
& mask
)
368 def test_binop(msg_prefix
, test_fn
, mod_attr
, *maskbit_list
):
369 for a
, b
in [(0x0000, 0x0000),
382 # convert to mask_list
384 for mb
in maskbit_list
:
391 # do the partitioned tests
392 for i
, mask
in enumerate(mask_list
):
393 if test_fn(a
, b
, mask
):
394 # OR y with the lowest set bit in the mask
397 outval
= (yield getattr(module
, "%s_output" % mod_attr
))
398 msg
= f
"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
399 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
400 print((msg
% str(maskbit_list
)).format(locals()))
401 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
403 for (test_fn
, mod_attr
) in ((test_eq_fn
, "eq"),
410 yield part_mask
.eq(0)
411 yield from test_binop("16-bit", test_fn
, mod_attr
, 0b1111)
412 yield part_mask
.eq(0b10)
413 yield from test_binop("8-bit", test_fn
, mod_attr
,
415 yield part_mask
.eq(0b1111)
416 yield from test_binop("4-bit", test_fn
, mod_attr
,
417 0b1000, 0b0100, 0b0010, 0b0001)
419 def test_muxop(msg_prefix
, *maskbit_list
):
420 for a
, b
in [(0x0000, 0x0000),
427 # convert to mask_list
429 for mb
in maskbit_list
:
436 # TODO: sel needs to go through permutations of mask_list
437 for p
in perms(len(mask_list
)):
441 for i
, v
in enumerate(p
):
443 sel |
= maskbit_list
[i
]
444 selmask |
= mask_list
[i
]
448 yield module
.mux_sel
.eq(sel
)
451 # do the partitioned tests
452 for i
, mask
in enumerate(mask_list
):
458 outval
= (yield module
.mux_out
)
459 msg
= f
"{msg_prefix}: mux " + \
460 f
"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
461 f
" => 0x{y:X} != 0x{outval:X}, masklist %s"
462 # print ((msg % str(maskbit_list)).format(locals()))
463 self
.assertEqual(y
, outval
, msg
% str(maskbit_list
))
465 yield part_mask
.eq(0)
466 yield from test_muxop("16-bit", 0b1111)
467 yield part_mask
.eq(0b10)
468 yield from test_muxop("8-bit", 0b1100, 0b0011)
469 yield part_mask
.eq(0b1111)
470 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
472 sim
.add_process(async_process
)
474 vcd_file
=open(test_name
+ ".vcd", "w"),
475 gtkw_file
=open(test_name
+ ".gtkw", "w"),
480 if __name__
== '__main__':