Fix bug with comparisons in selectable_int.py
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if isinstance(b, FieldSelectableInt):
9 b = b.get_range()
10 if isinstance(b, int):
11 return SelectableInt(b, a.bits)
12 if b.bits != 256:
13 return b
14 return SelectableInt(b.value, a.bits)
15
16
17 class FieldSelectableInt:
18 """FieldSelectableInt: allows bit-range selection onto another target
19 """
20 def __init__(self, si, br):
21 self.si = si # target selectable int
22 if isinstance(br, list) or isinstance(br, tuple):
23 _br = BitRange()
24 for i, v in enumerate(br):
25 _br[i] = v
26 br = _br
27 self.br = br # map of indices.
28
29 def eq(self, b):
30 if isinstance(b, SelectableInt):
31 for i in range(b.bits):
32 self[i] = b[i]
33 else:
34 self.si = copy(b.si)
35 self.br = copy(b.br)
36
37 def _op(self, op, b):
38 vi = self.get_range()
39 vi = op(vi, b)
40 return self.merge(vi)
41
42 def _op1(self, op):
43 vi = self.get_range()
44 vi = op(vi)
45 return self.merge(vi)
46
47 def __getitem__(self, key):
48 print ("getitem", key, self.br)
49 if isinstance(key, SelectableInt):
50 key = key.value
51 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
52 return self.si[key]
53
54 def __setitem__(self, key, value):
55 if isinstance(key, SelectableInt):
56 key = key.value
57 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
58 return self.si.__setitem__(key, value)
59
60 def __negate__(self):
61 return self._op1(negate)
62 def __invert__(self):
63 return self._op1(inv)
64 def __add__(self, b):
65 return self._op(add, b)
66 def __sub__(self, b):
67 return self._op(sub, b)
68 def __mul__(self, b):
69 return self._op(mul, b)
70 def __div__(self, b):
71 return self._op(truediv, b)
72 def __mod__(self, b):
73 return self._op(mod, b)
74 def __and__(self, b):
75 return self._op(and_, b)
76 def __or__(self, b):
77 return self._op(or_, b)
78 def __xor__(self, b):
79 return self._op(xor, b)
80
81 def get_range(self):
82 print ("get_range", self.si)
83 vi = SelectableInt(0, len(self.br))
84 for k, v in self.br.items():
85 print ("get_range", k, v, self.si[v])
86 vi[k] = self.si[v]
87 print ("get_range", vi)
88 return vi
89
90 def merge(self, vi):
91 fi = copy(self)
92 for i, v in fi.br.items():
93 fi.si[v] = vi[i]
94 return fi
95
96 def __repr__(self):
97 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
98
99
100 class FieldSelectableIntTestCase(unittest.TestCase):
101 def test_arith(self):
102 a = SelectableInt(0b10101, 5)
103 b = SelectableInt(0b011, 3)
104 br = BitRange()
105 br[0] = 0
106 br[1] = 2
107 br[2] = 3
108 fs = FieldSelectableInt(a, br)
109 c = fs + b
110 print (c)
111 #self.assertEqual(c.value, a.value + b.value)
112
113
114 class SelectableInt:
115 def __init__(self, value, bits):
116 mask = (1 << bits) - 1
117 self.value = value & mask
118 self.bits = bits
119
120 def eq(self, b):
121 self.value = b.value
122 self.bits = b.bits
123
124 def __add__(self, b):
125 if isinstance(b, int):
126 b = SelectableInt(b, self.bits)
127 b = check_extsign(self, b)
128 assert b.bits == self.bits
129 return SelectableInt(self.value + b.value, self.bits)
130
131 def __sub__(self, b):
132 if isinstance(b, int):
133 b = SelectableInt(b, self.bits)
134 b = check_extsign(self, b)
135 assert b.bits == self.bits
136 return SelectableInt(self.value - b.value, self.bits)
137
138 def __mul__(self, b):
139 b = check_extsign(self, b)
140 assert b.bits == self.bits
141 return SelectableInt(self.value * b.value, self.bits)
142
143 def __div__(self, b):
144 b = check_extsign(self, b)
145 assert b.bits == self.bits
146 return SelectableInt(self.value / b.value, self.bits)
147
148 def __mod__(self, b):
149 b = check_extsign(self, b)
150 assert b.bits == self.bits
151 return SelectableInt(self.value % b.value, self.bits)
152
153 def __or__(self, b):
154 b = check_extsign(self, b)
155 assert b.bits == self.bits
156 return SelectableInt(self.value | b.value, self.bits)
157
158 def __and__(self, b):
159 print ("__and__", self, b)
160 b = check_extsign(self, b)
161 assert b.bits == self.bits
162 return SelectableInt(self.value & b.value, self.bits)
163
164 def __xor__(self, b):
165 b = check_extsign(self, b)
166 assert b.bits == self.bits
167 return SelectableInt(self.value ^ b.value, self.bits)
168
169 def __invert__(self):
170 return SelectableInt(~self.value, self.bits)
171
172 def __neg__(self):
173 return SelectableInt(~self.value + 1, self.bits)
174
175 def __getitem__(self, key):
176 if isinstance(key, int):
177 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
178 assert key >= 0
179 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
180 # MSB is indexed **LOWEST** (sigh)
181 key = self.bits - (key + 1)
182
183 value = (self.value >> key) & 1
184 return SelectableInt(value, 1)
185 elif isinstance(key, slice):
186 assert key.step is None or key.step == 1
187 assert key.start < key.stop
188 assert key.start >= 0
189 assert key.stop <= self.bits
190
191 stop = self.bits - key.start
192 start = self.bits - key.stop
193
194 bits = stop - start
195 #print ("__getitem__ slice num bits", bits)
196 mask = (1 << bits) - 1
197 value = (self.value >> start) & mask
198 return SelectableInt(value, bits)
199
200 def __setitem__(self, key, value):
201 if isinstance(key, int):
202 assert key < self.bits
203 assert key >= 0
204 key = self.bits - (key + 1)
205 if isinstance(value, SelectableInt):
206 assert value.bits == 1
207 value = value.value
208
209 value = value << key
210 mask = 1 << key
211 self.value = (self.value & ~mask) | (value & mask)
212 elif isinstance(key, slice):
213 assert key.step is None or key.step == 1
214 assert key.start < key.stop
215 assert key.start >= 0
216 assert key.stop <= self.bits
217
218 stop = self.bits - key.start
219 start = self.bits - key.stop
220
221 bits = stop - start
222 #print ("__setitem__ slice num bits", bits)
223 if isinstance(value, SelectableInt):
224 assert value.bits == bits, "%d into %d" % (value.bits, bits)
225 value = value.value
226 mask = ((1 << bits) - 1) << start
227 value = value << start
228 self.value = (self.value & ~mask) | (value & mask)
229
230 def __ge__(self, other):
231 if isinstance(other, FieldSelectableInt):
232 other = other.get_range()
233 if isinstance(other, SelectableInt):
234 other = check_extsign(self, other)
235 assert other.bits == self.bits
236 other = other.value
237 if isinstance(other, int):
238 return onebit(self.value >= other.value)
239 assert False
240
241 def __le__(self, other):
242 if isinstance(other, FieldSelectableInt):
243 other = other.get_range()
244 if isinstance(other, SelectableInt):
245 other = check_extsign(self, other)
246 assert other.bits == self.bits
247 other = other.value
248 if isinstance(other, int):
249 return onebit(self.value <= other)
250 assert False
251
252 def __gt__(self, other):
253 if isinstance(other, FieldSelectableInt):
254 other = other.get_range()
255 if isinstance(other, SelectableInt):
256 other = check_extsign(self, other)
257 assert other.bits == self.bits
258 other = other.value
259 if isinstance(other, int):
260 return onebit(self.value > other)
261 assert False
262
263 def __lt__(self, other):
264 if isinstance(other, FieldSelectableInt):
265 other = other.get_range()
266 if isinstance(other, SelectableInt):
267 other = check_extsign(self, other)
268 assert other.bits == self.bits
269 other = other.value
270 if isinstance(other, int):
271 return onebit(self.value < other)
272 assert False
273
274 def __eq__(self, other):
275 print ("__eq__", self, other)
276 if isinstance(other, FieldSelectableInt):
277 other = other.get_range()
278 if isinstance(other, SelectableInt):
279 other = check_extsign(self, other)
280 assert other.bits == self.bits
281 other = other.value
282 if isinstance(other, int):
283 return onebit(other == self.value)
284 assert False
285
286 def narrow(self, bits):
287 assert bits <= self.bits
288 return SelectableInt(self.value, bits)
289
290 def __bool__(self):
291 return self.value != 0
292
293 def __repr__(self):
294 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
295 self.bits)
296
297 def __len__(self):
298 return self.bits
299
300 def onebit(bit):
301 return SelectableInt(1 if bit else 0, 1)
302
303 def selectltu(lhs, rhs):
304 """ less-than (unsigned)
305 """
306 if isinstance(rhs, SelectableInt):
307 rhs = rhs.value
308 return onebit(lhs.value < rhs)
309
310 def selectgtu(lhs, rhs):
311 """ greater-than (unsigned)
312 """
313 if isinstance(rhs, SelectableInt):
314 rhs = rhs.value
315 return onebit(lhs.value > rhs)
316
317
318 # XXX this probably isn't needed...
319 def selectassign(lhs, idx, rhs):
320 if isinstance(idx, tuple):
321 if len(idx) == 2:
322 lower, upper = idx
323 step = None
324 else:
325 lower, upper, step = idx
326 toidx = range(lower, upper, step)
327 fromidx = range(0, upper-lower, step) # XXX eurgh...
328 else:
329 toidx = [idx]
330 fromidx = [0]
331 for t, f in zip(toidx, fromidx):
332 lhs[t] = rhs[f]
333
334
335 def selectconcat(*args, repeat=1):
336 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
337 args = [SelectableInt(args[0], 1)]
338 if repeat != 1: # multiplies the incoming arguments
339 tmp = []
340 for i in range(repeat):
341 tmp += args
342 args = tmp
343 res = copy(args[0])
344 for i in args[1:]:
345 if isinstance(i, FieldSelectableInt):
346 i = i.si
347 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
348 res.bits += i.bits
349 res.value = (res.value << i.bits) | i.value
350 print ("concat", repeat, res)
351 return res
352
353
354 class SelectableIntTestCase(unittest.TestCase):
355 def test_arith(self):
356 a = SelectableInt(5, 8)
357 b = SelectableInt(9, 8)
358 c = a + b
359 d = a - b
360 e = a * b
361 f = -a
362 self.assertEqual(c.value, a.value + b.value)
363 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
364 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
365 self.assertEqual(f.value, (-a.value) & 0xFF)
366 self.assertEqual(c.bits, a.bits)
367 self.assertEqual(d.bits, a.bits)
368 self.assertEqual(e.bits, a.bits)
369
370 def test_logic(self):
371 a = SelectableInt(0x0F, 8)
372 b = SelectableInt(0xA5, 8)
373 c = a & b
374 d = a | b
375 e = a ^ b
376 f = ~a
377 self.assertEqual(c.value, a.value & b.value)
378 self.assertEqual(d.value, a.value | b.value)
379 self.assertEqual(e.value, a.value ^ b.value)
380 self.assertEqual(f.value, 0xF0)
381
382 def test_get(self):
383 a = SelectableInt(0xa2, 8)
384 # These should be big endian
385 self.assertEqual(a[7], 0)
386 self.assertEqual(a[0:4], 10)
387 self.assertEqual(a[4:8], 2)
388
389 def test_set(self):
390 a = SelectableInt(0x5, 8)
391 a[7] = SelectableInt(0, 1)
392 self.assertEqual(a, 4)
393 a[4:8] = 9
394 self.assertEqual(a, 9)
395 a[0:4] = 3
396 self.assertEqual(a, 0x39)
397 a[0:4] = a[4:8]
398 self.assertEqual(a, 0x99)
399
400 def test_concat(self):
401 a = SelectableInt(0x1, 1)
402 c = selectconcat(a, repeat=8)
403 self.assertEqual(c, 0xff)
404 self.assertEqual(c.bits, 8)
405 a = SelectableInt(0x0, 1)
406 c = selectconcat(a, repeat=8)
407 self.assertEqual(c, 0x00)
408 self.assertEqual(c.bits, 8)
409
410 def test_repr(self):
411 for i in range(65536):
412 a = SelectableInt(i, 16)
413 b = eval(repr(a))
414 self.assertEqual(a, b)
415
416 def test_cmp(self):
417 a = SelectableInt(10, bits=8)
418 b = SelectableInt(5, bits=8)
419 self.assertTrue(a > b)
420 self.assertFalse(a < b)
421 self.assertTrue(a != b)
422 self.assertFalse(a == b)
423
424 def test_unsigned(self):
425 a = SelectableInt(0x80, bits=8)
426 b = SelectableInt(0x7f, bits=8)
427 self.assertTrue(a > b)
428 self.assertFalse(a < b)
429 self.assertTrue(a != b)
430 self.assertFalse(a == b)
431
432 if __name__ == "__main__":
433 unittest.main()