575d1d8713d9fbbd8897cf4dc05ea5018e20161b
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21
22 def __init__(self, si, br):
23 self.si = si # target selectable int
24 if isinstance(br, list) or isinstance(br, tuple):
25 _br = BitRange()
26 for i, v in enumerate(br):
27 _br[i] = v
28 br = _br
29 self.br = br # map of indices.
30
31 def eq(self, b):
32 if isinstance(b, SelectableInt):
33 for i in range(b.bits):
34 self[i] = b[i]
35 else:
36 self.si = copy(b.si)
37 self.br = copy(b.br)
38
39 def _op(self, op, b):
40 vi = self.get_range()
41 vi = op(vi, b)
42 return self.merge(vi)
43
44 def _op1(self, op):
45 vi = self.get_range()
46 vi = op(vi)
47 return self.merge(vi)
48
49 def __getitem__(self, key):
50 print("getitem", key, self.br)
51 if isinstance(key, SelectableInt):
52 key = key.value
53 if isinstance(key, int):
54 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
55 return self.si[key]
56 if isinstance(key, slice):
57 key = self.br[key]
58 return selectconcat(*[self.si[x] for x in key])
59
60 def __setitem__(self, key, value):
61 if isinstance(key, SelectableInt):
62 key = key.value
63 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
64 if isinstance(key, int):
65 return self.si.__setitem__(key, value)
66 else:
67 if not isinstance(value, SelectableInt):
68 value = SelectableInt(value, bits=len(key))
69 for i, k in enumerate(key):
70 self.si[k] = value[i]
71
72 def __negate__(self):
73 return self._op1(negate)
74
75 def __invert__(self):
76 return self._op1(inv)
77
78 def __add__(self, b):
79 return self._op(add, b)
80
81 def __sub__(self, b):
82 return self._op(sub, b)
83
84 def __mul__(self, b):
85 return self._op(mul, b)
86
87 def __div__(self, b):
88 return self._op(truediv, b)
89
90 def __mod__(self, b):
91 return self._op(mod, b)
92
93 def __and__(self, b):
94 return self._op(and_, b)
95
96 def __or__(self, b):
97 return self._op(or_, b)
98
99 def __xor__(self, b):
100 return self._op(xor, b)
101
102 def get_range(self):
103 vi = SelectableInt(0, len(self.br))
104 for k, v in self.br.items():
105 vi[k] = self.si[v]
106 return vi
107
108 def merge(self, vi):
109 fi = copy(self)
110 for i, v in fi.br.items():
111 fi.si[v] = vi[i]
112 return fi
113
114 def __repr__(self):
115 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
116
117
118 class FieldSelectableIntTestCase(unittest.TestCase):
119 def test_arith(self):
120 a = SelectableInt(0b10101, 5)
121 b = SelectableInt(0b011, 3)
122 br = BitRange()
123 br[0] = 0
124 br[1] = 2
125 br[2] = 3
126 fs = FieldSelectableInt(a, br)
127 c = fs + b
128 print(c)
129 #self.assertEqual(c.value, a.value + b.value)
130
131 def test_select(self):
132 a = SelectableInt(0b00001111, 8)
133 br = BitRange()
134 br[0] = 0
135 br[1] = 1
136 br[2] = 4
137 br[3] = 5
138 fs = FieldSelectableInt(a, br)
139
140 self.assertEqual(fs.get_range(), 0b0011)
141
142 def test_select_range(self):
143 a = SelectableInt(0b00001111, 8)
144 br = BitRange()
145 br[0] = 0
146 br[1] = 1
147 br[2] = 4
148 br[3] = 5
149 fs = FieldSelectableInt(a, br)
150
151 self.assertEqual(fs[2:4], 0b11)
152
153 fs[0:2] = 0b10
154 self.assertEqual(fs.get_range(), 0b1011)
155
156 class SelectableInt:
157 """SelectableInt - a class that behaves exactly like python int
158
159 this class is designed to mirror precisely the behaviour of python int.
160 the only difference is that it must contain the context of the bitwidth
161 (number of bits) associated with that integer.
162
163 FieldSelectableInt can then operate on partial bits, and because there
164 is a bit width associated with SelectableInt, slices operate correctly
165 including negative start/end points.
166 """
167
168 def __init__(self, value, bits):
169 if isinstance(value, SelectableInt):
170 value = value.value
171 mask = (1 << bits) - 1
172 self.value = value & mask
173 self.bits = bits
174 self.overflow = (value & ~mask) != 0
175
176 def eq(self, b):
177 self.value = b.value
178 self.bits = b.bits
179
180 def to_signed_int(self):
181 print ("to signed?", self.value & (1<<(self.bits-1)), self.value)
182 if self.value & (1<<(self.bits-1)) != 0: # negative
183 res = self.value - (1<<self.bits)
184 print (" val -ve:", self.bits, res)
185 else:
186 res = self.value
187 print (" val +ve:", res)
188 return res
189
190 def _op(self, op, b):
191 if isinstance(b, int):
192 b = SelectableInt(b, self.bits)
193 b = check_extsign(self, b)
194 assert b.bits == self.bits
195 return SelectableInt(op(self.value, b.value), self.bits)
196
197 def __add__(self, b):
198 return self._op(add, b)
199
200 def __sub__(self, b):
201 return self._op(sub, b)
202
203 def __mul__(self, b):
204 # different case: mul result needs to fit the total bitsize
205 if isinstance(b, int):
206 b = SelectableInt(b, self.bits)
207 print("SelectableInt mul", hex(self.value), hex(b.value),
208 self.bits, b.bits)
209 return SelectableInt(self.value * b.value, self.bits + b.bits)
210
211 def __floordiv__(self, b):
212 return self._op(floordiv, b)
213
214 def __truediv__(self, b):
215 return self._op(truediv, b)
216
217 def __mod__(self, b):
218 return self._op(mod, b)
219
220 def __and__(self, b):
221 return self._op(and_, b)
222
223 def __or__(self, b):
224 return self._op(or_, b)
225
226 def __xor__(self, b):
227 return self._op(xor, b)
228
229 def __abs__(self):
230 print("abs", self.value & (1 << (self.bits-1)))
231 if self.value & (1 << (self.bits-1)) != 0:
232 return -self
233 return self
234
235 def __rsub__(self, b):
236 if isinstance(b, int):
237 b = SelectableInt(b, self.bits)
238 b = check_extsign(self, b)
239 assert b.bits == self.bits
240 return SelectableInt(b.value - self.value, self.bits)
241
242 def __radd__(self, b):
243 if isinstance(b, int):
244 b = SelectableInt(b, self.bits)
245 b = check_extsign(self, b)
246 assert b.bits == self.bits
247 return SelectableInt(b.value + self.value, self.bits)
248
249 def __rxor__(self, b):
250 b = check_extsign(self, b)
251 assert b.bits == self.bits
252 return SelectableInt(self.value ^ b.value, self.bits)
253
254 def __invert__(self):
255 return SelectableInt(~self.value, self.bits)
256
257 def __neg__(self):
258 res = SelectableInt((~self.value) + 1, self.bits)
259 print ("neg", hex(self.value), hex(res.value))
260 return res
261
262 def __lshift__(self, b):
263 b = check_extsign(self, b)
264 return SelectableInt(self.value << b.value, self.bits)
265
266 def __rshift__(self, b):
267 b = check_extsign(self, b)
268 return SelectableInt(self.value >> b.value, self.bits)
269
270 def __getitem__(self, key):
271 if isinstance(key, SelectableInt):
272 key = key.value
273 print("getitem", key, self.bits, hex(self.value))
274 if isinstance(key, int):
275 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
276 assert key >= 0
277 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
278 # MSB is indexed **LOWEST** (sigh)
279 key = self.bits - (key + 1)
280
281 value = (self.value >> key) & 1
282 return SelectableInt(value, 1)
283 elif isinstance(key, slice):
284 assert key.step is None or key.step == 1
285 assert key.start < key.stop
286 assert key.start >= 0
287 assert key.stop <= self.bits
288
289 stop = self.bits - key.start
290 start = self.bits - key.stop
291
292 bits = stop - start
293 #print ("__getitem__ slice num bits", start, stop, bits)
294 mask = (1 << bits) - 1
295 value = (self.value >> start) & mask
296 return SelectableInt(value, bits)
297
298 def __setitem__(self, key, value):
299 if isinstance(key, SelectableInt):
300 key = key.value
301 print("setitem", key, self.bits, hex(self.value))
302 if isinstance(key, int):
303 assert key < self.bits
304 assert key >= 0
305 key = self.bits - (key + 1)
306 if isinstance(value, SelectableInt):
307 assert value.bits == 1
308 value = value.value
309
310 value = value << key
311 mask = 1 << key
312 self.value = (self.value & ~mask) | (value & mask)
313 elif isinstance(key, slice):
314 assert key.step is None or key.step == 1
315 assert key.start < key.stop
316 assert key.start >= 0
317 assert key.stop <= self.bits
318
319 stop = self.bits - key.start
320 start = self.bits - key.stop
321
322 bits = stop - start
323 #print ("__setitem__ slice num bits", bits)
324 if isinstance(value, SelectableInt):
325 assert value.bits == bits, "%d into %d" % (value.bits, bits)
326 value = value.value
327 mask = ((1 << bits) - 1) << start
328 value = value << start
329 self.value = (self.value & ~mask) | (value & mask)
330
331 def __ge__(self, other):
332 if isinstance(other, FieldSelectableInt):
333 other = other.get_range()
334 if isinstance(other, SelectableInt):
335 other = check_extsign(self, other)
336 assert other.bits == self.bits
337 other = other.to_signed_int()
338 if isinstance(other, int):
339 return onebit(self.to_signed_int() >= other)
340 assert False
341
342 def __le__(self, other):
343 if isinstance(other, FieldSelectableInt):
344 other = other.get_range()
345 if isinstance(other, SelectableInt):
346 other = check_extsign(self, other)
347 assert other.bits == self.bits
348 other = other.to_signed_int()
349 if isinstance(other, int):
350 return onebit(self.to_signed_int() <= other)
351 assert False
352
353 def __gt__(self, other):
354 if isinstance(other, FieldSelectableInt):
355 other = other.get_range()
356 if isinstance(other, SelectableInt):
357 other = check_extsign(self, other)
358 assert other.bits == self.bits
359 other = other.to_signed_int()
360 if isinstance(other, int):
361 return onebit(self.to_signed_int() > other)
362 assert False
363
364 def __lt__(self, other):
365 print ("SelectableInt lt", self, other)
366 if isinstance(other, FieldSelectableInt):
367 other = other.get_range()
368 if isinstance(other, SelectableInt):
369 other = check_extsign(self, other)
370 assert other.bits == self.bits
371 other = other.to_signed_int()
372 if isinstance(other, int):
373 a = self.to_signed_int()
374 res = onebit(a < other)
375 print (" a < b", a, other, res)
376 return res
377 assert False
378
379 def __eq__(self, other):
380 print("__eq__", self, other)
381 if isinstance(other, FieldSelectableInt):
382 other = other.get_range()
383 if isinstance(other, SelectableInt):
384 other = check_extsign(self, other)
385 assert other.bits == self.bits
386 other = other.value
387 print (" eq", other, self.value, other == self.value)
388 if isinstance(other, int):
389 return onebit(other == self.value)
390 assert False
391
392 def narrow(self, bits):
393 assert bits <= self.bits
394 return SelectableInt(self.value, bits)
395
396 def __bool__(self):
397 return self.value != 0
398
399 def __repr__(self):
400 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
401 self.bits)
402
403 def __len__(self):
404 return self.bits
405
406 def asint(self):
407 return self.value
408
409
410 def onebit(bit):
411 return SelectableInt(1 if bit else 0, 1)
412
413
414 def selectltu(lhs, rhs):
415 """ less-than (unsigned)
416 """
417 if isinstance(rhs, SelectableInt):
418 rhs = rhs.value
419 return onebit(lhs.value < rhs)
420
421
422 def selectgtu(lhs, rhs):
423 """ greater-than (unsigned)
424 """
425 if isinstance(rhs, SelectableInt):
426 rhs = rhs.value
427 return onebit(lhs.value > rhs)
428
429
430 # XXX this probably isn't needed...
431 def selectassign(lhs, idx, rhs):
432 if isinstance(idx, tuple):
433 if len(idx) == 2:
434 lower, upper = idx
435 step = None
436 else:
437 lower, upper, step = idx
438 toidx = range(lower, upper, step)
439 fromidx = range(0, upper-lower, step) # XXX eurgh...
440 else:
441 toidx = [idx]
442 fromidx = [0]
443 for t, f in zip(toidx, fromidx):
444 lhs[t] = rhs[f]
445
446
447 def selectconcat(*args, repeat=1):
448 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
449 args = [SelectableInt(args[0], 1)]
450 if repeat != 1: # multiplies the incoming arguments
451 tmp = []
452 for i in range(repeat):
453 tmp += args
454 args = tmp
455 res = copy(args[0])
456 for i in args[1:]:
457 if isinstance(i, FieldSelectableInt):
458 i = i.si
459 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
460 res.bits += i.bits
461 res.value = (res.value << i.bits) | i.value
462 print("concat", repeat, res)
463 return res
464
465
466 class SelectableIntTestCase(unittest.TestCase):
467 def test_arith(self):
468 a = SelectableInt(5, 8)
469 b = SelectableInt(9, 8)
470 c = a + b
471 d = a - b
472 e = a * b
473 f = -a
474 g = abs(f)
475 h = abs(a)
476 self.assertEqual(c.value, a.value + b.value)
477 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
478 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
479 self.assertEqual(f.value, (-a.value) & 0xFF)
480 self.assertEqual(c.bits, a.bits)
481 self.assertEqual(d.bits, a.bits)
482 self.assertEqual(e.bits, a.bits)
483 self.assertEqual(a.bits, f.bits)
484 self.assertEqual(a.bits, h.bits)
485
486 def test_logic(self):
487 a = SelectableInt(0x0F, 8)
488 b = SelectableInt(0xA5, 8)
489 c = a & b
490 d = a | b
491 e = a ^ b
492 f = ~a
493 self.assertEqual(c.value, a.value & b.value)
494 self.assertEqual(d.value, a.value | b.value)
495 self.assertEqual(e.value, a.value ^ b.value)
496 self.assertEqual(f.value, 0xF0)
497
498 def test_get(self):
499 a = SelectableInt(0xa2, 8)
500 # These should be big endian
501 self.assertEqual(a[7], 0)
502 self.assertEqual(a[0:4], 10)
503 self.assertEqual(a[4:8], 2)
504
505 def test_set(self):
506 a = SelectableInt(0x5, 8)
507 a[7] = SelectableInt(0, 1)
508 self.assertEqual(a, 4)
509 a[4:8] = 9
510 self.assertEqual(a, 9)
511 a[0:4] = 3
512 self.assertEqual(a, 0x39)
513 a[0:4] = a[4:8]
514 self.assertEqual(a, 0x99)
515
516 def test_concat(self):
517 a = SelectableInt(0x1, 1)
518 c = selectconcat(a, repeat=8)
519 self.assertEqual(c, 0xff)
520 self.assertEqual(c.bits, 8)
521 a = SelectableInt(0x0, 1)
522 c = selectconcat(a, repeat=8)
523 self.assertEqual(c, 0x00)
524 self.assertEqual(c.bits, 8)
525
526 def test_repr(self):
527 for i in range(65536):
528 a = SelectableInt(i, 16)
529 b = eval(repr(a))
530 self.assertEqual(a, b)
531
532 def test_cmp(self):
533 a = SelectableInt(10, bits=8)
534 b = SelectableInt(5, bits=8)
535 self.assertTrue(a > b)
536 self.assertFalse(a < b)
537 self.assertTrue(a != b)
538 self.assertFalse(a == b)
539
540 def test_unsigned(self):
541 a = SelectableInt(0x80, bits=8)
542 b = SelectableInt(0x7f, bits=8)
543 self.assertTrue(a > b)
544 self.assertFalse(a < b)
545 self.assertTrue(a != b)
546 self.assertFalse(a == b)
547
548
549 if __name__ == "__main__":
550 unittest.main()