selectable_int: support int casts
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from openpower.decoder.power_fields import BitRange
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10
11 def check_extsign(a, b):
12 if isinstance(b, FieldSelectableInt):
13 b = b.get_range()
14 if isinstance(b, int):
15 return SelectableInt(b, a.bits)
16 if b.bits != 256:
17 return b
18 return SelectableInt(b.value, a.bits)
19
20
21 @functools.total_ordering
22 class FieldSelectableInt:
23 """FieldSelectableInt: allows bit-range selection onto another target
24 """
25
26 def __init__(self, si, br):
27 if not isinstance(si, SelectableInt):
28 raise ValueError(si)
29 self.si = si # target selectable int
30 if isinstance(br, (list, tuple, range)):
31 _br = BitRange()
32 for i, v in enumerate(br):
33 _br[i] = v
34 br = _br
35 self.br = br # map of indices.
36
37 def eq(self, b):
38 if isinstance(b, int):
39 # convert integer to same SelectableInt of same bitlength as range
40 blen = len(self.br)
41 b = SelectableInt(b, blen)
42 for i in range(b.bits):
43 self[i] = b[i]
44 elif isinstance(b, SelectableInt):
45 for i in range(b.bits):
46 self[i] = b[i]
47 else:
48 self.si = copy(b.si)
49 self.br = copy(b.br)
50
51 def _op(self, op, b):
52 vi = self.get_range()
53 vi = op(vi, b)
54 return self.merge(vi)
55
56 def _op1(self, op):
57 vi = self.get_range()
58 vi = op(vi)
59 return self.merge(vi)
60
61 def __getitem__(self, key):
62 log("getitem", key, self.br)
63 if isinstance(key, SelectableInt):
64 key = key.value
65 if isinstance(key, int):
66 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
67 return self.si[key]
68 if isinstance(key, slice):
69 key = self.br[key]
70 return selectconcat(*[self.si[x] for x in key])
71
72 def __setitem__(self, key, value):
73 if isinstance(key, SelectableInt):
74 key = key.value
75 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
76 if isinstance(key, int):
77 return self.si.__setitem__(key, value)
78 else:
79 if not isinstance(value, SelectableInt):
80 value = SelectableInt(value, bits=len(key))
81 for i, k in enumerate(key):
82 self.si[k] = value[i]
83
84 def __negate__(self):
85 return self._op1(negate)
86
87 def __invert__(self):
88 return self._op1(inv)
89
90 def __add__(self, b):
91 return self._op(add, b)
92
93 def __sub__(self, b):
94 return self._op(sub, b)
95
96 def __mul__(self, b):
97 return self._op(mul, b)
98
99 def __div__(self, b):
100 return self._op(truediv, b)
101
102 def __mod__(self, b):
103 return self._op(mod, b)
104
105 def __and__(self, b):
106 return self._op(and_, b)
107
108 def __or__(self, b):
109 return self._op(or_, b)
110
111 def __xor__(self, b):
112 return self._op(xor, b)
113
114 def __lt__(self, b):
115 return self._op(lt, b)
116
117 def __eq__(self, b):
118 return self._op(eq, b)
119
120 def get_range(self):
121 vi = SelectableInt(0, len(self.br))
122 for k, v in self.br.items():
123 vi[k] = self.si[v]
124 return vi
125
126 def merge(self, vi):
127 fi = copy(self)
128 for i, v in fi.br.items():
129 fi.si[v] = vi[i]
130 return fi
131
132 def __repr__(self):
133 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
134
135 def asint(self, msb0=False):
136 res = 0
137 brlen = len(self.br)
138 for i, key in self.br.items():
139 log("asint", i, key, self.si[key])
140 bit = self.si[key].value
141 #log("asint", i, key, bit)
142 res |= bit << ((brlen-i-1) if msb0 else i)
143 return res
144
145
146 class FieldSelectableIntTestCase(unittest.TestCase):
147 def test_arith(self):
148 a = SelectableInt(0b10101, 5)
149 b = SelectableInt(0b011, 3)
150 br = BitRange()
151 br[0] = 0
152 br[1] = 2
153 br[2] = 3
154 fs = FieldSelectableInt(a, br)
155 c = fs + b
156 log(c)
157 #self.assertEqual(c.value, a.value + b.value)
158
159 def test_select(self):
160 a = SelectableInt(0b00001111, 8)
161 br = BitRange()
162 br[0] = 0
163 br[1] = 1
164 br[2] = 4
165 br[3] = 5
166 fs = FieldSelectableInt(a, br)
167
168 self.assertEqual(fs.get_range(), 0b0011)
169
170 def test_select_range(self):
171 a = SelectableInt(0b00001111, 8)
172 br = BitRange()
173 br[0] = 0
174 br[1] = 1
175 br[2] = 4
176 br[3] = 5
177 fs = FieldSelectableInt(a, br)
178
179 self.assertEqual(fs[2:4], 0b11)
180
181 fs[0:2] = 0b10
182 self.assertEqual(fs.get_range(), 0b1011)
183
184
185 class SelectableInt:
186 """SelectableInt - a class that behaves exactly like python int
187
188 this class is designed to mirror precisely the behaviour of python int.
189 the only difference is that it must contain the context of the bitwidth
190 (number of bits) associated with that integer.
191
192 FieldSelectableInt can then operate on partial bits, and because there
193 is a bit width associated with SelectableInt, slices operate correctly
194 including negative start/end points.
195 """
196
197 def __init__(self, value, bits=None):
198 if isinstance(value, SelectableInt):
199 if bits is not None:
200 raise ValueError(value)
201 bits = value.bits
202 value = value.value
203 elif isinstance(value, FieldSelectableInt):
204 if bits is not None:
205 raise ValueError(value)
206 bits = len(value.br)
207 value = value.si.value
208 else:
209 if not isinstance(value, int):
210 raise ValueError(value)
211 if bits is None:
212 raise ValueError(bits)
213 mask = (1 << bits) - 1
214 self.value = value & mask
215 self.bits = bits
216 self.overflow = (value & ~mask) != 0
217
218 def eq(self, b):
219 self.value = b.value
220 self.bits = b.bits
221
222 def to_signed_int(self):
223 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
224 if self.value & (1<<(self.bits-1)) != 0: # negative
225 res = self.value - (1<<self.bits)
226 log (" val -ve:", self.bits, res)
227 else:
228 res = self.value
229 log (" val +ve:", res)
230 return res
231
232 def _op(self, op, b):
233 if isinstance(b, int):
234 b = SelectableInt(b, self.bits)
235 b = check_extsign(self, b)
236 assert b.bits == self.bits
237 return SelectableInt(op(self.value, b.value), self.bits)
238
239 def __add__(self, b):
240 return self._op(add, b)
241
242 def __sub__(self, b):
243 return self._op(sub, b)
244
245 def __mul__(self, b):
246 # different case: mul result needs to fit the total bitsize
247 if isinstance(b, int):
248 b = SelectableInt(b, self.bits)
249 log("SelectableInt mul", hex(self.value), hex(b.value),
250 self.bits, b.bits)
251 return SelectableInt(self.value * b.value, self.bits + b.bits)
252
253 def __floordiv__(self, b):
254 return self._op(floordiv, b)
255
256 def __truediv__(self, b):
257 return self._op(truediv, b)
258
259 def __mod__(self, b):
260 return self._op(mod, b)
261
262 def __and__(self, b):
263 return self._op(and_, b)
264
265 def __or__(self, b):
266 return self._op(or_, b)
267
268 def __xor__(self, b):
269 return self._op(xor, b)
270
271 def __abs__(self):
272 log("abs", self.value & (1 << (self.bits-1)))
273 if self.value & (1 << (self.bits-1)) != 0:
274 return -self
275 return self
276
277 def __rsub__(self, b):
278 log("rsub", b, self.value)
279 if isinstance(b, int):
280 b = SelectableInt(b, 256) # max extent
281 #b = check_extsign(self, b)
282 #assert b.bits == self.bits
283 return SelectableInt(b.value - self.value, b.bits)
284
285 def __radd__(self, b):
286 if isinstance(b, int):
287 b = SelectableInt(b, self.bits)
288 b = check_extsign(self, b)
289 assert b.bits == self.bits
290 return SelectableInt(b.value + self.value, self.bits)
291
292 def __rxor__(self, b):
293 b = check_extsign(self, b)
294 assert b.bits == self.bits
295 return SelectableInt(self.value ^ b.value, self.bits)
296
297 def __invert__(self):
298 return SelectableInt(~self.value, self.bits)
299
300 def __neg__(self):
301 res = SelectableInt((~self.value) + 1, self.bits)
302 log ("neg", hex(self.value), hex(res.value))
303 return res
304
305 def __lshift__(self, b):
306 b = check_extsign(self, b)
307 return SelectableInt(self.value << b.value, self.bits)
308
309 def __rshift__(self, b):
310 b = check_extsign(self, b)
311 return SelectableInt(self.value >> b.value, self.bits)
312
313 def __getitem__(self, key):
314 log ("SelectableInt.__getitem__", self, key, type(key))
315 if isinstance(key, SelectableInt):
316 key = key.value
317 if isinstance(key, int):
318 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
319 assert key >= 0
320 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
321 # MSB is indexed **LOWEST** (sigh)
322 key = self.bits - (key + 1)
323
324 value = (self.value >> key) & 1
325 log("getitem", key, self.bits, hex(self.value), value)
326 return SelectableInt(value, 1)
327 elif isinstance(key, slice):
328 assert key.step is None or key.step == 1
329 assert key.start < key.stop
330 assert key.start >= 0
331 assert key.stop <= self.bits
332
333 stop = self.bits - key.start
334 start = self.bits - key.stop
335
336 bits = stop - start
337 log ("__getitem__ slice num bits", start, stop, bits)
338 mask = (1 << bits) - 1
339 value = (self.value >> start) & mask
340 log("getitem", stop, start, self.bits, hex(self.value), value)
341 return SelectableInt(value, bits)
342
343 def __setitem__(self, key, value):
344 if isinstance(key, SelectableInt):
345 key = key.value
346 if isinstance(key, int):
347 if isinstance(value, SelectableInt):
348 assert value.bits == 1
349 value = value.value
350 log("setitem", key, self.bits, hex(self.value), hex(value))
351
352 assert key < self.bits
353 assert key >= 0
354 key = self.bits - (key + 1)
355
356 value = value << key
357 mask = 1 << key
358 self.value = (self.value & ~mask) | (value & mask)
359 elif isinstance(key, slice):
360 kstart, kstop, kstep = key.start, key.stop, key.step
361 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
362 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
363 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
364 log ("__setitem__ slice ", kstart, kstop, kstep)
365 assert kstep is None or kstep == 1
366 assert kstart < kstop
367 assert kstart >= 0
368 assert kstop <= self.bits, \
369 "key stop %d bits %d" % (kstop, self.bits)
370
371 stop = self.bits - kstart
372 start = self.bits - kstop
373
374 bits = stop - start
375 #log ("__setitem__ slice num bits", bits)
376 if isinstance(value, SelectableInt):
377 assert value.bits == bits, "%d into %d" % (value.bits, bits)
378 value = value.value
379 log("setitem", key, self.bits, hex(self.value), hex(value))
380 mask = ((1 << bits) - 1) << start
381 value = value << start
382 self.value = (self.value & ~mask) | (value & mask)
383
384 def __ge__(self, other):
385 if isinstance(other, FieldSelectableInt):
386 other = other.get_range()
387 if isinstance(other, SelectableInt):
388 other = check_extsign(self, other)
389 assert other.bits == self.bits
390 other = other.to_signed_int()
391 if isinstance(other, int):
392 return onebit(self.to_signed_int() >= other)
393 assert False
394
395 def __le__(self, other):
396 if isinstance(other, FieldSelectableInt):
397 other = other.get_range()
398 if isinstance(other, SelectableInt):
399 other = check_extsign(self, other)
400 assert other.bits == self.bits
401 other = other.to_signed_int()
402 if isinstance(other, int):
403 return onebit(self.to_signed_int() <= other)
404 assert False
405
406 def __gt__(self, other):
407 if isinstance(other, FieldSelectableInt):
408 other = other.get_range()
409 if isinstance(other, SelectableInt):
410 other = check_extsign(self, other)
411 assert other.bits == self.bits
412 other = other.to_signed_int()
413 if isinstance(other, int):
414 return onebit(self.to_signed_int() > other)
415 assert False
416
417 def __lt__(self, other):
418 log ("SelectableInt lt", self, other)
419 if isinstance(other, FieldSelectableInt):
420 other = other.get_range()
421 if isinstance(other, SelectableInt):
422 other = check_extsign(self, other)
423 assert other.bits == self.bits
424 other = other.to_signed_int()
425 if isinstance(other, int):
426 a = self.to_signed_int()
427 res = onebit(a < other)
428 log (" a < b", a, other, res)
429 return res
430 assert False
431
432 def __eq__(self, other):
433 log("__eq__", self, other)
434 if isinstance(other, FieldSelectableInt):
435 other = other.get_range()
436 if isinstance(other, SelectableInt):
437 other = check_extsign(self, other)
438 assert other.bits == self.bits
439 other = other.value
440 log (" eq", other, self.value, other == self.value)
441 if isinstance(other, int):
442 return onebit(other == self.value)
443 assert False
444
445 def narrow(self, bits):
446 assert bits <= self.bits
447 return SelectableInt(self.value, bits)
448
449 def __bool__(self):
450 return self.value != 0
451
452 def __repr__(self):
453 value = f"value={hex(self.value)}, bits={self.bits}"
454 return f"{self.__class__.__name__}({value})"
455
456 def __len__(self):
457 return self.bits
458
459 def asint(self):
460 return self.value
461
462 def __int__(self):
463 return self.asint()
464
465 def __float__(self):
466 """convert to double-precision float. TODO, properly convert
467 rather than a hack-job: must actually support Power IEEE754 FP
468 """
469 assert self.bits == 64 # must be 64-bit
470 data = self.value.to_bytes(8, byteorder='little')
471 return struct.unpack('<d', data)[0]
472
473
474 class SelectableIntMappingMeta(type):
475 @functools.total_ordering
476 class Field(FieldSelectableInt):
477 def __int__(self):
478 return self.asint(msb0=True)
479
480 def __lt__(self, b):
481 return int(self).__lt__(b)
482
483 def __eq__(self, b):
484 return int(self).__eq__(b)
485
486 class FieldProperty:
487 def __init__(self, field):
488 self.__field = field
489
490 def __repr__(self):
491 return self.__field.__repr__()
492
493 def __get__(self, instance, owner):
494 if instance is None:
495 return self.__field
496
497 cls = SelectableIntMappingMeta.Field
498 factory = lambda br: cls(si=instance, br=br)
499 if isinstance(self.__field, dict):
500 return {k:factory(br=v) for (k, v) in self.__field.items()}
501 else:
502 return factory(br=self.__field)
503
504 class BitsProperty:
505 def __init__(self, bits):
506 self.__bits = bits
507
508 def __get__(self, instance, owner):
509 if instance is None:
510 return self.__bits
511 return instance.bits
512
513 def __repr__(self):
514 return self.__bits.__repr__()
515
516 def __new__(metacls, name, bases, attrs, bits=None, fields=None):
517 if fields is None:
518 fields = {}
519
520 def field(item):
521 (key, value) = item
522 if isinstance(value, dict):
523 value = dict(map(field, value.items()))
524 else:
525 value = tuple(value)
526 return (key, value)
527
528 fields = dict(map(field, fields.items()))
529 for (key, value) in fields.items():
530 attrs.setdefault(key, metacls.FieldProperty(value))
531
532 if bits is None:
533 for base in bases:
534 bits = getattr(base, "bits", None)
535 if bits is not None:
536 break
537
538 if not isinstance(bits, int):
539 raise ValueError(bits)
540 attrs.setdefault("bits", metacls.BitsProperty(bits))
541
542 cls = super().__new__(metacls, name, bases, attrs)
543 cls.__fields = fields
544 return cls
545
546 def __iter__(cls):
547 for (key, value) in cls.__fields.items():
548 yield (key, value)
549
550
551 class SelectableIntMapping(SelectableInt, metaclass=SelectableIntMappingMeta,
552 bits=0):
553 def __init__(self, value=0, bits=None):
554 if isinstance(value, int) and bits is None:
555 bits = self.__class__.bits
556 return super().__init__(value, bits)
557
558
559 def onebit(bit):
560 return SelectableInt(1 if bit else 0, 1)
561
562
563 def selectltu(lhs, rhs):
564 """ less-than (unsigned)
565 """
566 if isinstance(rhs, SelectableInt):
567 rhs = rhs.value
568 return onebit(lhs.value < rhs)
569
570
571 def selectgtu(lhs, rhs):
572 """ greater-than (unsigned)
573 """
574 if isinstance(rhs, SelectableInt):
575 rhs = rhs.value
576 return onebit(lhs.value > rhs)
577
578
579 # XXX this probably isn't needed...
580 def selectassign(lhs, idx, rhs):
581 if isinstance(idx, tuple):
582 if len(idx) == 2:
583 lower, upper = idx
584 step = None
585 else:
586 lower, upper, step = idx
587 toidx = range(lower, upper, step)
588 fromidx = range(0, upper-lower, step) # XXX eurgh...
589 else:
590 toidx = [idx]
591 fromidx = [0]
592 for t, f in zip(toidx, fromidx):
593 lhs[t] = rhs[f]
594
595
596 def selectconcat(*args, repeat=1):
597 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
598 args = [SelectableInt(args[0], 1)]
599 if repeat != 1: # multiplies the incoming arguments
600 tmp = []
601 for i in range(repeat):
602 tmp += args
603 args = tmp
604 res = copy(args[0])
605 for i in args[1:]:
606 if isinstance(i, FieldSelectableInt):
607 i = i.si
608 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
609 res.bits += i.bits
610 res.value = (res.value << i.bits) | i.value
611 log("concat", repeat, res)
612 return res
613
614
615 class SelectableIntTestCase(unittest.TestCase):
616 def test_arith(self):
617 a = SelectableInt(5, 8)
618 b = SelectableInt(9, 8)
619 c = a + b
620 d = a - b
621 e = a * b
622 f = -a
623 g = abs(f)
624 h = abs(a)
625 self.assertEqual(c.value, a.value + b.value)
626 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
627 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
628 self.assertEqual(f.value, (-a.value) & 0xFF)
629 self.assertEqual(c.bits, a.bits)
630 self.assertEqual(d.bits, a.bits)
631 self.assertEqual(e.bits, a.bits)
632 self.assertEqual(a.bits, f.bits)
633 self.assertEqual(a.bits, h.bits)
634
635 def test_logic(self):
636 a = SelectableInt(0x0F, 8)
637 b = SelectableInt(0xA5, 8)
638 c = a & b
639 d = a | b
640 e = a ^ b
641 f = ~a
642 self.assertEqual(c.value, a.value & b.value)
643 self.assertEqual(d.value, a.value | b.value)
644 self.assertEqual(e.value, a.value ^ b.value)
645 self.assertEqual(f.value, 0xF0)
646
647 def test_get(self):
648 a = SelectableInt(0xa2, 8)
649 # These should be big endian
650 self.assertEqual(a[7], 0)
651 self.assertEqual(a[0:4], 10)
652 self.assertEqual(a[4:8], 2)
653
654 def test_set(self):
655 a = SelectableInt(0x5, 8)
656 a[7] = SelectableInt(0, 1)
657 self.assertEqual(a, 4)
658 a[4:8] = 9
659 self.assertEqual(a, 9)
660 a[0:4] = 3
661 self.assertEqual(a, 0x39)
662 a[0:4] = a[4:8]
663 self.assertEqual(a, 0x99)
664
665 def test_concat(self):
666 a = SelectableInt(0x1, 1)
667 c = selectconcat(a, repeat=8)
668 self.assertEqual(c, 0xff)
669 self.assertEqual(c.bits, 8)
670 a = SelectableInt(0x0, 1)
671 c = selectconcat(a, repeat=8)
672 self.assertEqual(c, 0x00)
673 self.assertEqual(c.bits, 8)
674
675 def test_repr(self):
676 for i in range(65536):
677 a = SelectableInt(i, 16)
678 b = eval(repr(a))
679 self.assertEqual(a, b)
680
681 def test_cmp(self):
682 a = SelectableInt(10, bits=8)
683 b = SelectableInt(5, bits=8)
684 self.assertTrue(a > b)
685 self.assertFalse(a < b)
686 self.assertTrue(a != b)
687 self.assertFalse(a == b)
688
689 def test_unsigned(self):
690 a = SelectableInt(0x80, bits=8)
691 b = SelectableInt(0x7f, bits=8)
692 self.assertTrue(a > b)
693 self.assertFalse(a < b)
694 self.assertTrue(a != b)
695 self.assertFalse(a == b)
696
697 def test_maxint(self):
698 a = SelectableInt(0xffffffffffffffff, bits=64)
699 b = SelectableInt(0, bits=64)
700 result = a + b
701 self.assertTrue(result.value == 0xffffffffffffffff)
702
703 def test_double_1(self):
704 """use http://weitz.de/ieee/,
705 """
706 for asint, asfloat in [(0x4000000000000000, 2.0),
707 (0x4056C00000000000, 91.0),
708 (0xff80000000000000, -1.4044477616111843e+306),
709 ]:
710 a = SelectableInt(asint, bits=64)
711 convert = float(a)
712 log ("test_double_1", asint, asfloat, convert)
713 self.assertTrue(asfloat == convert)
714
715
716 if __name__ == "__main__":
717 unittest.main()