181741216a40d1e9b2f0d01b061ad97d46d9e5dc
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from collections import OrderedDict
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10
11 def check_extsign(a, b):
12 if isinstance(b, FieldSelectableInt):
13 b = b.get_range()
14 if isinstance(b, int):
15 return SelectableInt(b, a.bits)
16 if b.bits != 256:
17 return b
18 return SelectableInt(b.value, a.bits)
19
20
21 class BitRange(OrderedDict):
22 """BitRange: remaps from straight indices (0,1,2..) to bit numbers
23 """
24
25 def __getitem__(self, subscript):
26 if isinstance(subscript, slice):
27 return list(self.values())[subscript]
28 else:
29 return OrderedDict.__getitem__(self, subscript)
30
31
32 @functools.total_ordering
33 class FieldSelectableInt:
34 """FieldSelectableInt: allows bit-range selection onto another target
35 """
36
37 def __init__(self, si, br):
38 if not isinstance(si, (FieldSelectableInt, SelectableInt)):
39 raise ValueError(si)
40
41 if isinstance(br, (list, tuple, range)):
42 _br = BitRange()
43 for i, v in enumerate(br):
44 _br[i] = v
45 br = _br
46
47 if isinstance(si, FieldSelectableInt):
48 fsi = si
49 if len(br) > len(fsi.br):
50 raise OverflowError(br)
51 _br = BitRange()
52 for (i, v) in br.items():
53 _br[i] = fsi.br[v]
54 br = _br
55 si = fsi.si
56
57 self.si = si # target selectable int
58 self.br = br # map of indices
59
60 def eq(self, b):
61 if isinstance(b, int):
62 # convert integer to same SelectableInt of same bitlength as range
63 blen = len(self.br)
64 b = SelectableInt(b, blen)
65 for i in range(b.bits):
66 self[i] = b[i]
67 elif isinstance(b, SelectableInt):
68 for i in range(b.bits):
69 self[i] = b[i]
70 else:
71 self.si = copy(b.si)
72 self.br = copy(b.br)
73
74 def _op(self, op, b):
75 vi = self.get_range()
76 vi = op(vi, b)
77 return self.merge(vi)
78
79 def _op1(self, op):
80 vi = self.get_range()
81 vi = op(vi)
82 return self.merge(vi)
83
84 def __len__(self):
85 return len(self.br)
86
87 def __getitem__(self, key):
88 log("getitem", key, self.br)
89 if isinstance(key, SelectableInt):
90 key = key.value
91
92 if isinstance(key, int):
93 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
94 return self.si[key]
95 elif isinstance(key, slice):
96 key = self.br[key]
97 return selectconcat(*[self.si[x] for x in key])
98 elif isinstance(key, (tuple, list, range)):
99 return FieldSelectableInt(si=self, br=key)
100 else:
101 raise ValueError(key)
102
103 def __setitem__(self, key, value):
104 if isinstance(key, SelectableInt):
105 key = key.value
106 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
107 if isinstance(key, int):
108 return self.si.__setitem__(key, value)
109 else:
110 if not isinstance(value, SelectableInt):
111 value = SelectableInt(value, bits=len(key))
112 for i, k in enumerate(key):
113 self.si[k] = value[i]
114
115 def __negate__(self):
116 return self._op1(neg)
117
118 def __invert__(self):
119 return self._op1(inv)
120
121 def __add__(self, b):
122 return self._op(add, b)
123
124 def __sub__(self, b):
125 return self._op(sub, b)
126
127 def __mul__(self, b):
128 return self._op(mul, b)
129
130 def __div__(self, b):
131 return self._op(truediv, b)
132
133 def __mod__(self, b):
134 return self._op(mod, b)
135
136 def __and__(self, b):
137 return self._op(and_, b)
138
139 def __or__(self, b):
140 return self._op(or_, b)
141
142 def __xor__(self, b):
143 return self._op(xor, b)
144
145 def __lt__(self, b):
146 vi = self.get_range()
147 return onebit(lt(vi, b))
148
149 def __eq__(self, b):
150 vi = self.get_range()
151 return onebit(eq(vi, b))
152
153 def get_range(self):
154 vi = SelectableInt(0, len(self.br))
155 for k, v in self.br.items():
156 vi[k] = self.si[v]
157 return vi
158
159 def merge(self, vi):
160 fi = copy(self)
161 for i, v in fi.br.items():
162 fi.si[v] = vi[i]
163 return fi
164
165 def __repr__(self):
166 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
167
168 def __bool__(self):
169 for key in self.br.values():
170 bit = self.si[key].value
171 if bit:
172 return True
173 return False
174
175 def __int__(self):
176 return self.asint(msb0=True)
177
178 def asint(self, msb0=False):
179 res = 0
180 brlen = len(self.br)
181 for i, key in self.br.items():
182 #log("asint", i, key, self.si[key])
183 bit = self.si[key].value
184 #log("asint", i, key, bit)
185 res |= bit << ((brlen-i-1) if msb0 else i)
186 return res
187
188
189 class FieldSelectableIntTestCase(unittest.TestCase):
190 def test_arith(self):
191 a = SelectableInt(0b10101, 5)
192 b = SelectableInt(0b011, 3)
193 br = BitRange()
194 br[0] = 0
195 br[1] = 2
196 br[2] = 3
197 fs = FieldSelectableInt(a, br)
198 c = fs + b
199 log(c)
200 #self.assertEqual(c.value, a.value + b.value)
201
202 def test_select(self):
203 a = SelectableInt(0b00001111, 8)
204 br = BitRange()
205 br[0] = 0
206 br[1] = 1
207 br[2] = 4
208 br[3] = 5
209 fs = FieldSelectableInt(a, br)
210
211 self.assertEqual(fs.get_range(), 0b0011)
212
213 def test_select_range(self):
214 a = SelectableInt(0b00001111, 8)
215 br = BitRange()
216 br[0] = 0
217 br[1] = 1
218 br[2] = 4
219 br[3] = 5
220 fs = FieldSelectableInt(a, br)
221
222 self.assertEqual(fs[2:4], 0b11)
223
224 fs[0:2] = 0b10
225 self.assertEqual(fs.get_range(), 0b1011)
226
227
228 @functools.total_ordering
229 class SelectableInt:
230 """SelectableInt - a class that behaves exactly like python int
231
232 this class is designed to mirror precisely the behaviour of python int.
233 the only difference is that it must contain the context of the bitwidth
234 (number of bits) associated with that integer.
235
236 FieldSelectableInt can then operate on partial bits, and because there
237 is a bit width associated with SelectableInt, slices operate correctly
238 including negative start/end points.
239 """
240
241 def __init__(self, value, bits=None):
242 if isinstance(value, SelectableInt):
243 if bits is not None:
244 # check if the bitlength is different. TODO, allow override?
245 if bits != value.bits:
246 raise ValueError(value)
247 bits = value.bits
248 value = value.value
249 elif isinstance(value, FieldSelectableInt):
250 if bits is not None:
251 raise ValueError(value)
252 bits = len(value.br)
253 value = value.si.value
254 else:
255 if not isinstance(value, int):
256 raise ValueError(value)
257 if bits is None:
258 raise ValueError(bits)
259 mask = (1 << bits) - 1
260 self.value = value & mask
261 self.bits = bits
262 self.overflow = (value & ~mask) != 0
263
264 def eq(self, b):
265 self.value = b.value
266 self.bits = b.bits
267
268 def to_signed_int(self):
269 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
270 if self.value & (1<<(self.bits-1)) != 0: # negative
271 res = self.value - (1<<self.bits)
272 log (" val -ve:", self.bits, res)
273 else:
274 res = self.value
275 log (" val +ve:", res)
276 return res
277
278 def _op(self, op, b):
279 if isinstance(b, int):
280 b = SelectableInt(b, self.bits)
281 b = check_extsign(self, b)
282 assert b.bits == self.bits
283 return SelectableInt(op(self.value, b.value), self.bits)
284
285 def __add__(self, b):
286 return self._op(add, b)
287
288 def __sub__(self, b):
289 return self._op(sub, b)
290
291 def __mul__(self, b):
292 # different case: mul result needs to fit the total bitsize
293 if isinstance(b, int):
294 b = SelectableInt(b, self.bits)
295 log("SelectableInt mul", hex(self.value), hex(b.value),
296 self.bits, b.bits)
297 return SelectableInt(self.value * b.value, self.bits + b.bits)
298
299 def __floordiv__(self, b):
300 return self._op(floordiv, b)
301
302 def __truediv__(self, b):
303 return self._op(truediv, b)
304
305 def __mod__(self, b):
306 return self._op(mod, b)
307
308 def __and__(self, b):
309 return self._op(and_, b)
310
311 def __or__(self, b):
312 return self._op(or_, b)
313
314 def __xor__(self, b):
315 return self._op(xor, b)
316
317 def __abs__(self):
318 log("abs", self.value & (1 << (self.bits-1)))
319 if self.value & (1 << (self.bits-1)) != 0:
320 return -self
321 return self
322
323 def __rsub__(self, b):
324 log("rsub", b, self.value)
325 if isinstance(b, int):
326 b = SelectableInt(b, 256) # max extent
327 #b = check_extsign(self, b)
328 #assert b.bits == self.bits
329 return SelectableInt(b.value - self.value, b.bits)
330
331 def __radd__(self, b):
332 if isinstance(b, int):
333 b = SelectableInt(b, self.bits)
334 b = check_extsign(self, b)
335 assert b.bits == self.bits
336 return SelectableInt(b.value + self.value, self.bits)
337
338 def __rxor__(self, b):
339 b = check_extsign(self, b)
340 assert b.bits == self.bits
341 return SelectableInt(self.value ^ b.value, self.bits)
342
343 def __invert__(self):
344 return SelectableInt(~self.value, self.bits)
345
346 def __neg__(self):
347 res = SelectableInt((~self.value) + 1, self.bits)
348 log ("neg", hex(self.value), hex(res.value))
349 return res
350
351 def __lshift__(self, b):
352 b = check_extsign(self, b)
353 return SelectableInt(self.value << b.value, self.bits)
354
355 def __rshift__(self, b):
356 b = check_extsign(self, b)
357 return SelectableInt(self.value >> b.value, self.bits)
358
359 def __getitem__(self, key):
360 #log ("SelectableInt.__getitem__", self, key, type(key))
361 if isinstance(key, SelectableInt):
362 key = key.value
363 if isinstance(key, int):
364 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
365 assert key >= 0
366 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
367 # MSB is indexed **LOWEST** (sigh)
368 key = self.bits - (key + 1)
369
370 value = (self.value >> key) & 1
371 #log("getitem", key, self.bits, hex(self.value), value)
372 return SelectableInt(value, 1)
373 elif isinstance(key, slice):
374 assert key.step is None or key.step == 1
375 assert key.start < key.stop
376 assert key.start >= 0
377 assert key.stop <= self.bits
378
379 stop = self.bits - key.start
380 start = self.bits - key.stop
381
382 bits = stop - start
383 #log ("__getitem__ slice num bits", start, stop, bits)
384 mask = (1 << bits) - 1
385 value = (self.value >> start) & mask
386 #log("getitem", stop, start, self.bits, hex(self.value), value)
387 return SelectableInt(value, bits)
388 else:
389 bits = []
390 for bit in key:
391 if not isinstance(bit, (int, SelectableInt)):
392 raise ValueError(key)
393 bits.append(self[bit])
394 return selectconcat(*bits)
395
396 def __setitem__(self, key, value):
397 if isinstance(key, SelectableInt):
398 key = key.value
399 if isinstance(key, int):
400 if isinstance(value, SelectableInt):
401 assert value.bits == 1
402 value = value.value
403 #log("setitem", key, self.bits, hex(self.value), hex(value))
404
405 assert key < self.bits
406 assert key >= 0
407 key = self.bits - (key + 1)
408
409 value = value << key
410 mask = 1 << key
411 self.value = (self.value & ~mask) | (value & mask)
412 elif isinstance(key, slice):
413 kstart, kstop, kstep = key.start, key.stop, key.step
414 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
415 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
416 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
417 #log ("__setitem__ slice ", kstart, kstop, kstep)
418 assert kstep is None or kstep == 1
419 assert kstart < kstop
420 assert kstart >= 0
421 assert kstop <= self.bits, \
422 "key stop %d bits %d" % (kstop, self.bits)
423
424 stop = self.bits - kstart
425 start = self.bits - kstop
426
427 bits = stop - start
428 #log ("__setitem__ slice num bits", bits)
429 if isinstance(value, SelectableInt):
430 assert value.bits == bits, "%d into %d" % (value.bits, bits)
431 value = value.value
432 #log("setitem", key, self.bits, hex(self.value), hex(value))
433 mask = ((1 << bits) - 1) << start
434 value = value << start
435 self.value = (self.value & ~mask) | (value & mask)
436 else:
437 bits = []
438 for bit in key:
439 if not isinstance(bit, (int, SelectableInt)):
440 raise ValueError(key)
441 bits.append(bit)
442
443 if isinstance(value, int):
444 if value.bit_length() > len(bits):
445 raise ValueError(value)
446 value = SelectableInt(value=value, bits=len(bits))
447 if not isinstance(value, SelectableInt):
448 raise ValueError(value)
449
450 for (src, dst) in enumerate(bits):
451 self[dst] = value[src]
452
453 def __lt__(self, other):
454 log ("SelectableInt __lt__", self, other)
455 if isinstance(other, FieldSelectableInt):
456 other = other.get_range()
457 if isinstance(other, SelectableInt):
458 other = check_extsign(self, other)
459 assert other.bits == self.bits
460 other = other.to_signed_int()
461 if isinstance(other, int):
462 a = self.to_signed_int()
463 res = onebit(a < other)
464 log (" a < b", a, other, res)
465 return res
466 assert False
467
468 def __eq__(self, other):
469 log("SelectableInt __eq__", self, other)
470 if isinstance(other, FieldSelectableInt):
471 other = other.get_range()
472 if isinstance(other, SelectableInt):
473 other = check_extsign(self, other)
474 assert other.bits == self.bits
475 other = other.value
476 log (" eq", other, self.value, other == self.value)
477 if isinstance(other, int):
478 return onebit(other == self.value)
479 assert False
480
481 def narrow(self, bits):
482 assert bits <= self.bits
483 return SelectableInt(self.value, bits)
484
485 def __bool__(self):
486 return self.value != 0
487
488 def __repr__(self):
489 value = f"value={hex(self.value)}, bits={self.bits}"
490 return f"{self.__class__.__name__}({value})"
491
492 def __len__(self):
493 return self.bits
494
495 def asint(self):
496 return self.value
497
498 def __int__(self):
499 return self.asint()
500
501 def __float__(self):
502 """convert to double-precision float. TODO, properly convert
503 rather than a hack-job: must actually support Power IEEE754 FP
504 """
505 if self.bits == 32:
506 data = self.value.to_bytes(4, byteorder='little')
507 return struct.unpack('<f', data)[0]
508 assert self.bits == 64 # must be 64-bit
509 data = self.value.to_bytes(8, byteorder='little')
510 return struct.unpack('<d', data)[0]
511
512
513 def onebit(bit):
514 return SelectableInt(1 if bit else 0, 1)
515
516
517 def selectltu(lhs, rhs):
518 """ less-than (unsigned)
519 """
520 if isinstance(rhs, SelectableInt):
521 rhs = rhs.value
522 return onebit(lhs.value < rhs)
523
524
525 def selectgtu(lhs, rhs):
526 """ greater-than (unsigned)
527 """
528 if isinstance(rhs, SelectableInt):
529 rhs = rhs.value
530 return onebit(lhs.value > rhs)
531
532
533 # XXX this probably isn't needed...
534 def selectassign(lhs, idx, rhs):
535 if isinstance(idx, tuple):
536 if len(idx) == 2:
537 lower, upper = idx
538 step = None
539 else:
540 lower, upper, step = idx
541 toidx = range(lower, upper, step)
542 fromidx = range(0, upper-lower, step) # XXX eurgh...
543 else:
544 toidx = [idx]
545 fromidx = [0]
546 for t, f in zip(toidx, fromidx):
547 lhs[t] = rhs[f]
548
549
550 def selectconcat(*args, repeat=1):
551 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
552 args = [SelectableInt(args[0], 1)]
553 if repeat != 1: # multiplies the incoming arguments
554 tmp = []
555 for i in range(repeat):
556 tmp += args
557 args = tmp
558 res = copy(args[0])
559 for i in args[1:]:
560 if isinstance(i, FieldSelectableInt):
561 i = i.si
562 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
563 res.bits += i.bits
564 res.value = (res.value << i.bits) | i.value
565 log("concat", repeat, res)
566 return res
567
568
569 class SelectableIntTestCase(unittest.TestCase):
570 def test_arith(self):
571 a = SelectableInt(5, 8)
572 b = SelectableInt(9, 8)
573 c = a + b
574 d = a - b
575 e = a * b
576 f = -a
577 g = abs(f)
578 h = abs(a)
579 self.assertEqual(c.value, a.value + b.value)
580 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
581 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
582 self.assertEqual(f.value, (-a.value) & 0xFF)
583 self.assertEqual(c.bits, a.bits)
584 self.assertEqual(d.bits, a.bits)
585 self.assertEqual(e.bits, a.bits)
586 self.assertEqual(a.bits, f.bits)
587 self.assertEqual(a.bits, h.bits)
588
589 def test_logic(self):
590 a = SelectableInt(0x0F, 8)
591 b = SelectableInt(0xA5, 8)
592 c = a & b
593 d = a | b
594 e = a ^ b
595 f = ~a
596 self.assertEqual(c.value, a.value & b.value)
597 self.assertEqual(d.value, a.value | b.value)
598 self.assertEqual(e.value, a.value ^ b.value)
599 self.assertEqual(f.value, 0xF0)
600
601 def test_get(self):
602 a = SelectableInt(0xa2, 8)
603 # These should be big endian
604 self.assertEqual(a[7], 0)
605 self.assertEqual(a[0:4], 10)
606 self.assertEqual(a[4:8], 2)
607
608 def test_set(self):
609 a = SelectableInt(0x5, 8)
610 a[7] = SelectableInt(0, 1)
611 self.assertEqual(a, 4)
612 a[4:8] = 9
613 self.assertEqual(a, 9)
614 a[0:4] = 3
615 self.assertEqual(a, 0x39)
616 a[0:4] = a[4:8]
617 self.assertEqual(a, 0x99)
618
619 def test_concat(self):
620 a = SelectableInt(0x1, 1)
621 c = selectconcat(a, repeat=8)
622 self.assertEqual(c, 0xff)
623 self.assertEqual(c.bits, 8)
624 a = SelectableInt(0x0, 1)
625 c = selectconcat(a, repeat=8)
626 self.assertEqual(c, 0x00)
627 self.assertEqual(c.bits, 8)
628
629 def test_repr(self):
630 for i in range(65536):
631 a = SelectableInt(i, 16)
632 b = eval(repr(a))
633 self.assertEqual(a, b)
634
635 def test_cmp(self):
636 a = SelectableInt(10, bits=8)
637 b = SelectableInt(5, bits=8)
638 self.assertTrue(a > b)
639 self.assertFalse(a < b)
640 self.assertTrue(a != b)
641 self.assertFalse(a == b)
642
643 def test_unsigned(self):
644 a = SelectableInt(0x80, bits=8)
645 b = SelectableInt(0x7f, bits=8)
646 self.assertTrue(a > b)
647 self.assertFalse(a < b)
648 self.assertTrue(a != b)
649 self.assertFalse(a == b)
650
651 def test_maxint(self):
652 a = SelectableInt(0xffffffffffffffff, bits=64)
653 b = SelectableInt(0, bits=64)
654 result = a + b
655 self.assertTrue(result.value == 0xffffffffffffffff)
656
657 def test_double_1(self):
658 """use http://weitz.de/ieee/,
659 """
660 for asint, asfloat in [(0x4000000000000000, 2.0),
661 (0x4056C00000000000, 91.0),
662 (0xff80000000000000, -1.4044477616111843e+306),
663 ]:
664 a = SelectableInt(asint, bits=64)
665 convert = float(a)
666 log ("test_double_1", asint, asfloat, convert)
667 self.assertTrue(asfloat == convert)
668
669
670 if __name__ == "__main__":
671 unittest.main()