whitespace - keep to under 80 chars
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from openpower.decoder.power_fields import BitRange
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10
11 def check_extsign(a, b):
12 if isinstance(b, FieldSelectableInt):
13 b = b.get_range()
14 if isinstance(b, int):
15 return SelectableInt(b, a.bits)
16 if b.bits != 256:
17 return b
18 return SelectableInt(b.value, a.bits)
19
20
21 @functools.total_ordering
22 class FieldSelectableInt:
23 """FieldSelectableInt: allows bit-range selection onto another target
24 """
25
26 def __init__(self, si, br):
27 if not isinstance(si, SelectableInt):
28 raise ValueError(si)
29 self.si = si # target selectable int
30 if isinstance(br, (list, tuple, range)):
31 _br = BitRange()
32 for i, v in enumerate(br):
33 _br[i] = v
34 br = _br
35 self.br = br # map of indices.
36
37 def eq(self, b):
38 if isinstance(b, int):
39 # convert integer to same SelectableInt of same bitlength as range
40 blen = len(self.br)
41 b = SelectableInt(b, blen)
42 for i in range(b.bits):
43 self[i] = b[i]
44 elif isinstance(b, SelectableInt):
45 for i in range(b.bits):
46 self[i] = b[i]
47 else:
48 self.si = copy(b.si)
49 self.br = copy(b.br)
50
51 def _op(self, op, b):
52 vi = self.get_range()
53 vi = op(vi, b)
54 return self.merge(vi)
55
56 def _op1(self, op):
57 vi = self.get_range()
58 vi = op(vi)
59 return self.merge(vi)
60
61 def __getitem__(self, key):
62 log("getitem", key, self.br)
63 if isinstance(key, SelectableInt):
64 key = key.value
65 if isinstance(key, int):
66 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
67 return self.si[key]
68 if isinstance(key, slice):
69 key = self.br[key]
70 return selectconcat(*[self.si[x] for x in key])
71
72 def __setitem__(self, key, value):
73 if isinstance(key, SelectableInt):
74 key = key.value
75 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
76 if isinstance(key, int):
77 return self.si.__setitem__(key, value)
78 else:
79 if not isinstance(value, SelectableInt):
80 value = SelectableInt(value, bits=len(key))
81 for i, k in enumerate(key):
82 self.si[k] = value[i]
83
84 def __negate__(self):
85 return self._op1(negate)
86
87 def __invert__(self):
88 return self._op1(inv)
89
90 def __add__(self, b):
91 return self._op(add, b)
92
93 def __sub__(self, b):
94 return self._op(sub, b)
95
96 def __mul__(self, b):
97 return self._op(mul, b)
98
99 def __div__(self, b):
100 return self._op(truediv, b)
101
102 def __mod__(self, b):
103 return self._op(mod, b)
104
105 def __and__(self, b):
106 return self._op(and_, b)
107
108 def __or__(self, b):
109 return self._op(or_, b)
110
111 def __xor__(self, b):
112 return self._op(xor, b)
113
114 def __lt__(self, b):
115 return self._op(lt, b)
116
117 def __eq__(self, b):
118 return self._op(eq, b)
119
120 def get_range(self):
121 vi = SelectableInt(0, len(self.br))
122 for k, v in self.br.items():
123 vi[k] = self.si[v]
124 return vi
125
126 def merge(self, vi):
127 fi = copy(self)
128 for i, v in fi.br.items():
129 fi.si[v] = vi[i]
130 return fi
131
132 def __repr__(self):
133 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
134
135 def asint(self, msb0=False):
136 res = 0
137 brlen = len(self.br)
138 for i, key in self.br.items():
139 bit = self.si[key].value
140 #log("asint", i, key, bit)
141 res |= bit << ((brlen-i-1) if msb0 else i)
142 return res
143
144
145 class FieldSelectableIntTestCase(unittest.TestCase):
146 def test_arith(self):
147 a = SelectableInt(0b10101, 5)
148 b = SelectableInt(0b011, 3)
149 br = BitRange()
150 br[0] = 0
151 br[1] = 2
152 br[2] = 3
153 fs = FieldSelectableInt(a, br)
154 c = fs + b
155 log(c)
156 #self.assertEqual(c.value, a.value + b.value)
157
158 def test_select(self):
159 a = SelectableInt(0b00001111, 8)
160 br = BitRange()
161 br[0] = 0
162 br[1] = 1
163 br[2] = 4
164 br[3] = 5
165 fs = FieldSelectableInt(a, br)
166
167 self.assertEqual(fs.get_range(), 0b0011)
168
169 def test_select_range(self):
170 a = SelectableInt(0b00001111, 8)
171 br = BitRange()
172 br[0] = 0
173 br[1] = 1
174 br[2] = 4
175 br[3] = 5
176 fs = FieldSelectableInt(a, br)
177
178 self.assertEqual(fs[2:4], 0b11)
179
180 fs[0:2] = 0b10
181 self.assertEqual(fs.get_range(), 0b1011)
182
183
184 class SelectableInt:
185 """SelectableInt - a class that behaves exactly like python int
186
187 this class is designed to mirror precisely the behaviour of python int.
188 the only difference is that it must contain the context of the bitwidth
189 (number of bits) associated with that integer.
190
191 FieldSelectableInt can then operate on partial bits, and because there
192 is a bit width associated with SelectableInt, slices operate correctly
193 including negative start/end points.
194 """
195
196 def __init__(self, value, bits=None):
197 if isinstance(value, SelectableInt):
198 if bits is not None:
199 raise ValueError(value)
200 bits = value.bits
201 value = value.value
202 elif isinstance(value, FieldSelectableInt):
203 if bits is not None:
204 raise ValueError(value)
205 bits = len(value.br)
206 value = value.si.value
207 else:
208 if not isinstance(value, int):
209 raise ValueError(value)
210 if bits is None:
211 raise ValueError(bits)
212 mask = (1 << bits) - 1
213 self.value = value & mask
214 self.bits = bits
215 self.overflow = (value & ~mask) != 0
216
217 def eq(self, b):
218 self.value = b.value
219 self.bits = b.bits
220
221 def to_signed_int(self):
222 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
223 if self.value & (1<<(self.bits-1)) != 0: # negative
224 res = self.value - (1<<self.bits)
225 log (" val -ve:", self.bits, res)
226 else:
227 res = self.value
228 log (" val +ve:", res)
229 return res
230
231 def _op(self, op, b):
232 if isinstance(b, int):
233 b = SelectableInt(b, self.bits)
234 b = check_extsign(self, b)
235 assert b.bits == self.bits
236 return SelectableInt(op(self.value, b.value), self.bits)
237
238 def __add__(self, b):
239 return self._op(add, b)
240
241 def __sub__(self, b):
242 return self._op(sub, b)
243
244 def __mul__(self, b):
245 # different case: mul result needs to fit the total bitsize
246 if isinstance(b, int):
247 b = SelectableInt(b, self.bits)
248 log("SelectableInt mul", hex(self.value), hex(b.value),
249 self.bits, b.bits)
250 return SelectableInt(self.value * b.value, self.bits + b.bits)
251
252 def __floordiv__(self, b):
253 return self._op(floordiv, b)
254
255 def __truediv__(self, b):
256 return self._op(truediv, b)
257
258 def __mod__(self, b):
259 return self._op(mod, b)
260
261 def __and__(self, b):
262 return self._op(and_, b)
263
264 def __or__(self, b):
265 return self._op(or_, b)
266
267 def __xor__(self, b):
268 return self._op(xor, b)
269
270 def __abs__(self):
271 log("abs", self.value & (1 << (self.bits-1)))
272 if self.value & (1 << (self.bits-1)) != 0:
273 return -self
274 return self
275
276 def __rsub__(self, b):
277 log("rsub", b, self.value)
278 if isinstance(b, int):
279 b = SelectableInt(b, 256) # max extent
280 #b = check_extsign(self, b)
281 #assert b.bits == self.bits
282 return SelectableInt(b.value - self.value, b.bits)
283
284 def __radd__(self, b):
285 if isinstance(b, int):
286 b = SelectableInt(b, self.bits)
287 b = check_extsign(self, b)
288 assert b.bits == self.bits
289 return SelectableInt(b.value + self.value, self.bits)
290
291 def __rxor__(self, b):
292 b = check_extsign(self, b)
293 assert b.bits == self.bits
294 return SelectableInt(self.value ^ b.value, self.bits)
295
296 def __invert__(self):
297 return SelectableInt(~self.value, self.bits)
298
299 def __neg__(self):
300 res = SelectableInt((~self.value) + 1, self.bits)
301 log ("neg", hex(self.value), hex(res.value))
302 return res
303
304 def __lshift__(self, b):
305 b = check_extsign(self, b)
306 return SelectableInt(self.value << b.value, self.bits)
307
308 def __rshift__(self, b):
309 b = check_extsign(self, b)
310 return SelectableInt(self.value >> b.value, self.bits)
311
312 def __getitem__(self, key):
313 if isinstance(key, SelectableInt):
314 key = key.value
315 if isinstance(key, int):
316 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
317 assert key >= 0
318 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
319 # MSB is indexed **LOWEST** (sigh)
320 key = self.bits - (key + 1)
321
322 value = (self.value >> key) & 1
323 log("getitem", key, self.bits, hex(self.value), value)
324 return SelectableInt(value, 1)
325 elif isinstance(key, slice):
326 assert key.step is None or key.step == 1
327 assert key.start < key.stop
328 assert key.start >= 0
329 assert key.stop <= self.bits
330
331 stop = self.bits - key.start
332 start = self.bits - key.stop
333
334 bits = stop - start
335 #log ("__getitem__ slice num bits", start, stop, bits)
336 mask = (1 << bits) - 1
337 value = (self.value >> start) & mask
338 log("getitem", stop, start, self.bits, hex(self.value), value)
339 return SelectableInt(value, bits)
340
341 def __setitem__(self, key, value):
342 if isinstance(key, SelectableInt):
343 key = key.value
344 if isinstance(key, int):
345 if isinstance(value, SelectableInt):
346 assert value.bits == 1
347 value = value.value
348 log("setitem", key, self.bits, hex(self.value), hex(value))
349
350 assert key < self.bits
351 assert key >= 0
352 key = self.bits - (key + 1)
353
354 value = value << key
355 mask = 1 << key
356 self.value = (self.value & ~mask) | (value & mask)
357 elif isinstance(key, slice):
358 kstart, kstop, kstep = key.start, key.stop, key.step
359 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
360 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
361 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
362 log ("__setitem__ slice ", kstart, kstop, kstep)
363 assert kstep is None or kstep == 1
364 assert kstart < kstop
365 assert kstart >= 0
366 assert kstop <= self.bits, \
367 "key stop %d bits %d" % (kstop, self.bits)
368
369 stop = self.bits - kstart
370 start = self.bits - kstop
371
372 bits = stop - start
373 #log ("__setitem__ slice num bits", bits)
374 if isinstance(value, SelectableInt):
375 assert value.bits == bits, "%d into %d" % (value.bits, bits)
376 value = value.value
377 log("setitem", key, self.bits, hex(self.value), hex(value))
378 mask = ((1 << bits) - 1) << start
379 value = value << start
380 self.value = (self.value & ~mask) | (value & mask)
381
382 def __ge__(self, other):
383 if isinstance(other, FieldSelectableInt):
384 other = other.get_range()
385 if isinstance(other, SelectableInt):
386 other = check_extsign(self, other)
387 assert other.bits == self.bits
388 other = other.to_signed_int()
389 if isinstance(other, int):
390 return onebit(self.to_signed_int() >= other)
391 assert False
392
393 def __le__(self, other):
394 if isinstance(other, FieldSelectableInt):
395 other = other.get_range()
396 if isinstance(other, SelectableInt):
397 other = check_extsign(self, other)
398 assert other.bits == self.bits
399 other = other.to_signed_int()
400 if isinstance(other, int):
401 return onebit(self.to_signed_int() <= other)
402 assert False
403
404 def __gt__(self, other):
405 if isinstance(other, FieldSelectableInt):
406 other = other.get_range()
407 if isinstance(other, SelectableInt):
408 other = check_extsign(self, other)
409 assert other.bits == self.bits
410 other = other.to_signed_int()
411 if isinstance(other, int):
412 return onebit(self.to_signed_int() > other)
413 assert False
414
415 def __lt__(self, other):
416 log ("SelectableInt lt", self, other)
417 if isinstance(other, FieldSelectableInt):
418 other = other.get_range()
419 if isinstance(other, SelectableInt):
420 other = check_extsign(self, other)
421 assert other.bits == self.bits
422 other = other.to_signed_int()
423 if isinstance(other, int):
424 a = self.to_signed_int()
425 res = onebit(a < other)
426 log (" a < b", a, other, res)
427 return res
428 assert False
429
430 def __eq__(self, other):
431 log("__eq__", self, other)
432 if isinstance(other, FieldSelectableInt):
433 other = other.get_range()
434 if isinstance(other, SelectableInt):
435 other = check_extsign(self, other)
436 assert other.bits == self.bits
437 other = other.value
438 log (" eq", other, self.value, other == self.value)
439 if isinstance(other, int):
440 return onebit(other == self.value)
441 assert False
442
443 def narrow(self, bits):
444 assert bits <= self.bits
445 return SelectableInt(self.value, bits)
446
447 def __bool__(self):
448 return self.value != 0
449
450 def __repr__(self):
451 value = f"value={hex(self.value)}, bits={self.bits}"
452 return f"{self.__class__.__name__}({value})"
453
454 def __len__(self):
455 return self.bits
456
457 def asint(self):
458 return self.value
459
460 def __float__(self):
461 """convert to double-precision float. TODO, properly convert
462 rather than a hack-job: must actually support Power IEEE754 FP
463 """
464 assert self.bits == 64 # must be 64-bit
465 data = self.value.to_bytes(8, byteorder='little')
466 return struct.unpack('<d', data)[0]
467
468
469 class SelectableIntMappingMeta(type):
470 @functools.total_ordering
471 class Field(FieldSelectableInt):
472 def __eq__(self, b):
473 a = self.asint(msb0=True)
474 return a.__eq__(b)
475
476 def __lt__(self, b):
477 a = self.asint(msb0=True)
478 return a.__lt__(b)
479
480 class FieldProperty:
481 def __init__(self, field):
482 self.__field = field
483
484 def __get__(self, instance, owner):
485 if instance is None:
486 return self.__field
487 res = FieldSelectableInt(si=instance, br=self.__field)
488 return res.asint(msb0=True)
489
490 class BitsProperty:
491 def __init__(self, bits):
492 self.__bits = bits
493
494 def __get__(self, instance, owner):
495 if instance is None:
496 return self.__bits
497 return instance.bits
498
499 def __new__(metacls, name, bases, attrs, bits=None, fields=None):
500 if fields is None:
501 fields = {}
502
503 def field(item):
504 (key, value) = item
505 if isinstance(value, dict):
506 value = dict(map(field, value.items()))
507 else:
508 value = tuple(value)
509 return (key, value)
510
511 for (key, value) in map(field, fields.items()):
512 attrs.setdefault(key, metacls.FieldProperty(value))
513
514 if bits is None:
515 for base in bases:
516 bits = getattr(base, "bits", None)
517 if bits is not None:
518 break
519
520 if not isinstance(bits, int):
521 raise ValueError(bits)
522 attrs.setdefault("bits", metacls.BitsProperty(bits))
523
524 return super().__new__(metacls, name, bases, attrs)
525
526 def __iter__(cls):
527 for (key, value) in cls.__fields.items():
528 yield (key, value)
529
530
531 class SelectableIntMapping(SelectableInt, metaclass=SelectableIntMappingMeta,
532 bits=0):
533 def __init__(self, value=0, bits=None):
534 if isinstance(value, int) and bits is None:
535 bits = self.__class__.bits
536 return super().__init__(value, bits)
537
538
539 def onebit(bit):
540 return SelectableInt(1 if bit else 0, 1)
541
542
543 def selectltu(lhs, rhs):
544 """ less-than (unsigned)
545 """
546 if isinstance(rhs, SelectableInt):
547 rhs = rhs.value
548 return onebit(lhs.value < rhs)
549
550
551 def selectgtu(lhs, rhs):
552 """ greater-than (unsigned)
553 """
554 if isinstance(rhs, SelectableInt):
555 rhs = rhs.value
556 return onebit(lhs.value > rhs)
557
558
559 # XXX this probably isn't needed...
560 def selectassign(lhs, idx, rhs):
561 if isinstance(idx, tuple):
562 if len(idx) == 2:
563 lower, upper = idx
564 step = None
565 else:
566 lower, upper, step = idx
567 toidx = range(lower, upper, step)
568 fromidx = range(0, upper-lower, step) # XXX eurgh...
569 else:
570 toidx = [idx]
571 fromidx = [0]
572 for t, f in zip(toidx, fromidx):
573 lhs[t] = rhs[f]
574
575
576 def selectconcat(*args, repeat=1):
577 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
578 args = [SelectableInt(args[0], 1)]
579 if repeat != 1: # multiplies the incoming arguments
580 tmp = []
581 for i in range(repeat):
582 tmp += args
583 args = tmp
584 res = copy(args[0])
585 for i in args[1:]:
586 if isinstance(i, FieldSelectableInt):
587 i = i.si
588 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
589 res.bits += i.bits
590 res.value = (res.value << i.bits) | i.value
591 log("concat", repeat, res)
592 return res
593
594
595 class SelectableIntTestCase(unittest.TestCase):
596 def test_arith(self):
597 a = SelectableInt(5, 8)
598 b = SelectableInt(9, 8)
599 c = a + b
600 d = a - b
601 e = a * b
602 f = -a
603 g = abs(f)
604 h = abs(a)
605 self.assertEqual(c.value, a.value + b.value)
606 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
607 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
608 self.assertEqual(f.value, (-a.value) & 0xFF)
609 self.assertEqual(c.bits, a.bits)
610 self.assertEqual(d.bits, a.bits)
611 self.assertEqual(e.bits, a.bits)
612 self.assertEqual(a.bits, f.bits)
613 self.assertEqual(a.bits, h.bits)
614
615 def test_logic(self):
616 a = SelectableInt(0x0F, 8)
617 b = SelectableInt(0xA5, 8)
618 c = a & b
619 d = a | b
620 e = a ^ b
621 f = ~a
622 self.assertEqual(c.value, a.value & b.value)
623 self.assertEqual(d.value, a.value | b.value)
624 self.assertEqual(e.value, a.value ^ b.value)
625 self.assertEqual(f.value, 0xF0)
626
627 def test_get(self):
628 a = SelectableInt(0xa2, 8)
629 # These should be big endian
630 self.assertEqual(a[7], 0)
631 self.assertEqual(a[0:4], 10)
632 self.assertEqual(a[4:8], 2)
633
634 def test_set(self):
635 a = SelectableInt(0x5, 8)
636 a[7] = SelectableInt(0, 1)
637 self.assertEqual(a, 4)
638 a[4:8] = 9
639 self.assertEqual(a, 9)
640 a[0:4] = 3
641 self.assertEqual(a, 0x39)
642 a[0:4] = a[4:8]
643 self.assertEqual(a, 0x99)
644
645 def test_concat(self):
646 a = SelectableInt(0x1, 1)
647 c = selectconcat(a, repeat=8)
648 self.assertEqual(c, 0xff)
649 self.assertEqual(c.bits, 8)
650 a = SelectableInt(0x0, 1)
651 c = selectconcat(a, repeat=8)
652 self.assertEqual(c, 0x00)
653 self.assertEqual(c.bits, 8)
654
655 def test_repr(self):
656 for i in range(65536):
657 a = SelectableInt(i, 16)
658 b = eval(repr(a))
659 self.assertEqual(a, b)
660
661 def test_cmp(self):
662 a = SelectableInt(10, bits=8)
663 b = SelectableInt(5, bits=8)
664 self.assertTrue(a > b)
665 self.assertFalse(a < b)
666 self.assertTrue(a != b)
667 self.assertFalse(a == b)
668
669 def test_unsigned(self):
670 a = SelectableInt(0x80, bits=8)
671 b = SelectableInt(0x7f, bits=8)
672 self.assertTrue(a > b)
673 self.assertFalse(a < b)
674 self.assertTrue(a != b)
675 self.assertFalse(a == b)
676
677 def test_maxint(self):
678 a = SelectableInt(0xffffffffffffffff, bits=64)
679 b = SelectableInt(0, bits=64)
680 result = a + b
681 self.assertTrue(result.value == 0xffffffffffffffff)
682
683 def test_double_1(self):
684 """use http://weitz.de/ieee/,
685 """
686 for asint, asfloat in [(0x4000000000000000, 2.0),
687 (0x4056C00000000000, 91.0),
688 (0xff80000000000000, -1.4044477616111843e+306),
689 ]:
690 a = SelectableInt(asint, bits=64)
691 convert = float(a)
692 log ("test_double_1", asint, asfloat, convert)
693 self.assertTrue(asfloat == convert)
694
695
696 if __name__ == "__main__":
697 unittest.main()