reduce range of b in shift test
[ieee754fpu.git] / src / ieee754 / part / test / test_partsig.py
1 #!/usr/bin/env python3
2 # SPDX-License-Identifier: LGPL-2.1-or-later
3 # See Notices.txt for copyright information
4
5 from nmigen import Signal, Module, Elaboratable
6 from nmigen.back.pysim import Simulator, Delay
7 from nmigen.cli import rtlil
8
9 from ieee754.part.partsig import PartitionedSignal
10 from ieee754.part_mux.part_mux import PMux
11
12 from random import randint
13 import unittest
14 import itertools
15 import math
16
17 def first_zero(x):
18 res = 0
19 for i in range(16):
20 if x & (1<<i):
21 return res
22 res += 1
23
24 def count_bits(x):
25 res = 0
26 for i in range(16):
27 if x & (1<<i):
28 res += 1
29 return res
30
31
32 def perms(k):
33 return map(''.join, itertools.product('01', repeat=k))
34
35
36 def create_ilang(dut, traces, test_name):
37 vl = rtlil.convert(dut, ports=traces)
38 with open("%s.il" % test_name, "w") as f:
39 f.write(vl)
40
41
42 def create_simulator(module, traces, test_name):
43 create_ilang(module, traces, test_name)
44 return Simulator(module,
45 vcd_file=open(test_name + ".vcd", "w"),
46 gtkw_file=open(test_name + ".gtkw", "w"),
47 traces=traces)
48
49
50 class TestAddMod(Elaboratable):
51 def __init__(self, width, partpoints):
52 self.partpoints = partpoints
53 self.a = PartitionedSignal(partpoints, width)
54 self.b = PartitionedSignal(partpoints, width)
55 self.add_output = Signal(width)
56 self.ls_output = Signal(width) # left shift
57 self.sub_output = Signal(width)
58 self.eq_output = Signal(len(partpoints)+1)
59 self.gt_output = Signal(len(partpoints)+1)
60 self.ge_output = Signal(len(partpoints)+1)
61 self.ne_output = Signal(len(partpoints)+1)
62 self.lt_output = Signal(len(partpoints)+1)
63 self.le_output = Signal(len(partpoints)+1)
64 self.mux_sel = Signal(len(partpoints)+1)
65 self.mux_out = Signal(width)
66 self.carry_in = Signal(len(partpoints)+1)
67 self.add_carry_out = Signal(len(partpoints)+1)
68 self.sub_carry_out = Signal(len(partpoints)+1)
69 self.neg_output = Signal(width)
70
71 def elaborate(self, platform):
72 m = Module()
73 comb = m.d.comb
74 self.a.set_module(m)
75 self.b.set_module(m)
76 # compares
77 comb += self.lt_output.eq(self.a < self.b)
78 comb += self.ne_output.eq(self.a != self.b)
79 comb += self.le_output.eq(self.a <= self.b)
80 comb += self.gt_output.eq(self.a > self.b)
81 comb += self.eq_output.eq(self.a == self.b)
82 comb += self.ge_output.eq(self.a >= self.b)
83 # add
84 add_out, add_carry = self.a.add_op(self.a, self.b,
85 self.carry_in)
86 comb += self.add_output.eq(add_out)
87 comb += self.add_carry_out.eq(add_carry)
88 # sub
89 sub_out, sub_carry = self.a.sub_op(self.a, self.b,
90 self.carry_in)
91 comb += self.sub_output.eq(sub_out)
92 comb += self.sub_carry_out.eq(sub_carry)
93 # neg
94 comb += self.neg_output.eq(-self.a)
95 # left shift
96 comb += self.ls_output.eq(self.a << self.b)
97 ppts = self.partpoints
98 comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
99
100 return m
101
102
103 class TestPartitionPoints(unittest.TestCase):
104 def test(self):
105 width = 16
106 part_mask = Signal(4) # divide into 4-bits
107 module = TestAddMod(width, part_mask)
108
109 sim = create_simulator(module,
110 [part_mask,
111 module.a.sig,
112 module.b.sig,
113 module.add_output,
114 module.eq_output],
115 "part_sig_add")
116
117 def async_process():
118
119 def test_ls_fn(carry_in, a, b, mask):
120 # reduce range of b
121 bits = count_bits(mask)
122 fz = first_zero(mask)
123 print ("%x %x %x bits %d zero %d" % \
124 (a, b, mask, bits, fz))
125 b = b & ((bits-1)<<fz)
126 # TODO: carry
127 carry_in = 0
128 lsb = mask & ~(mask-1) if carry_in else 0
129 sum = (a & mask) << (b & mask) + lsb
130 result = mask & sum
131 carry = (sum & mask) != sum
132 carry = 0
133 print(a, b, sum, mask)
134 return result, carry
135
136 def test_add_fn(carry_in, a, b, mask):
137 lsb = mask & ~(mask-1) if carry_in else 0
138 sum = (a & mask) + (b & mask) + lsb
139 result = mask & sum
140 carry = (sum & mask) != sum
141 print(a, b, sum, mask)
142 return result, carry
143
144 def test_sub_fn(carry_in, a, b, mask):
145 lsb = mask & ~(mask-1) if carry_in else 0
146 sum = (a & mask) + (~b & mask) + lsb
147 result = mask & sum
148 carry = (sum & mask) != sum
149 return result, carry
150
151 def test_neg_fn(carry_in, a, b, mask):
152 return test_add_fn(0, a, ~0, mask)
153
154 def test_op(msg_prefix, carry, test_fn, mod_attr, *mask_list):
155 rand_data = []
156 for i in range(100):
157 a, b = randint(0, 1 << 16), randint(0, 1 << 16)
158 rand_data.append((a, b))
159 for a, b in [(0x0000, 0x0000),
160 (0x1234, 0x1234),
161 (0xABCD, 0xABCD),
162 (0xFFFF, 0x0000),
163 (0x0000, 0x0000),
164 (0xFFFF, 0xFFFF),
165 (0x0000, 0xFFFF)] + rand_data:
166 yield module.a.eq(a)
167 yield module.b.eq(b)
168 carry_sig = 0xf if carry else 0
169 yield module.carry_in.eq(carry_sig)
170 yield Delay(0.1e-6)
171 y = 0
172 carry_result = 0
173 for i, mask in enumerate(mask_list):
174 res, c = test_fn(carry, a, b, mask)
175 y |= res
176 lsb = mask & ~(mask - 1)
177 bit_set = int(math.log2(lsb))
178 carry_result |= c << int(bit_set/4)
179 outval = (yield getattr(module, "%s_output" % mod_attr))
180 # TODO: get (and test) carry output as well
181 print(a, b, outval, carry)
182 msg = f"{msg_prefix}: 0x{a:X} {mod_attr} 0x{b:X}" + \
183 f" => 0x{y:X} != 0x{outval:X}"
184 self.assertEqual(y, outval, msg)
185 if hasattr(module, "%s_carry_out" % mod_attr):
186 c_outval = (yield getattr(module,
187 "%s_carry_out" % mod_attr))
188 msg = f"{msg_prefix}: 0x{a:X} {modattr} 0x{b:X}" + \
189 f" => 0x{carry_result:X} != 0x{c_outval:X}"
190 self.assertEqual(carry_result, c_outval, msg)
191
192 for (test_fn, mod_attr) in (
193 (test_ls_fn, "ls"),
194 (test_add_fn, "add"),
195 (test_sub_fn, "sub"),
196 (test_neg_fn, "neg"),
197 ):
198 yield part_mask.eq(0)
199 yield from test_op("16-bit", 1, test_fn, mod_attr, 0xFFFF)
200 yield from test_op("16-bit", 0, test_fn, mod_attr, 0xFFFF)
201 yield part_mask.eq(0b10)
202 yield from test_op("8-bit", 0, test_fn, mod_attr,
203 0xFF00, 0x00FF)
204 yield from test_op("8-bit", 1, test_fn, mod_attr,
205 0xFF00, 0x00FF)
206 yield part_mask.eq(0b1111)
207 yield from test_op("4-bit", 0, test_fn, mod_attr,
208 0xF000, 0x0F00, 0x00F0, 0x000F)
209 yield from test_op("4-bit", 1, test_fn, mod_attr,
210 0xF000, 0x0F00, 0x00F0, 0x000F)
211
212 def test_ne_fn(a, b, mask):
213 return (a & mask) != (b & mask)
214
215 def test_lt_fn(a, b, mask):
216 return (a & mask) < (b & mask)
217
218 def test_le_fn(a, b, mask):
219 return (a & mask) <= (b & mask)
220
221 def test_eq_fn(a, b, mask):
222 return (a & mask) == (b & mask)
223
224 def test_gt_fn(a, b, mask):
225 return (a & mask) > (b & mask)
226
227 def test_ge_fn(a, b, mask):
228 return (a & mask) >= (b & mask)
229
230 def test_binop(msg_prefix, test_fn, mod_attr, *maskbit_list):
231 for a, b in [(0x0000, 0x0000),
232 (0x1234, 0x1234),
233 (0xABCD, 0xABCD),
234 (0xFFFF, 0x0000),
235 (0x0000, 0x0000),
236 (0xFFFF, 0xFFFF),
237 (0x0000, 0xFFFF),
238 (0xABCD, 0xABCE),
239 (0x8000, 0x0000),
240 (0xBEEF, 0xFEED)]:
241 yield module.a.eq(a)
242 yield module.b.eq(b)
243 yield Delay(0.1e-6)
244 # convert to mask_list
245 mask_list = []
246 for mb in maskbit_list:
247 v = 0
248 for i in range(4):
249 if mb & (1 << i):
250 v |= 0xf << (i*4)
251 mask_list.append(v)
252 y = 0
253 # do the partitioned tests
254 for i, mask in enumerate(mask_list):
255 if test_fn(a, b, mask):
256 # OR y with the lowest set bit in the mask
257 y |= maskbit_list[i]
258 # check the result
259 outval = (yield getattr(module, "%s_output" % mod_attr))
260 msg = f"{msg_prefix}: {mod_attr} 0x{a:X} == 0x{b:X}" + \
261 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
262 print((msg % str(maskbit_list)).format(locals()))
263 self.assertEqual(y, outval, msg % str(maskbit_list))
264
265 for (test_fn, mod_attr) in ((test_eq_fn, "eq"),
266 (test_gt_fn, "gt"),
267 (test_ge_fn, "ge"),
268 (test_lt_fn, "lt"),
269 (test_le_fn, "le"),
270 (test_ne_fn, "ne"),
271 ):
272 yield part_mask.eq(0)
273 yield from test_binop("16-bit", test_fn, mod_attr, 0b1111)
274 yield part_mask.eq(0b10)
275 yield from test_binop("8-bit", test_fn, mod_attr,
276 0b1100, 0b0011)
277 yield part_mask.eq(0b1111)
278 yield from test_binop("4-bit", test_fn, mod_attr,
279 0b1000, 0b0100, 0b0010, 0b0001)
280
281 def test_muxop(msg_prefix, *maskbit_list):
282 for a, b in [(0x0000, 0x0000),
283 (0x1234, 0x1234),
284 (0xABCD, 0xABCD),
285 (0xFFFF, 0x0000),
286 (0x0000, 0x0000),
287 (0xFFFF, 0xFFFF),
288 (0x0000, 0xFFFF)]:
289 # convert to mask_list
290 mask_list = []
291 for mb in maskbit_list:
292 v = 0
293 for i in range(4):
294 if mb & (1 << i):
295 v |= 0xf << (i*4)
296 mask_list.append(v)
297
298 # TODO: sel needs to go through permutations of mask_list
299 for p in perms(len(mask_list)):
300
301 sel = 0
302 selmask = 0
303 for i, v in enumerate(p):
304 if v == '1':
305 sel |= maskbit_list[i]
306 selmask |= mask_list[i]
307
308 yield module.a.eq(a)
309 yield module.b.eq(b)
310 yield module.mux_sel.eq(sel)
311 yield Delay(0.1e-6)
312 y = 0
313 # do the partitioned tests
314 for i, mask in enumerate(mask_list):
315 if (selmask & mask):
316 y |= (a & mask)
317 else:
318 y |= (b & mask)
319 # check the result
320 outval = (yield module.mux_out)
321 msg = f"{msg_prefix}: mux " + \
322 f"0x{sel:X} ? 0x{a:X} : 0x{b:X}" + \
323 f" => 0x{y:X} != 0x{outval:X}, masklist %s"
324 # print ((msg % str(maskbit_list)).format(locals()))
325 self.assertEqual(y, outval, msg % str(maskbit_list))
326
327 yield part_mask.eq(0)
328 yield from test_muxop("16-bit", 0b1111)
329 yield part_mask.eq(0b10)
330 yield from test_muxop("8-bit", 0b1100, 0b0011)
331 yield part_mask.eq(0b1111)
332 yield from test_muxop("4-bit", 0b1000, 0b0100, 0b0010, 0b0001)
333
334 sim.add_process(async_process)
335 sim.run()
336
337
338 if __name__ == '__main__':
339 unittest.main()