power_insn: support PPC multi-records
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from collections import OrderedDict
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10
11 def check_extsign(a, b):
12 if isinstance(b, FieldSelectableInt):
13 b = b.get_range()
14 if isinstance(b, int):
15 return SelectableInt(b, a.bits)
16 if b.bits != 256:
17 return b
18 return SelectableInt(b.value, a.bits)
19
20
21 class BitRange(OrderedDict):
22 """BitRange: remaps from straight indices (0,1,2..) to bit numbers
23 """
24
25 def __getitem__(self, subscript):
26 if isinstance(subscript, slice):
27 return list(self.values())[subscript]
28 else:
29 return OrderedDict.__getitem__(self, subscript)
30
31
32 @functools.total_ordering
33 class FieldSelectableInt:
34 """FieldSelectableInt: allows bit-range selection onto another target
35 """
36
37 def __init__(self, si, br):
38 if not isinstance(si, (FieldSelectableInt, SelectableInt)):
39 raise ValueError(si)
40
41 if isinstance(br, (list, tuple, range)):
42 _br = BitRange()
43 for i, v in enumerate(br):
44 _br[i] = v
45 br = _br
46
47 if isinstance(si, FieldSelectableInt):
48 fsi = si
49 if len(br) > len(fsi.br):
50 raise OverflowError(br)
51 _br = BitRange()
52 for (i, v) in br.items():
53 _br[i] = fsi.br[v]
54 br = _br
55 si = fsi.si
56
57 self.si = si # target selectable int
58 self.br = br # map of indices
59
60 def eq(self, b):
61 if isinstance(b, int):
62 # convert integer to same SelectableInt of same bitlength as range
63 blen = len(self.br)
64 b = SelectableInt(b, blen)
65 for i in range(b.bits):
66 self[i] = b[i]
67 elif isinstance(b, SelectableInt):
68 for i in range(b.bits):
69 self[i] = b[i]
70 else:
71 self.si = copy(b.si)
72 self.br = copy(b.br)
73
74 def _op(self, op, b):
75 vi = self.get_range()
76 vi = op(vi, b)
77 return self.merge(vi)
78
79 def _op1(self, op):
80 vi = self.get_range()
81 vi = op(vi)
82 return self.merge(vi)
83
84 def __len__(self):
85 return len(self.br)
86
87 def __getitem__(self, key):
88 log("getitem", key, self.br)
89 if isinstance(key, SelectableInt):
90 key = key.value
91
92 if isinstance(key, int):
93 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
94 return self.si[key]
95 elif isinstance(key, slice):
96 key = self.br[key]
97 return selectconcat(*[self.si[x] for x in key])
98 elif isinstance(key, (tuple, list, range)):
99 return FieldSelectableInt(si=self, br=key)
100 else:
101 raise ValueError(key)
102
103 def __setitem__(self, key, value):
104 if isinstance(key, SelectableInt):
105 key = key.value
106 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
107 if isinstance(key, int):
108 return self.si.__setitem__(key, value)
109 else:
110 if not isinstance(value, SelectableInt):
111 value = SelectableInt(value, bits=len(key))
112 for i, k in enumerate(key):
113 self.si[k] = value[i]
114
115 def __negate__(self):
116 return self._op1(neg)
117
118 def __invert__(self):
119 return self._op1(inv)
120
121 def __add__(self, b):
122 return self._op(add, b)
123
124 def __sub__(self, b):
125 return self._op(sub, b)
126
127 def __mul__(self, b):
128 return self._op(mul, b)
129
130 def __div__(self, b):
131 return self._op(truediv, b)
132
133 def __mod__(self, b):
134 return self._op(mod, b)
135
136 def __and__(self, b):
137 return self._op(and_, b)
138
139 def __or__(self, b):
140 return self._op(or_, b)
141
142 def __xor__(self, b):
143 return self._op(xor, b)
144
145 def __lt__(self, b):
146 vi = self.get_range()
147 return onebit(lt(vi, b))
148
149 def __eq__(self, b):
150 vi = self.get_range()
151 return onebit(eq(vi, b))
152
153 def get_range(self):
154 vi = SelectableInt(0, len(self.br))
155 for k, v in self.br.items():
156 vi[k] = self.si[v]
157 return vi
158
159 def merge(self, vi):
160 fi = copy(self)
161 for i, v in fi.br.items():
162 fi.si[v] = vi[i]
163 return fi
164
165 def __repr__(self):
166 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
167
168 def __bool__(self):
169 for key in self.br.values():
170 bit = self.si[key].value
171 if bit:
172 return True
173 return False
174
175 def __int__(self):
176 return self.asint(msb0=True)
177
178 def asint(self, msb0=False):
179 res = 0
180 brlen = len(self.br)
181 for i, key in self.br.items():
182 log("asint", i, key, self.si[key])
183 bit = self.si[key].value
184 #log("asint", i, key, bit)
185 res |= bit << ((brlen-i-1) if msb0 else i)
186 return res
187
188
189 class FieldSelectableIntTestCase(unittest.TestCase):
190 def test_arith(self):
191 a = SelectableInt(0b10101, 5)
192 b = SelectableInt(0b011, 3)
193 br = BitRange()
194 br[0] = 0
195 br[1] = 2
196 br[2] = 3
197 fs = FieldSelectableInt(a, br)
198 c = fs + b
199 log(c)
200 #self.assertEqual(c.value, a.value + b.value)
201
202 def test_select(self):
203 a = SelectableInt(0b00001111, 8)
204 br = BitRange()
205 br[0] = 0
206 br[1] = 1
207 br[2] = 4
208 br[3] = 5
209 fs = FieldSelectableInt(a, br)
210
211 self.assertEqual(fs.get_range(), 0b0011)
212
213 def test_select_range(self):
214 a = SelectableInt(0b00001111, 8)
215 br = BitRange()
216 br[0] = 0
217 br[1] = 1
218 br[2] = 4
219 br[3] = 5
220 fs = FieldSelectableInt(a, br)
221
222 self.assertEqual(fs[2:4], 0b11)
223
224 fs[0:2] = 0b10
225 self.assertEqual(fs.get_range(), 0b1011)
226
227
228 class SelectableInt:
229 """SelectableInt - a class that behaves exactly like python int
230
231 this class is designed to mirror precisely the behaviour of python int.
232 the only difference is that it must contain the context of the bitwidth
233 (number of bits) associated with that integer.
234
235 FieldSelectableInt can then operate on partial bits, and because there
236 is a bit width associated with SelectableInt, slices operate correctly
237 including negative start/end points.
238 """
239
240 def __init__(self, value, bits=None):
241 if isinstance(value, SelectableInt):
242 if bits is not None:
243 # check if the bitlength is different. TODO, allow override?
244 if bits != value.bits:
245 raise ValueError(value)
246 bits = value.bits
247 value = value.value
248 elif isinstance(value, FieldSelectableInt):
249 if bits is not None:
250 raise ValueError(value)
251 bits = len(value.br)
252 value = value.si.value
253 else:
254 if not isinstance(value, int):
255 raise ValueError(value)
256 if bits is None:
257 raise ValueError(bits)
258 mask = (1 << bits) - 1
259 self.value = value & mask
260 self.bits = bits
261 self.overflow = (value & ~mask) != 0
262
263 def eq(self, b):
264 self.value = b.value
265 self.bits = b.bits
266
267 def to_signed_int(self):
268 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
269 if self.value & (1<<(self.bits-1)) != 0: # negative
270 res = self.value - (1<<self.bits)
271 log (" val -ve:", self.bits, res)
272 else:
273 res = self.value
274 log (" val +ve:", res)
275 return res
276
277 def _op(self, op, b):
278 if isinstance(b, int):
279 b = SelectableInt(b, self.bits)
280 b = check_extsign(self, b)
281 assert b.bits == self.bits
282 return SelectableInt(op(self.value, b.value), self.bits)
283
284 def __add__(self, b):
285 return self._op(add, b)
286
287 def __sub__(self, b):
288 return self._op(sub, b)
289
290 def __mul__(self, b):
291 # different case: mul result needs to fit the total bitsize
292 if isinstance(b, int):
293 b = SelectableInt(b, self.bits)
294 log("SelectableInt mul", hex(self.value), hex(b.value),
295 self.bits, b.bits)
296 return SelectableInt(self.value * b.value, self.bits + b.bits)
297
298 def __floordiv__(self, b):
299 return self._op(floordiv, b)
300
301 def __truediv__(self, b):
302 return self._op(truediv, b)
303
304 def __mod__(self, b):
305 return self._op(mod, b)
306
307 def __and__(self, b):
308 return self._op(and_, b)
309
310 def __or__(self, b):
311 return self._op(or_, b)
312
313 def __xor__(self, b):
314 return self._op(xor, b)
315
316 def __abs__(self):
317 log("abs", self.value & (1 << (self.bits-1)))
318 if self.value & (1 << (self.bits-1)) != 0:
319 return -self
320 return self
321
322 def __rsub__(self, b):
323 log("rsub", b, self.value)
324 if isinstance(b, int):
325 b = SelectableInt(b, 256) # max extent
326 #b = check_extsign(self, b)
327 #assert b.bits == self.bits
328 return SelectableInt(b.value - self.value, b.bits)
329
330 def __radd__(self, b):
331 if isinstance(b, int):
332 b = SelectableInt(b, self.bits)
333 b = check_extsign(self, b)
334 assert b.bits == self.bits
335 return SelectableInt(b.value + self.value, self.bits)
336
337 def __rxor__(self, b):
338 b = check_extsign(self, b)
339 assert b.bits == self.bits
340 return SelectableInt(self.value ^ b.value, self.bits)
341
342 def __invert__(self):
343 return SelectableInt(~self.value, self.bits)
344
345 def __neg__(self):
346 res = SelectableInt((~self.value) + 1, self.bits)
347 log ("neg", hex(self.value), hex(res.value))
348 return res
349
350 def __lshift__(self, b):
351 b = check_extsign(self, b)
352 return SelectableInt(self.value << b.value, self.bits)
353
354 def __rshift__(self, b):
355 b = check_extsign(self, b)
356 return SelectableInt(self.value >> b.value, self.bits)
357
358 def __getitem__(self, key):
359 log ("SelectableInt.__getitem__", self, key, type(key))
360 if isinstance(key, SelectableInt):
361 key = key.value
362 if isinstance(key, int):
363 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
364 assert key >= 0
365 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
366 # MSB is indexed **LOWEST** (sigh)
367 key = self.bits - (key + 1)
368
369 value = (self.value >> key) & 1
370 log("getitem", key, self.bits, hex(self.value), value)
371 return SelectableInt(value, 1)
372 elif isinstance(key, slice):
373 assert key.step is None or key.step == 1
374 assert key.start < key.stop
375 assert key.start >= 0
376 assert key.stop <= self.bits
377
378 stop = self.bits - key.start
379 start = self.bits - key.stop
380
381 bits = stop - start
382 log ("__getitem__ slice num bits", start, stop, bits)
383 mask = (1 << bits) - 1
384 value = (self.value >> start) & mask
385 log("getitem", stop, start, self.bits, hex(self.value), value)
386 return SelectableInt(value, bits)
387 else:
388 bits = []
389 key = tuple(key)
390 for bit in key:
391 if not isinstance(bit, (int, SelectableInt)):
392 raise ValueError(key)
393 bits.append(self[bit])
394 return selectconcat(*bits)
395
396 def __setitem__(self, key, value):
397 if isinstance(key, SelectableInt):
398 key = key.value
399 if isinstance(key, int):
400 if isinstance(value, SelectableInt):
401 assert value.bits == 1
402 value = value.value
403 log("setitem", key, self.bits, hex(self.value), hex(value))
404
405 assert key < self.bits
406 assert key >= 0
407 key = self.bits - (key + 1)
408
409 value = value << key
410 mask = 1 << key
411 self.value = (self.value & ~mask) | (value & mask)
412 elif isinstance(key, slice):
413 kstart, kstop, kstep = key.start, key.stop, key.step
414 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
415 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
416 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
417 log ("__setitem__ slice ", kstart, kstop, kstep)
418 assert kstep is None or kstep == 1
419 assert kstart < kstop
420 assert kstart >= 0
421 assert kstop <= self.bits, \
422 "key stop %d bits %d" % (kstop, self.bits)
423
424 stop = self.bits - kstart
425 start = self.bits - kstop
426
427 bits = stop - start
428 #log ("__setitem__ slice num bits", bits)
429 if isinstance(value, SelectableInt):
430 assert value.bits == bits, "%d into %d" % (value.bits, bits)
431 value = value.value
432 log("setitem", key, self.bits, hex(self.value), hex(value))
433 mask = ((1 << bits) - 1) << start
434 value = value << start
435 self.value = (self.value & ~mask) | (value & mask)
436 else:
437 raise ValueError(key)
438
439 def __ge__(self, other):
440 if isinstance(other, FieldSelectableInt):
441 other = other.get_range()
442 if isinstance(other, SelectableInt):
443 other = check_extsign(self, other)
444 assert other.bits == self.bits
445 other = other.to_signed_int()
446 if isinstance(other, int):
447 return onebit(self.to_signed_int() >= other)
448 assert False
449
450 def __le__(self, other):
451 if isinstance(other, FieldSelectableInt):
452 other = other.get_range()
453 if isinstance(other, SelectableInt):
454 other = check_extsign(self, other)
455 assert other.bits == self.bits
456 other = other.to_signed_int()
457 if isinstance(other, int):
458 return onebit(self.to_signed_int() <= other)
459 assert False
460
461 def __gt__(self, other):
462 if isinstance(other, FieldSelectableInt):
463 other = other.get_range()
464 if isinstance(other, SelectableInt):
465 other = check_extsign(self, other)
466 assert other.bits == self.bits
467 other = other.to_signed_int()
468 if isinstance(other, int):
469 return onebit(self.to_signed_int() > other)
470 assert False
471
472 def __lt__(self, other):
473 log ("SelectableInt lt", self, other)
474 if isinstance(other, FieldSelectableInt):
475 other = other.get_range()
476 if isinstance(other, SelectableInt):
477 other = check_extsign(self, other)
478 assert other.bits == self.bits
479 other = other.to_signed_int()
480 if isinstance(other, int):
481 a = self.to_signed_int()
482 res = onebit(a < other)
483 log (" a < b", a, other, res)
484 return res
485 assert False
486
487 def __eq__(self, other):
488 log("__eq__", self, other)
489 if isinstance(other, FieldSelectableInt):
490 other = other.get_range()
491 if isinstance(other, SelectableInt):
492 other = check_extsign(self, other)
493 assert other.bits == self.bits
494 other = other.value
495 log (" eq", other, self.value, other == self.value)
496 if isinstance(other, int):
497 return onebit(other == self.value)
498 assert False
499
500 def narrow(self, bits):
501 assert bits <= self.bits
502 return SelectableInt(self.value, bits)
503
504 def __bool__(self):
505 return self.value != 0
506
507 def __repr__(self):
508 value = f"value={hex(self.value)}, bits={self.bits}"
509 return f"{self.__class__.__name__}({value})"
510
511 def __len__(self):
512 return self.bits
513
514 def asint(self):
515 return self.value
516
517 def __int__(self):
518 return self.asint()
519
520 def __float__(self):
521 """convert to double-precision float. TODO, properly convert
522 rather than a hack-job: must actually support Power IEEE754 FP
523 """
524 assert self.bits == 64 # must be 64-bit
525 data = self.value.to_bytes(8, byteorder='little')
526 return struct.unpack('<d', data)[0]
527
528
529 def onebit(bit):
530 return SelectableInt(1 if bit else 0, 1)
531
532
533 def selectltu(lhs, rhs):
534 """ less-than (unsigned)
535 """
536 if isinstance(rhs, SelectableInt):
537 rhs = rhs.value
538 return onebit(lhs.value < rhs)
539
540
541 def selectgtu(lhs, rhs):
542 """ greater-than (unsigned)
543 """
544 if isinstance(rhs, SelectableInt):
545 rhs = rhs.value
546 return onebit(lhs.value > rhs)
547
548
549 # XXX this probably isn't needed...
550 def selectassign(lhs, idx, rhs):
551 if isinstance(idx, tuple):
552 if len(idx) == 2:
553 lower, upper = idx
554 step = None
555 else:
556 lower, upper, step = idx
557 toidx = range(lower, upper, step)
558 fromidx = range(0, upper-lower, step) # XXX eurgh...
559 else:
560 toidx = [idx]
561 fromidx = [0]
562 for t, f in zip(toidx, fromidx):
563 lhs[t] = rhs[f]
564
565
566 def selectconcat(*args, repeat=1):
567 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
568 args = [SelectableInt(args[0], 1)]
569 if repeat != 1: # multiplies the incoming arguments
570 tmp = []
571 for i in range(repeat):
572 tmp += args
573 args = tmp
574 res = copy(args[0])
575 for i in args[1:]:
576 if isinstance(i, FieldSelectableInt):
577 i = i.si
578 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
579 res.bits += i.bits
580 res.value = (res.value << i.bits) | i.value
581 log("concat", repeat, res)
582 return res
583
584
585 class SelectableIntTestCase(unittest.TestCase):
586 def test_arith(self):
587 a = SelectableInt(5, 8)
588 b = SelectableInt(9, 8)
589 c = a + b
590 d = a - b
591 e = a * b
592 f = -a
593 g = abs(f)
594 h = abs(a)
595 self.assertEqual(c.value, a.value + b.value)
596 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
597 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
598 self.assertEqual(f.value, (-a.value) & 0xFF)
599 self.assertEqual(c.bits, a.bits)
600 self.assertEqual(d.bits, a.bits)
601 self.assertEqual(e.bits, a.bits)
602 self.assertEqual(a.bits, f.bits)
603 self.assertEqual(a.bits, h.bits)
604
605 def test_logic(self):
606 a = SelectableInt(0x0F, 8)
607 b = SelectableInt(0xA5, 8)
608 c = a & b
609 d = a | b
610 e = a ^ b
611 f = ~a
612 self.assertEqual(c.value, a.value & b.value)
613 self.assertEqual(d.value, a.value | b.value)
614 self.assertEqual(e.value, a.value ^ b.value)
615 self.assertEqual(f.value, 0xF0)
616
617 def test_get(self):
618 a = SelectableInt(0xa2, 8)
619 # These should be big endian
620 self.assertEqual(a[7], 0)
621 self.assertEqual(a[0:4], 10)
622 self.assertEqual(a[4:8], 2)
623
624 def test_set(self):
625 a = SelectableInt(0x5, 8)
626 a[7] = SelectableInt(0, 1)
627 self.assertEqual(a, 4)
628 a[4:8] = 9
629 self.assertEqual(a, 9)
630 a[0:4] = 3
631 self.assertEqual(a, 0x39)
632 a[0:4] = a[4:8]
633 self.assertEqual(a, 0x99)
634
635 def test_concat(self):
636 a = SelectableInt(0x1, 1)
637 c = selectconcat(a, repeat=8)
638 self.assertEqual(c, 0xff)
639 self.assertEqual(c.bits, 8)
640 a = SelectableInt(0x0, 1)
641 c = selectconcat(a, repeat=8)
642 self.assertEqual(c, 0x00)
643 self.assertEqual(c.bits, 8)
644
645 def test_repr(self):
646 for i in range(65536):
647 a = SelectableInt(i, 16)
648 b = eval(repr(a))
649 self.assertEqual(a, b)
650
651 def test_cmp(self):
652 a = SelectableInt(10, bits=8)
653 b = SelectableInt(5, bits=8)
654 self.assertTrue(a > b)
655 self.assertFalse(a < b)
656 self.assertTrue(a != b)
657 self.assertFalse(a == b)
658
659 def test_unsigned(self):
660 a = SelectableInt(0x80, bits=8)
661 b = SelectableInt(0x7f, bits=8)
662 self.assertTrue(a > b)
663 self.assertFalse(a < b)
664 self.assertTrue(a != b)
665 self.assertFalse(a == b)
666
667 def test_maxint(self):
668 a = SelectableInt(0xffffffffffffffff, bits=64)
669 b = SelectableInt(0, bits=64)
670 result = a + b
671 self.assertTrue(result.value == 0xffffffffffffffff)
672
673 def test_double_1(self):
674 """use http://weitz.de/ieee/,
675 """
676 for asint, asfloat in [(0x4000000000000000, 2.0),
677 (0x4056C00000000000, 91.0),
678 (0xff80000000000000, -1.4044477616111843e+306),
679 ]:
680 a = SelectableInt(asint, bits=64)
681 convert = float(a)
682 log ("test_double_1", asint, asfloat, convert)
683 self.assertTrue(asfloat == convert)
684
685
686 if __name__ == "__main__":
687 unittest.main()