Add subtraction to partsig.py
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay, Tick, Passive
7 from nmigen.cli import verilog, rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15
16
17 def perms(k):
18 return map(''.join, itertools.product('01', repeat=k))
19
20
21 def create_ilang(dut, traces, test_name):
22 vl = rtlil.convert(dut, ports=traces)
23 with open("%s.il" % test_name, "w") as f:
24 f.write(vl)
25
26
27 def create_simulator(module, traces, test_name):
28 create_ilang(module, traces, test_name)
29 return Simulator(module,
30 vcd_file=open(test_name + ".vcd", "w"),
31 gtkw_file=open(test_name + ".gtkw", "w"),
32 traces=traces)
33
34 class TestAddMod(Elaboratable):
35 def __init__(self, width, partpoints):
36 self.partpoints = partpoints
37 self.a = PartitionedSignal(partpoints, width)
38 self.b = PartitionedSignal(partpoints, width)
39 self.add_output = Signal(width)
40 self.sub_output = Signal(width)
41 self.eq_output = Signal(len(partpoints)+1)
42 self.gt_output = Signal(len(partpoints)+1)
43 self.ge_output = Signal(len(partpoints)+1)
44 self.ne_output = Signal(len(partpoints)+1)
45 self.lt_output = Signal(len(partpoints)+1)
46 self.le_output = Signal(len(partpoints)+1)
47 self.mux_sel = Signal(len(partpoints)+1)
48 self.mux_out = Signal(width)
49 self.carry_in = Signal(len(partpoints)+1)
50 self.add_carry_out = Signal(len(partpoints)+1)
51 self.sub_carry_out = Signal(len(partpoints)+1)
52
53 def elaborate(self, platform):
54 m = Module()
55 self.a.set_module(m)
56 self.b.set_module(m)
57 m.d.comb += self.lt_output.eq(self.a < self.b)
58 m.d.comb += self.ne_output.eq(self.a != self.b)
59 m.d.comb += self.le_output.eq(self.a <= self.b)
60 m.d.comb += self.gt_output.eq(self.a > self.b)
61 m.d.comb += self.eq_output.eq(self.a == self.b)
62 m.d.comb += self.ge_output.eq(self.a >= self.b)
63 # add
64 add_out, add_carry = self.a.add_op(self.a, self.b,
65 self.carry_in)
66 m.d.comb += self.add_output.eq(add_out)
67 m.d.comb += self.add_carry_out.eq(add_carry)
68 if hasattr(self.a, "sub_op"): # TODO, remove this
69 # sub
70 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
71 self.carry_in)
72 m.d.comb += self.sub_output.eq(sub_out)
73 m.d.comb += self.sub_carry_out.eq(add_carry)
74 ppts = self.partpoints
75 m.d.comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
76
77 return m
78
79
80 class TestPartitionPoints(unittest.TestCase):
81 def test(self):
82 width = 16
83 part_mask = Signal(4) # divide into 4-bits
84 module = TestAddMod(width, part_mask)
85
86 sim = create_simulator(module,
87 [part_mask,
88 module.a.sig,
89 module.b.sig,
90 module.add_output,
91 module.eq_output],
92 "part_sig_add")
93 def async_process():
94
95 def test_add_fn(carry_in, a, b, mask):
96 lsb = mask & ~(mask-1) if carry_in else 0
97 return mask & ((a & mask) + (b & mask) + lsb)
98
99 def test_sub_fn(carry_in, a, b, mask):
100 lsb = mask & ~(mask-1) if carry_in else 0
101 return mask & ((a & mask) + (~b & mask) + lsb)
102
103 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
104 rand_data = []
105 for i in range(100):
106 a, b = randint(0, 1<<16), randint(0, 1<<16)
107 rand_data.append((a, b))
108 for a, b in [(0x0000, 0x0000),
109 (0x1234, 0x1234),
110 (0xABCD, 0xABCD),
111 (0xFFFF, 0x0000),
112 (0x0000, 0x0000),
113 (0xFFFF, 0xFFFF),
114 (0x0000, 0xFFFF)] + rand_data:
115 yield module.a.eq(a)
116 yield module.b.eq(b)
117 carry_sig = 0xf if carry else 0
118 yield module.carry_in.eq(carry_sig)
119 yield Delay(0.1e-6)
120 y = 0
121 for i, mask in enumerate(mask_list):
122 y |= test_fn(carry, a, b, mask)
123 outval = (yield getattr(module, "%s_output" % mod_attr))
124 # TODO: get (and test) carry output as well
125 print(a, b, outval, carry)
126 msg = f"{msg_prefix}: 0x{a:X} + 0x{b:X}" + \
127 f" => 0x{y:X} != 0x{outval:X}"
128 self.assertEqual(y, outval, msg)
129
130 for (test_fn, mod_attr) in ((test_add_fn, "add"),
131 (test_sub_fn, "sub"),
132 ):
133 yield part_mask.eq(0)
134 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
135 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
136 yield part_mask.eq(0b10)
137 yield from test_op("8-bit", 0, test_fn, mod_attr,
138 0xFF00, 0x00FF)
139 yield from test_op("8-bit", 1, test_fn, mod_attr,
140 0xFF00, 0x00FF)
141 yield part_mask.eq(0b1111)
142 yield from test_op("4-bit", 0, test_fn, mod_attr,
143 0xF000, 0x0F00, 0x00F0, 0x000F)
144 yield from test_op("4-bit", 1, test_fn, mod_attr,
145 0xF000, 0x0F00, 0x00F0, 0x000F)
146
147 def test_ne_fn(a, b, mask):
148 return (a & mask) != (b & mask)
149
150 def test_lt_fn(a, b, mask):
151 return (a & mask) < (b & mask)
152
153 def test_le_fn(a, b, mask):
154 return (a & mask) <= (b & mask)
155
156 def test_eq_fn(a, b, mask):
157 return (a & mask) == (b & mask)
158
159 def test_gt_fn(a, b, mask):
160 return (a & mask) > (b & mask)
161
162 def test_ge_fn(a, b, mask):
163 return (a & mask) >= (b & mask)
164
165 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
166 for a, b in [(0x0000, 0x0000),
167 (0x1234, 0x1234),
168 (0xABCD, 0xABCD),
169 (0xFFFF, 0x0000),
170 (0x0000, 0x0000),
171 (0xFFFF, 0xFFFF),
172 (0x0000, 0xFFFF),
173 (0xABCD, 0xABCE),
174 (0x8000, 0x0000),
175 (0xBEEF, 0xFEED)]:
176 yield module.a.eq(a)
177 yield module.b.eq(b)
178 yield Delay(0.1e-6)
179 # convert to mask_list
180 mask_list = []
181 for mb in maskbit_list:
182 v = 0
183 for i in range(4):
184 if mb & (1<<i):
185 v |= 0xf << (i*4)
186 mask_list.append(v)
187 y = 0
188 # do the partitioned tests
189 for i, mask in enumerate(mask_list):
190 if test_fn(a, b, mask):
191 # OR y with the lowest set bit in the mask
192 y |= maskbit_list[i]
193 # check the result
194 outval = (yield getattr(module, "%s_output" % mod_attr))
195 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
196 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
197 print ((msg % str(maskbit_list)).format(locals()))
198 self.assertEqual(y, outval, msg % str(maskbit_list))
199
200 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
201 (test_gt_fn, "gt"),
202 (test_ge_fn, "ge"),
203 (test_lt_fn, "lt"),
204 (test_le_fn, "le"),
205 (test_ne_fn, "ne"),
206 ):
207 yield part_mask.eq(0)
208 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
209 yield part_mask.eq(0b10)
210 yield from test_binop("8-bit", test_fn, mod_attr,
211 0b1100, 0b0011)
212 yield part_mask.eq(0b1111)
213 yield from test_binop("4-bit", test_fn, mod_attr,
214 0b1000, 0b0100, 0b0010, 0b0001)
215
216 def test_muxop(msg_prefix, *maskbit_list):
217 for a, b in [(0x0000, 0x0000),
218 (0x1234, 0x1234),
219 (0xABCD, 0xABCD),
220 (0xFFFF, 0x0000),
221 (0x0000, 0x0000),
222 (0xFFFF, 0xFFFF),
223 (0x0000, 0xFFFF)]:
224 # convert to mask_list
225 mask_list = []
226 for mb in maskbit_list:
227 v = 0
228 for i in range(4):
229 if mb & (1<<i):
230 v |= 0xf << (i*4)
231 mask_list.append(v)
232
233 # TODO: sel needs to go through permutations of mask_list
234 for p in perms(len(mask_list)):
235
236 sel = 0
237 selmask = 0
238 for i, v in enumerate(p):
239 if v == '1':
240 sel |= maskbit_list[i]
241 selmask |= mask_list[i]
242
243 yield module.a.eq(a)
244 yield module.b.eq(b)
245 yield module.mux_sel.eq(sel)
246 yield Delay(0.1e-6)
247 y = 0
248 # do the partitioned tests
249 for i, mask in enumerate(mask_list):
250 if (selmask & mask):
251 y |= (a & mask)
252 else:
253 y |= (b & mask)
254 # check the result
255 outval = (yield module.mux_out)
256 msg = f"{msg_prefix}: mux " + \
257 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
258 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
259 #print ((msg % str(maskbit_list)).format(locals()))
260 self.assertEqual(y, outval, msg % str(maskbit_list))
261
262 yield part_mask.eq(0)
263 yield from test_muxop("16-bit", 0b1111)
264 yield part_mask.eq(0b10)
265 yield from test_muxop("8-bit", 0b1100, 0b0011)
266 yield part_mask.eq(0b1111)
267 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
268
269 sim.add_process(async_process)
270 sim.run()
271
272 if __name__ == '__main__':
273 unittest.main()
274