make truediv available to pseudocode
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 import functools
5 from collections import OrderedDict
6 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
7 neg, inv, lshift, rshift, lt, eq)
8 from openpower.util import log
9
10 EFFECTIVELY_UNLIMITED = 1024
11
12 def check_extsign(a, b):
13 if isinstance(b, FieldSelectableInt):
14 b = b.get_range()
15 if isinstance(b, int):
16 return SelectableInt(b, a.bits)
17 if b.bits != EFFECTIVELY_UNLIMITED:
18 return b
19 return SelectableInt(b.value, a.bits)
20
21
22 class BitRange(OrderedDict):
23 """BitRange: remaps from straight indices (0,1,2..) to bit numbers
24 """
25
26 def __getitem__(self, subscript):
27 if isinstance(subscript, slice):
28 return list(self.values())[subscript]
29 else:
30 return OrderedDict.__getitem__(self, subscript)
31
32
33 @functools.total_ordering
34 class FieldSelectableInt:
35 """FieldSelectableInt: allows bit-range selection onto another target
36 """
37
38 def __init__(self, si, br):
39 if not isinstance(si, (FieldSelectableInt, SelectableInt)):
40 raise ValueError(si)
41
42 if isinstance(br, (list, tuple, range)):
43 _br = BitRange()
44 for i, v in enumerate(br):
45 _br[i] = v
46 br = _br
47
48 if isinstance(si, FieldSelectableInt):
49 fsi = si
50 if len(br) > len(fsi.br):
51 raise OverflowError(br)
52 _br = BitRange()
53 for (i, v) in br.items():
54 _br[i] = fsi.br[v]
55 br = _br
56 si = fsi.si
57
58 self.si = si # target selectable int
59 self.br = br # map of indices
60
61 def eq(self, b):
62 if not isinstance(b, SelectableInt):
63 b = SelectableInt(b, len(self.br))
64 for i in range(b.bits):
65 self[i] = b[i]
66
67 def _op(self, op, b):
68 vi = self.get_range()
69 vi = op(vi, b)
70 return self.merge(vi)
71
72 def _op1(self, op):
73 vi = self.get_range()
74 vi = op(vi)
75 return self.merge(vi)
76
77 def __len__(self):
78 return len(self.br)
79
80 def __getitem__(self, key):
81 log("getitem", key, self.br)
82 if isinstance(key, SelectableInt):
83 key = key.value
84
85 if isinstance(key, int):
86 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
87 return self.si[key]
88 elif isinstance(key, slice):
89 key = self.br[key]
90 return selectconcat(*[self.si[x] for x in key])
91 elif isinstance(key, (tuple, list, range)):
92 return FieldSelectableInt(si=self, br=key)
93 else:
94 raise ValueError(key)
95
96 def __setitem__(self, key, value):
97 if isinstance(key, SelectableInt):
98 key = key.value
99 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
100 if isinstance(key, int):
101 return self.si.__setitem__(key, value)
102 else:
103 if not isinstance(value, SelectableInt):
104 value = SelectableInt(value, bits=len(key))
105 for i, k in enumerate(key):
106 self.si[k] = value[i]
107
108 def __negate__(self):
109 return self._op1(neg)
110
111 def __invert__(self):
112 return self._op1(inv)
113
114 def __add__(self, b):
115 return self._op(add, b)
116
117 def __sub__(self, b):
118 return self._op(sub, b)
119
120 def __mul__(self, b):
121 return self._op(mul, b)
122
123 def __div__(self, b):
124 return self._op(truediv, b)
125
126 def __mod__(self, b):
127 return self._op(mod, b)
128
129 def __and__(self, b):
130 return self._op(and_, b)
131
132 def __or__(self, b):
133 return self._op(or_, b)
134
135 def __xor__(self, b):
136 return self._op(xor, b)
137
138 def __lt__(self, b):
139 vi = self.get_range()
140 return onebit(lt(vi, b))
141
142 def __eq__(self, b):
143 vi = self.get_range()
144 return onebit(eq(vi, b))
145
146 def get_range(self):
147 vi = SelectableInt(0, len(self.br))
148 for k, v in self.br.items():
149 vi[k] = self.si[v]
150 return vi
151
152 def merge(self, vi):
153 fi = copy(self)
154 for i, v in fi.br.items():
155 fi.si[v] = vi[i]
156 return fi
157
158 def __repr__(self):
159 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
160
161 def __bool__(self):
162 for key in self.br.values():
163 bit = self.si[key].value
164 if bit:
165 return True
166 return False
167
168 def __int__(self):
169 return self.asint(msb0=True)
170
171 def asint(self, msb0=False):
172 res = 0
173 brlen = len(self.br)
174 for i, key in self.br.items():
175 #log("asint", i, key, self.si[key])
176 bit = self.si[key].value
177 #log("asint", i, key, bit)
178 res |= bit << ((brlen-i-1) if msb0 else i)
179 return res
180
181
182 class FieldSelectableIntTestCase(unittest.TestCase):
183 def test_arith(self):
184 a = SelectableInt(0b10101, 5)
185 b = SelectableInt(0b011, 3)
186 br = BitRange()
187 br[0] = 0
188 br[1] = 2
189 br[2] = 3
190 fs = FieldSelectableInt(a, br)
191 c = fs + b
192 log(c)
193 #self.assertEqual(c.value, a.value + b.value)
194
195 def test_select(self):
196 a = SelectableInt(0b00001111, 8)
197 br = BitRange()
198 br[0] = 0
199 br[1] = 1
200 br[2] = 4
201 br[3] = 5
202 fs = FieldSelectableInt(a, br)
203
204 self.assertEqual(fs.get_range(), 0b0011)
205
206 def test_select_range(self):
207 a = SelectableInt(0b00001111, 8)
208 br = BitRange()
209 br[0] = 0
210 br[1] = 1
211 br[2] = 4
212 br[3] = 5
213 fs = FieldSelectableInt(a, br)
214
215 self.assertEqual(fs[2:4], 0b11)
216
217 fs[0:2] = 0b10
218 self.assertEqual(fs.get_range(), 0b1011)
219
220
221 @functools.total_ordering
222 class SelectableInt:
223 """SelectableInt - a class that behaves exactly like python int
224
225 this class is designed to mirror precisely the behaviour of python int.
226 the only difference is that it must contain the context of the bitwidth
227 (number of bits) associated with that integer.
228
229 FieldSelectableInt can then operate on partial bits, and because there
230 is a bit width associated with SelectableInt, slices operate correctly
231 including negative start/end points.
232 """
233
234 def __init__(self, value, bits=None):
235 if isinstance(value, FieldSelectableInt):
236 value = value.get_range()
237 if isinstance(value, SelectableInt):
238 if bits is not None:
239 # check if the bitlength is different. TODO, allow override?
240 if bits != value.bits:
241 raise ValueError(value)
242 bits = value.bits
243 value = value.value
244 else:
245 if not isinstance(value, int):
246 raise ValueError(value)
247 if bits is None:
248 raise ValueError(bits)
249 mask = (1 << bits) - 1
250 self.value = value & mask
251 self.bits = bits
252 self.overflow = (value & ~mask) != 0
253
254 def eq(self, b):
255 self.value = b.value
256 self.bits = b.bits
257
258 def to_signed_int(self):
259 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
260 if self.value & (1<<(self.bits-1)) != 0: # negative
261 res = self.value - (1<<self.bits)
262 log (" val -ve:", self.bits, res)
263 else:
264 res = self.value
265 log (" val +ve:", res)
266 return res
267
268 def _op(self, op, b):
269 if isinstance(b, int):
270 b = SelectableInt(b, self.bits)
271 b = check_extsign(self, b)
272 assert b.bits == self.bits
273 return SelectableInt(op(self.value, b.value), self.bits)
274
275 def __add__(self, b):
276 return self._op(add, b)
277
278 def __sub__(self, b):
279 return self._op(sub, b)
280
281 def __mul__(self, b):
282 # different case: mul result needs to fit the total bitsize
283 if isinstance(b, int):
284 b = SelectableInt(b, self.bits)
285 log("SelectableInt mul", hex(self.value), hex(b.value),
286 self.bits, b.bits)
287 return SelectableInt(self.value * b.value, self.bits + b.bits)
288
289 def __floordiv__(self, b):
290 return self._op(floordiv, b)
291
292 def __truediv__(self, b):
293 return self._op(truediv, b)
294
295 def __mod__(self, b):
296 return self._op(mod, b)
297
298 def __and__(self, b):
299 return self._op(and_, b)
300
301 def __or__(self, b):
302 return self._op(or_, b)
303
304 def __xor__(self, b):
305 return self._op(xor, b)
306
307 def __abs__(self):
308 log("abs", self.value & (1 << (self.bits-1)))
309 if self.value & (1 << (self.bits-1)) != 0:
310 return -self
311 return self
312
313 def __rsub__(self, b):
314 log("rsub", b, self.value)
315 if isinstance(b, int):
316 b = SelectableInt(b, EFFECTIVELY_UNLIMITED) # max extent
317 #b = check_extsign(self, b)
318 #assert b.bits == self.bits
319 return SelectableInt(b.value - self.value, b.bits)
320
321 def __radd__(self, b):
322 if isinstance(b, int):
323 b = SelectableInt(b, self.bits)
324 b = check_extsign(self, b)
325 assert b.bits == self.bits
326 return SelectableInt(b.value + self.value, self.bits)
327
328 def __rxor__(self, b):
329 b = check_extsign(self, b)
330 assert b.bits == self.bits
331 return SelectableInt(self.value ^ b.value, self.bits)
332
333 def __invert__(self):
334 return SelectableInt(~self.value, self.bits)
335
336 def __neg__(self):
337 res = SelectableInt((~self.value) + 1, self.bits)
338 log ("neg", hex(self.value), hex(res.value))
339 return res
340
341 def __lshift__(self, b):
342 b = check_extsign(self, b)
343 return SelectableInt(self.value << b.value, self.bits)
344
345 def __rshift__(self, b):
346 b = check_extsign(self, b)
347 return SelectableInt(self.value >> b.value, self.bits)
348
349 def __getitem__(self, key):
350 #log ("SelectableInt.__getitem__", self, key, type(key))
351 if isinstance(key, SelectableInt):
352 key = key.value
353 if isinstance(key, int):
354 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
355 assert key >= 0
356 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
357 # MSB is indexed **LOWEST** (sigh)
358 key = self.bits - (key + 1)
359
360 value = (self.value >> key) & 1
361 #log("getitem", key, self.bits, hex(self.value), value)
362 return SelectableInt(value, 1)
363 elif isinstance(key, slice):
364 start = key.start
365 if isinstance(start, SelectableInt):
366 start = start.value
367 stop = key.stop
368 if isinstance(stop, SelectableInt):
369 stop = stop.value
370 step = key.step
371 if isinstance(step, SelectableInt):
372 step = step.value
373
374 assert step is None or step == 1
375 assert start < stop
376 assert start >= 0
377 assert stop <= self.bits
378
379 (start, stop) = (
380 (self.bits - stop),
381 (self.bits - start),
382 )
383 bits = stop - start
384 #log ("__getitem__ slice num bits", start, stop, bits)
385 mask = (1 << bits) - 1
386 value = (self.value >> start) & mask
387 #log("getitem", stop, start, self.bits, hex(self.value), value)
388 return SelectableInt(value, bits)
389 else:
390 bits = []
391 for bit in key:
392 if not isinstance(bit, (int, SelectableInt)):
393 raise ValueError(key)
394 bits.append(self[bit])
395 return selectconcat(*bits)
396
397 def __setitem__(self, key, value):
398 if isinstance(key, SelectableInt):
399 key = key.value
400 if isinstance(key, int):
401 if isinstance(value, SelectableInt):
402 assert value.bits == 1
403 value = value.value
404 #log("setitem", key, self.bits, hex(self.value), hex(value))
405
406 assert key < self.bits
407 assert key >= 0
408 key = self.bits - (key + 1)
409
410 value = value << key
411 mask = 1 << key
412 self.value = (self.value & ~mask) | (value & mask)
413 elif isinstance(key, slice):
414 kstart, kstop, kstep = key.start, key.stop, key.step
415 if isinstance(kstart, SelectableInt): kstart = kstart.asint()
416 if isinstance(kstop, SelectableInt): kstop = kstop.asint()
417 if isinstance(kstep, SelectableInt): kstep = kstep.asint()
418 #log ("__setitem__ slice ", kstart, kstop, kstep)
419 assert kstep is None or kstep == 1
420 assert kstart < kstop
421 assert kstart >= 0
422 assert kstop <= self.bits, \
423 "key stop %d bits %d" % (kstop, self.bits)
424
425 stop = self.bits - kstart
426 start = self.bits - kstop
427
428 bits = stop - start
429 #log ("__setitem__ slice num bits", bits)
430 if isinstance(value, SelectableInt):
431 assert value.bits == bits, "%d into %d" % (value.bits, bits)
432 value = value.value
433 #log("setitem", key, self.bits, hex(self.value), hex(value))
434 mask = ((1 << bits) - 1) << start
435 value = value << start
436 self.value = (self.value & ~mask) | (value & mask)
437 else:
438 bits = []
439 for bit in key:
440 if not isinstance(bit, (int, SelectableInt)):
441 raise ValueError(key)
442 bits.append(bit)
443
444 if isinstance(value, int):
445 if value.bit_length() > len(bits):
446 raise ValueError(value)
447 value = SelectableInt(value=value, bits=len(bits))
448 if not isinstance(value, SelectableInt):
449 raise ValueError(value)
450
451 for (src, dst) in enumerate(bits):
452 self[dst] = value[src]
453
454 def __lt__(self, other):
455 log ("SelectableInt __lt__", self, other)
456 if isinstance(other, FieldSelectableInt):
457 other = other.get_range()
458 if isinstance(other, SelectableInt):
459 other = check_extsign(self, other)
460 assert other.bits == self.bits
461 other = other.to_signed_int()
462 if isinstance(other, int):
463 a = self.to_signed_int()
464 res = onebit(a < other)
465 log (" a < b", a, other, res)
466 return res
467 assert False
468
469 def __eq__(self, other):
470 log("SelectableInt __eq__", self, other)
471 if isinstance(other, FieldSelectableInt):
472 other = other.get_range()
473 if isinstance(other, SelectableInt):
474 other = check_extsign(self, other)
475 assert other.bits == self.bits
476 other = other.value
477 log (" eq", other, self.value, other == self.value)
478 if isinstance(other, int):
479 return onebit(other == self.value)
480 assert False
481
482 def narrow(self, bits):
483 assert bits <= self.bits
484 return SelectableInt(self.value, bits)
485
486 def __bool__(self):
487 return self.value != 0
488
489 def __repr__(self):
490 value = f"value={hex(self.value)}, bits={self.bits}"
491 return f"{self.__class__.__name__}({value})"
492
493 def __len__(self):
494 return self.bits
495
496 def asint(self):
497 return self.value
498
499 def __int__(self):
500 return self.asint()
501
502 def __float__(self):
503 """convert to double-precision float. TODO, properly convert
504 rather than a hack-job: must actually support Power IEEE754 FP
505 """
506 if self.bits == 32:
507 data = self.value.to_bytes(4, byteorder='little')
508 return struct.unpack('<f', data)[0]
509 assert self.bits == 64 # must be 64-bit
510 data = self.value.to_bytes(8, byteorder='little')
511 return struct.unpack('<d', data)[0]
512
513
514 def onebit(bit):
515 return SelectableInt(1 if bit else 0, 1)
516
517
518 def selectltu(lhs, rhs):
519 """ less-than (unsigned)
520 """
521 if isinstance(lhs, SelectableInt):
522 lhs = lhs.value
523 if isinstance(rhs, SelectableInt):
524 rhs = rhs.value
525 return onebit(lhs < rhs)
526
527
528 def selectgtu(lhs, rhs):
529 """ greater-than (unsigned)
530 """
531 if isinstance(lhs, SelectableInt):
532 lhs = lhs.value
533 if isinstance(rhs, SelectableInt):
534 rhs = rhs.value
535 return onebit(lhs > rhs)
536
537
538 # XXX this probably isn't needed...
539 def selectassign(lhs, idx, rhs):
540 if isinstance(idx, tuple):
541 if len(idx) == 2:
542 lower, upper = idx
543 step = None
544 else:
545 lower, upper, step = idx
546 toidx = range(lower, upper, step)
547 fromidx = range(0, upper-lower, step) # XXX eurgh...
548 else:
549 toidx = [idx]
550 fromidx = [0]
551 for t, f in zip(toidx, fromidx):
552 lhs[t] = rhs[f]
553
554
555 def selectconcat(*args, repeat=1):
556 if isinstance(repeat, SelectableInt):
557 repeat = repeat.value
558 if len(args) == 1 and isinstance(args[0], int) and args[0] in (0, 1):
559 args = [SelectableInt(args[0], 1)]
560 if repeat != 1: # multiplies the incoming arguments
561 tmp = []
562 for i in range(repeat):
563 tmp += args
564 args = tmp
565 res = copy(args[0])
566 for i in args[1:]:
567 if isinstance(i, FieldSelectableInt):
568 i = i.get_range()
569 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
570 res.bits += i.bits
571 res.value = (res.value << i.bits) | i.value
572 log("concat", repeat, res)
573 return res
574
575
576 class SelectableIntTestCase(unittest.TestCase):
577 def test_arith(self):
578 a = SelectableInt(5, 8)
579 b = SelectableInt(9, 8)
580 c = a + b
581 d = a - b
582 e = a * b
583 f = -a
584 g = abs(f)
585 h = abs(a)
586 self.assertEqual(c.value, a.value + b.value)
587 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
588 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
589 self.assertEqual(f.value, (-a.value) & 0xFF)
590 self.assertEqual(c.bits, a.bits)
591 self.assertEqual(d.bits, a.bits)
592 self.assertEqual(e.bits, a.bits)
593 self.assertEqual(a.bits, f.bits)
594 self.assertEqual(a.bits, h.bits)
595
596 def test_logic(self):
597 a = SelectableInt(0x0F, 8)
598 b = SelectableInt(0xA5, 8)
599 c = a & b
600 d = a | b
601 e = a ^ b
602 f = ~a
603 self.assertEqual(c.value, a.value & b.value)
604 self.assertEqual(d.value, a.value | b.value)
605 self.assertEqual(e.value, a.value ^ b.value)
606 self.assertEqual(f.value, 0xF0)
607
608 def test_get(self):
609 a = SelectableInt(0xa2, 8)
610 # These should be big endian
611 self.assertEqual(a[7], 0)
612 self.assertEqual(a[0:4], 10)
613 self.assertEqual(a[4:8], 2)
614
615 def test_set(self):
616 a = SelectableInt(0x5, 8)
617 a[7] = SelectableInt(0, 1)
618 self.assertEqual(a, 4)
619 a[4:8] = 9
620 self.assertEqual(a, 9)
621 a[0:4] = 3
622 self.assertEqual(a, 0x39)
623 a[0:4] = a[4:8]
624 self.assertEqual(a, 0x99)
625
626 def test_concat(self):
627 a = SelectableInt(0x1, 1)
628 c = selectconcat(a, repeat=8)
629 self.assertEqual(c, 0xff)
630 self.assertEqual(c.bits, 8)
631 a = SelectableInt(0x0, 1)
632 c = selectconcat(a, repeat=8)
633 self.assertEqual(c, 0x00)
634 self.assertEqual(c.bits, 8)
635
636 def test_repr(self):
637 for i in range(65536):
638 a = SelectableInt(i, 16)
639 b = eval(repr(a))
640 self.assertEqual(a, b)
641
642 def test_cmp(self):
643 a = SelectableInt(10, bits=8)
644 b = SelectableInt(5, bits=8)
645 self.assertTrue(a > b)
646 self.assertFalse(a < b)
647 self.assertTrue(a != b)
648 self.assertFalse(a == b)
649
650 def test_unsigned(self):
651 a = SelectableInt(0x80, bits=8)
652 b = SelectableInt(0x7f, bits=8)
653 self.assertTrue(a > b)
654 self.assertFalse(a < b)
655 self.assertTrue(a != b)
656 self.assertFalse(a == b)
657
658 def test_maxint(self):
659 a = SelectableInt(0xffffffffffffffff, bits=64)
660 b = SelectableInt(0, bits=64)
661 result = a + b
662 self.assertTrue(result.value == 0xffffffffffffffff)
663
664 def test_double_1(self):
665 """use http://weitz.de/ieee/,
666 """
667 for asint, asfloat in [(0x4000000000000000, 2.0),
668 (0x4056C00000000000, 91.0),
669 (0xff80000000000000, -1.4044477616111843e+306),
670 ]:
671 a = SelectableInt(asint, bits=64)
672 convert = float(a)
673 log ("test_double_1", asint, asfloat, convert)
674 self.assertTrue(asfloat == convert)
675
676
677 if __name__ == "__main__":
678 unittest.main()