selectable_int: fix multi-dimensional mappings
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from openpower.decoder.power_fields import BitRange
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10
11 def check_extsign(a, b):
12 if isinstance(b, FieldSelectableInt):
13 b = b.get_range()
14 if isinstance(b, int):
15 return SelectableInt(b, a.bits)
16 if b.bits != 256:
17 return b
18 return SelectableInt(b.value, a.bits)
19
20
21 @functools.total_ordering
22 class FieldSelectableInt:
23 """FieldSelectableInt: allows bit-range selection onto another target
24 """
25
26 def __init__(self, si, br):
27 if not isinstance(si, SelectableInt):
28 raise ValueError(si)
29 self.si = si # target selectable int
30 if isinstance(br, (list, tuple, range)):
31 _br = BitRange()
32 for i, v in enumerate(br):
33 _br[i] = v
34 br = _br
35 self.br = br # map of indices.
36
37 def eq(self, b):
38 if isinstance(b, int):
39 # convert integer to same SelectableInt of same bitlength as range
40 blen = len(self.br)
41 b = SelectableInt(b, blen)
42 for i in range(b.bits):
43 self[i] = b[i]
44 elif isinstance(b, SelectableInt):
45 for i in range(b.bits):
46 self[i] = b[i]
47 else:
48 self.si = copy(b.si)
49 self.br = copy(b.br)
50
51 def _op(self, op, b):
52 vi = self.get_range()
53 vi = op(vi, b)
54 return self.merge(vi)
55
56 def _op1(self, op):
57 vi = self.get_range()
58 vi = op(vi)
59 return self.merge(vi)
60
61 def __getitem__(self, key):
62 log("getitem", key, self.br)
63 if isinstance(key, SelectableInt):
64 key = key.value
65 if isinstance(key, int):
66 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
67 return self.si[key]
68 if isinstance(key, slice):
69 key = self.br[key]
70 return selectconcat(*[self.si[x] for x in key])
71
72 def __setitem__(self, key, value):
73 if isinstance(key, SelectableInt):
74 key = key.value
75 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
76 if isinstance(key, int):
77 return self.si.__setitem__(key, value)
78 else:
79 if not isinstance(value, SelectableInt):
80 value = SelectableInt(value, bits=len(key))
81 for i, k in enumerate(key):
82 self.si[k] = value[i]
83
84 def __negate__(self):
85 return self._op1(negate)
86
87 def __invert__(self):
88 return self._op1(inv)
89
90 def __add__(self, b):
91 return self._op(add, b)
92
93 def __sub__(self, b):
94 return self._op(sub, b)
95
96 def __mul__(self, b):
97 return self._op(mul, b)
98
99 def __div__(self, b):
100 return self._op(truediv, b)
101
102 def __mod__(self, b):
103 return self._op(mod, b)
104
105 def __and__(self, b):
106 return self._op(and_, b)
107
108 def __or__(self, b):
109 return self._op(or_, b)
110
111 def __xor__(self, b):
112 return self._op(xor, b)
113
114 def __lt__(self, b):
115 return self._op(lt, b)
116
117 def __eq__(self, b):
118 return self._op(eq, b)
119
120 def get_range(self):
121 vi = SelectableInt(0, len(self.br))
122 for k, v in self.br.items():
123 vi[k] = self.si[v]
124 return vi
125
126 def merge(self, vi):
127 fi = copy(self)
128 for i, v in fi.br.items():
129 fi.si[v] = vi[i]
130 return fi
131
132 def __repr__(self):
133 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
134
135 def asint(self, msb0=False):
136 res = 0
137 brlen = len(self.br)
138 for i, key in self.br.items():
139 log("asint", i, key, self.si[key])
140 bit = self.si[key].value
141 #log("asint", i, key, bit)
142 res |= bit << ((brlen-i-1) if msb0 else i)
143 return res
144
145
146 class FieldSelectableIntTestCase(unittest.TestCase):
147 def test_arith(self):
148 a = SelectableInt(0b10101, 5)
149 b = SelectableInt(0b011, 3)
150 br = BitRange()
151 br[0] = 0
152 br[1] = 2
153 br[2] = 3
154 fs = FieldSelectableInt(a, br)
155 c = fs + b
156 log(c)
157 #self.assertEqual(c.value, a.value + b.value)
158
159 def test_select(self):
160 a = SelectableInt(0b00001111, 8)
161 br = BitRange()
162 br[0] = 0
163 br[1] = 1
164 br[2] = 4
165 br[3] = 5
166 fs = FieldSelectableInt(a, br)
167
168 self.assertEqual(fs.get_range(), 0b0011)
169
170 def test_select_range(self):
171 a = SelectableInt(0b00001111, 8)
172 br = BitRange()
173 br[0] = 0
174 br[1] = 1
175 br[2] = 4
176 br[3] = 5
177 fs = FieldSelectableInt(a, br)
178
179 self.assertEqual(fs[2:4], 0b11)
180
181 fs[0:2] = 0b10
182 self.assertEqual(fs.get_range(), 0b1011)
183
184
185 class SelectableInt:
186 """SelectableInt - a class that behaves exactly like python int
187
188 this class is designed to mirror precisely the behaviour of python int.
189 the only difference is that it must contain the context of the bitwidth
190 (number of bits) associated with that integer.
191
192 FieldSelectableInt can then operate on partial bits, and because there
193 is a bit width associated with SelectableInt, slices operate correctly
194 including negative start/end points.
195 """
196
197 def __init__(self, value, bits=None):
198 if isinstance(value, SelectableInt):
199 if bits is not None:
200 raise ValueError(value)
201 bits = value.bits
202 value = value.value
203 elif isinstance(value, FieldSelectableInt):
204 if bits is not None:
205 raise ValueError(value)
206 bits = len(value.br)
207 value = value.si.value
208 else:
209 if not isinstance(value, int):
210 raise ValueError(value)
211 if bits is None:
212 raise ValueError(bits)
213 mask = (1 << bits) - 1
214 self.value = value & mask
215 self.bits = bits
216 self.overflow = (value & ~mask) != 0
217
218 def eq(self, b):
219 self.value = b.value
220 self.bits = b.bits
221
222 def to_signed_int(self):
223 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
224 if self.value & (1<<(self.bits-1)) != 0: # negative
225 res = self.value - (1<<self.bits)
226 log (" val -ve:", self.bits, res)
227 else:
228 res = self.value
229 log (" val +ve:", res)
230 return res
231
232 def _op(self, op, b):
233 if isinstance(b, int):
234 b = SelectableInt(b, self.bits)
235 b = check_extsign(self, b)
236 assert b.bits == self.bits
237 return SelectableInt(op(self.value, b.value), self.bits)
238
239 def __add__(self, b):
240 return self._op(add, b)
241
242 def __sub__(self, b):
243 return self._op(sub, b)
244
245 def __mul__(self, b):
246 # different case: mul result needs to fit the total bitsize
247 if isinstance(b, int):
248 b = SelectableInt(b, self.bits)
249 log("SelectableInt mul", hex(self.value), hex(b.value),
250 self.bits, b.bits)
251 return SelectableInt(self.value * b.value, self.bits + b.bits)
252
253 def __floordiv__(self, b):
254 return self._op(floordiv, b)
255
256 def __truediv__(self, b):
257 return self._op(truediv, b)
258
259 def __mod__(self, b):
260 return self._op(mod, b)
261
262 def __and__(self, b):
263 return self._op(and_, b)
264
265 def __or__(self, b):
266 return self._op(or_, b)
267
268 def __xor__(self, b):
269 return self._op(xor, b)
270
271 def __abs__(self):
272 log("abs", self.value & (1 << (self.bits-1)))
273 if self.value & (1 << (self.bits-1)) != 0:
274 return -self
275 return self
276
277 def __rsub__(self, b):
278 log("rsub", b, self.value)
279 if isinstance(b, int):
280 b = SelectableInt(b, 256) # max extent
281 #b = check_extsign(self, b)
282 #assert b.bits == self.bits
283 return SelectableInt(b.value - self.value, b.bits)
284
285 def __radd__(self, b):
286 if isinstance(b, int):
287 b = SelectableInt(b, self.bits)
288 b = check_extsign(self, b)
289 assert b.bits == self.bits
290 return SelectableInt(b.value + self.value, self.bits)
291
292 def __rxor__(self, b):
293 b = check_extsign(self, b)
294 assert b.bits == self.bits
295 return SelectableInt(self.value ^ b.value, self.bits)
296
297 def __invert__(self):
298 return SelectableInt(~self.value, self.bits)
299
300 def __neg__(self):
301 res = SelectableInt((~self.value) + 1, self.bits)
302 log ("neg", hex(self.value), hex(res.value))
303 return res
304
305 def __lshift__(self, b):
306 b = check_extsign(self, b)
307 return SelectableInt(self.value << b.value, self.bits)
308
309 def __rshift__(self, b):
310 b = check_extsign(self, b)
311 return SelectableInt(self.value >> b.value, self.bits)
312
313 def __getitem__(self, key):
314 log ("SelectableInt.__getitem__", self, key, type(key))
315 if isinstance(key, SelectableInt):
316 key = key.value
317 if isinstance(key, int):
318 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
319 assert key >= 0
320 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
321 # MSB is indexed **LOWEST** (sigh)
322 key = self.bits - (key + 1)
323
324 value = (self.value >> key) & 1
325 log("getitem", key, self.bits, hex(self.value), value)
326 return SelectableInt(value, 1)
327 elif isinstance(key, slice):
328 assert key.step is None or key.step == 1
329 assert key.start < key.stop
330 assert key.start >= 0
331 assert key.stop <= self.bits
332
333 stop = self.bits - key.start
334 start = self.bits - key.stop
335
336 bits = stop - start
337 log ("__getitem__ slice num bits", start, stop, bits)
338 mask = (1 << bits) - 1
339 value = (self.value >> start) & mask
340 log("getitem", stop, start, self.bits, hex(self.value), value)
341 return SelectableInt(value, bits)
342
343 def __setitem__(self, key, value):
344 if isinstance(key, SelectableInt):
345 key = key.value
346 if isinstance(key, int):
347 if isinstance(value, SelectableInt):
348 assert value.bits == 1
349 value = value.value
350 log("setitem", key, self.bits, hex(self.value), hex(value))
351
352 assert key < self.bits
353 assert key >= 0
354 key = self.bits - (key + 1)
355
356 value = value << key
357 mask = 1 << key
358 self.value = (self.value & ~mask) | (value & mask)
359 elif isinstance(key, slice):
360 kstart, kstop, kstep = key.start, key.stop, key.step
361 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
362 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
363 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
364 log ("__setitem__ slice ", kstart, kstop, kstep)
365 assert kstep is None or kstep == 1
366 assert kstart < kstop
367 assert kstart >= 0
368 assert kstop <= self.bits, \
369 "key stop %d bits %d" % (kstop, self.bits)
370
371 stop = self.bits - kstart
372 start = self.bits - kstop
373
374 bits = stop - start
375 #log ("__setitem__ slice num bits", bits)
376 if isinstance(value, SelectableInt):
377 assert value.bits == bits, "%d into %d" % (value.bits, bits)
378 value = value.value
379 log("setitem", key, self.bits, hex(self.value), hex(value))
380 mask = ((1 << bits) - 1) << start
381 value = value << start
382 self.value = (self.value & ~mask) | (value & mask)
383
384 def __ge__(self, other):
385 if isinstance(other, FieldSelectableInt):
386 other = other.get_range()
387 if isinstance(other, SelectableInt):
388 other = check_extsign(self, other)
389 assert other.bits == self.bits
390 other = other.to_signed_int()
391 if isinstance(other, int):
392 return onebit(self.to_signed_int() >= other)
393 assert False
394
395 def __le__(self, other):
396 if isinstance(other, FieldSelectableInt):
397 other = other.get_range()
398 if isinstance(other, SelectableInt):
399 other = check_extsign(self, other)
400 assert other.bits == self.bits
401 other = other.to_signed_int()
402 if isinstance(other, int):
403 return onebit(self.to_signed_int() <= other)
404 assert False
405
406 def __gt__(self, other):
407 if isinstance(other, FieldSelectableInt):
408 other = other.get_range()
409 if isinstance(other, SelectableInt):
410 other = check_extsign(self, other)
411 assert other.bits == self.bits
412 other = other.to_signed_int()
413 if isinstance(other, int):
414 return onebit(self.to_signed_int() > other)
415 assert False
416
417 def __lt__(self, other):
418 log ("SelectableInt lt", self, other)
419 if isinstance(other, FieldSelectableInt):
420 other = other.get_range()
421 if isinstance(other, SelectableInt):
422 other = check_extsign(self, other)
423 assert other.bits == self.bits
424 other = other.to_signed_int()
425 if isinstance(other, int):
426 a = self.to_signed_int()
427 res = onebit(a < other)
428 log (" a < b", a, other, res)
429 return res
430 assert False
431
432 def __eq__(self, other):
433 log("__eq__", self, other)
434 if isinstance(other, FieldSelectableInt):
435 other = other.get_range()
436 if isinstance(other, SelectableInt):
437 other = check_extsign(self, other)
438 assert other.bits == self.bits
439 other = other.value
440 log (" eq", other, self.value, other == self.value)
441 if isinstance(other, int):
442 return onebit(other == self.value)
443 assert False
444
445 def narrow(self, bits):
446 assert bits <= self.bits
447 return SelectableInt(self.value, bits)
448
449 def __bool__(self):
450 return self.value != 0
451
452 def __repr__(self):
453 value = f"value={hex(self.value)}, bits={self.bits}"
454 return f"{self.__class__.__name__}({value})"
455
456 def __len__(self):
457 return self.bits
458
459 def asint(self):
460 return self.value
461
462 def __float__(self):
463 """convert to double-precision float. TODO, properly convert
464 rather than a hack-job: must actually support Power IEEE754 FP
465 """
466 assert self.bits == 64 # must be 64-bit
467 data = self.value.to_bytes(8, byteorder='little')
468 return struct.unpack('<d', data)[0]
469
470
471 class SelectableIntMappingMeta(type):
472 @functools.total_ordering
473 class Field(FieldSelectableInt):
474 def __int__(self):
475 return self.asint(msb0=True)
476
477 def __lt__(self, b):
478 return int(self).__lt__(b)
479
480 def __eq__(self, b):
481 return int(self).__eq__(b)
482
483 class FieldProperty:
484 def __init__(self, field):
485 self.__field = field
486
487 def __repr__(self):
488 return self.__field.__repr__()
489
490 def __get__(self, instance, owner):
491 if instance is None:
492 return self.__field
493
494 cls = SelectableIntMappingMeta.Field
495 factory = lambda br: cls(si=instance, br=br)
496 if isinstance(self.__field, dict):
497 return {k:factory(br=v) for (k, v) in self.__field.items()}
498 else:
499 return factory(br=self.__field)
500
501 class BitsProperty:
502 def __init__(self, bits):
503 self.__bits = bits
504
505 def __get__(self, instance, owner):
506 if instance is None:
507 return self.__bits
508 return instance.bits
509
510 def __repr__(self):
511 return self.__bits.__repr__()
512
513 def __new__(metacls, name, bases, attrs, bits=None, fields=None):
514 if fields is None:
515 fields = {}
516
517 def field(item):
518 (key, value) = item
519 if isinstance(value, dict):
520 value = dict(map(field, value.items()))
521 else:
522 value = tuple(value)
523 return (key, value)
524
525 fields = dict(map(field, fields.items()))
526 for (key, value) in fields.items():
527 attrs.setdefault(key, metacls.FieldProperty(value))
528
529 if bits is None:
530 for base in bases:
531 bits = getattr(base, "bits", None)
532 if bits is not None:
533 break
534
535 if not isinstance(bits, int):
536 raise ValueError(bits)
537 attrs.setdefault("bits", metacls.BitsProperty(bits))
538
539 cls = super().__new__(metacls, name, bases, attrs)
540 cls.__fields = fields
541 return cls
542
543 def __iter__(cls):
544 for (key, value) in cls.__fields.items():
545 yield (key, value)
546
547
548 class SelectableIntMapping(SelectableInt, metaclass=SelectableIntMappingMeta,
549 bits=0):
550 def __init__(self, value=0, bits=None):
551 if isinstance(value, int) and bits is None:
552 bits = self.__class__.bits
553 return super().__init__(value, bits)
554
555
556 def onebit(bit):
557 return SelectableInt(1 if bit else 0, 1)
558
559
560 def selectltu(lhs, rhs):
561 """ less-than (unsigned)
562 """
563 if isinstance(rhs, SelectableInt):
564 rhs = rhs.value
565 return onebit(lhs.value < rhs)
566
567
568 def selectgtu(lhs, rhs):
569 """ greater-than (unsigned)
570 """
571 if isinstance(rhs, SelectableInt):
572 rhs = rhs.value
573 return onebit(lhs.value > rhs)
574
575
576 # XXX this probably isn't needed...
577 def selectassign(lhs, idx, rhs):
578 if isinstance(idx, tuple):
579 if len(idx) == 2:
580 lower, upper = idx
581 step = None
582 else:
583 lower, upper, step = idx
584 toidx = range(lower, upper, step)
585 fromidx = range(0, upper-lower, step) # XXX eurgh...
586 else:
587 toidx = [idx]
588 fromidx = [0]
589 for t, f in zip(toidx, fromidx):
590 lhs[t] = rhs[f]
591
592
593 def selectconcat(*args, repeat=1):
594 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
595 args = [SelectableInt(args[0], 1)]
596 if repeat != 1: # multiplies the incoming arguments
597 tmp = []
598 for i in range(repeat):
599 tmp += args
600 args = tmp
601 res = copy(args[0])
602 for i in args[1:]:
603 if isinstance(i, FieldSelectableInt):
604 i = i.si
605 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
606 res.bits += i.bits
607 res.value = (res.value << i.bits) | i.value
608 log("concat", repeat, res)
609 return res
610
611
612 class SelectableIntTestCase(unittest.TestCase):
613 def test_arith(self):
614 a = SelectableInt(5, 8)
615 b = SelectableInt(9, 8)
616 c = a + b
617 d = a - b
618 e = a * b
619 f = -a
620 g = abs(f)
621 h = abs(a)
622 self.assertEqual(c.value, a.value + b.value)
623 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
624 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
625 self.assertEqual(f.value, (-a.value) & 0xFF)
626 self.assertEqual(c.bits, a.bits)
627 self.assertEqual(d.bits, a.bits)
628 self.assertEqual(e.bits, a.bits)
629 self.assertEqual(a.bits, f.bits)
630 self.assertEqual(a.bits, h.bits)
631
632 def test_logic(self):
633 a = SelectableInt(0x0F, 8)
634 b = SelectableInt(0xA5, 8)
635 c = a & b
636 d = a | b
637 e = a ^ b
638 f = ~a
639 self.assertEqual(c.value, a.value & b.value)
640 self.assertEqual(d.value, a.value | b.value)
641 self.assertEqual(e.value, a.value ^ b.value)
642 self.assertEqual(f.value, 0xF0)
643
644 def test_get(self):
645 a = SelectableInt(0xa2, 8)
646 # These should be big endian
647 self.assertEqual(a[7], 0)
648 self.assertEqual(a[0:4], 10)
649 self.assertEqual(a[4:8], 2)
650
651 def test_set(self):
652 a = SelectableInt(0x5, 8)
653 a[7] = SelectableInt(0, 1)
654 self.assertEqual(a, 4)
655 a[4:8] = 9
656 self.assertEqual(a, 9)
657 a[0:4] = 3
658 self.assertEqual(a, 0x39)
659 a[0:4] = a[4:8]
660 self.assertEqual(a, 0x99)
661
662 def test_concat(self):
663 a = SelectableInt(0x1, 1)
664 c = selectconcat(a, repeat=8)
665 self.assertEqual(c, 0xff)
666 self.assertEqual(c.bits, 8)
667 a = SelectableInt(0x0, 1)
668 c = selectconcat(a, repeat=8)
669 self.assertEqual(c, 0x00)
670 self.assertEqual(c.bits, 8)
671
672 def test_repr(self):
673 for i in range(65536):
674 a = SelectableInt(i, 16)
675 b = eval(repr(a))
676 self.assertEqual(a, b)
677
678 def test_cmp(self):
679 a = SelectableInt(10, bits=8)
680 b = SelectableInt(5, bits=8)
681 self.assertTrue(a > b)
682 self.assertFalse(a < b)
683 self.assertTrue(a != b)
684 self.assertFalse(a == b)
685
686 def test_unsigned(self):
687 a = SelectableInt(0x80, bits=8)
688 b = SelectableInt(0x7f, bits=8)
689 self.assertTrue(a > b)
690 self.assertFalse(a < b)
691 self.assertTrue(a != b)
692 self.assertFalse(a == b)
693
694 def test_maxint(self):
695 a = SelectableInt(0xffffffffffffffff, bits=64)
696 b = SelectableInt(0, bits=64)
697 result = a + b
698 self.assertTrue(result.value == 0xffffffffffffffff)
699
700 def test_double_1(self):
701 """use http://weitz.de/ieee/,
702 """
703 for asint, asfloat in [(0x4000000000000000, 2.0),
704 (0x4056C00000000000, 91.0),
705 (0xff80000000000000, -1.4044477616111843e+306),
706 ]:
707 a = SelectableInt(asint, bits=64)
708 convert = float(a)
709 log ("test_double_1", asint, asfloat, convert)
710 self.assertTrue(asfloat == convert)
711
712
713 if __name__ == "__main__":
714 unittest.main()