selectable_int: simplify SelectableIntMapping class
[openpower-isa.git] / src / openpower / decoder / selectable_int.py
1 import unittest
2 import struct
3 from copy import copy
4 from openpower.decoder.power_fields import BitRange
5 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
6 neg, inv, lshift, rshift)
7 from openpower.util import log
8
9
10 def check_extsign(a, b):
11 if isinstance(b, FieldSelectableInt):
12 b = b.get_range()
13 if isinstance(b, int):
14 return SelectableInt(b, a.bits)
15 if b.bits != 256:
16 return b
17 return SelectableInt(b.value, a.bits)
18
19
20 class FieldSelectableInt:
21 """FieldSelectableInt: allows bit-range selection onto another target
22 """
23
24 def __init__(self, si, br):
25 self.si = si # target selectable int
26 if isinstance(br, (list, tuple, range)):
27 _br = BitRange()
28 for i, v in enumerate(br):
29 _br[i] = v
30 br = _br
31 self.br = br # map of indices.
32
33 def eq(self, b):
34 if isinstance(b, int):
35 # convert integer to same SelectableInt of same bitlength as range
36 blen = len(self.br)
37 b = SelectableInt(b, blen)
38 for i in range(b.bits):
39 self[i] = b[i]
40 elif isinstance(b, SelectableInt):
41 for i in range(b.bits):
42 self[i] = b[i]
43 else:
44 self.si = copy(b.si)
45 self.br = copy(b.br)
46
47 def _op(self, op, b):
48 vi = self.get_range()
49 vi = op(vi, b)
50 return self.merge(vi)
51
52 def _op1(self, op):
53 vi = self.get_range()
54 vi = op(vi)
55 return self.merge(vi)
56
57 def __getitem__(self, key):
58 log("getitem", key, self.br)
59 if isinstance(key, SelectableInt):
60 key = key.value
61 if isinstance(key, int):
62 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
63 return self.si[key]
64 if isinstance(key, slice):
65 key = self.br[key]
66 return selectconcat(*[self.si[x] for x in key])
67
68 def __setitem__(self, key, value):
69 if isinstance(key, SelectableInt):
70 key = key.value
71 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
72 if isinstance(key, int):
73 return self.si.__setitem__(key, value)
74 else:
75 if not isinstance(value, SelectableInt):
76 value = SelectableInt(value, bits=len(key))
77 for i, k in enumerate(key):
78 self.si[k] = value[i]
79
80 def __negate__(self):
81 return self._op1(negate)
82
83 def __invert__(self):
84 return self._op1(inv)
85
86 def __add__(self, b):
87 return self._op(add, b)
88
89 def __sub__(self, b):
90 return self._op(sub, b)
91
92 def __mul__(self, b):
93 return self._op(mul, b)
94
95 def __div__(self, b):
96 return self._op(truediv, b)
97
98 def __mod__(self, b):
99 return self._op(mod, b)
100
101 def __and__(self, b):
102 return self._op(and_, b)
103
104 def __or__(self, b):
105 return self._op(or_, b)
106
107 def __xor__(self, b):
108 return self._op(xor, b)
109
110 def get_range(self):
111 vi = SelectableInt(0, len(self.br))
112 for k, v in self.br.items():
113 vi[k] = self.si[v]
114 return vi
115
116 def merge(self, vi):
117 fi = copy(self)
118 for i, v in fi.br.items():
119 fi.si[v] = vi[i]
120 return fi
121
122 def __repr__(self):
123 return f"{self.__class__.__name__}(si={self.si}, br={self.br})"
124
125 def asint(self, msb0=False):
126 res = 0
127 brlen = len(self.br)
128 for i, key in self.br.items():
129 bit = self.si[key].value
130 #log("asint", i, key, bit)
131 res |= bit << ((brlen-i-1) if msb0 else i)
132 return res
133
134
135 class FieldSelectableIntTestCase(unittest.TestCase):
136 def test_arith(self):
137 a = SelectableInt(0b10101, 5)
138 b = SelectableInt(0b011, 3)
139 br = BitRange()
140 br[0] = 0
141 br[1] = 2
142 br[2] = 3
143 fs = FieldSelectableInt(a, br)
144 c = fs + b
145 log(c)
146 #self.assertEqual(c.value, a.value + b.value)
147
148 def test_select(self):
149 a = SelectableInt(0b00001111, 8)
150 br = BitRange()
151 br[0] = 0
152 br[1] = 1
153 br[2] = 4
154 br[3] = 5
155 fs = FieldSelectableInt(a, br)
156
157 self.assertEqual(fs.get_range(), 0b0011)
158
159 def test_select_range(self):
160 a = SelectableInt(0b00001111, 8)
161 br = BitRange()
162 br[0] = 0
163 br[1] = 1
164 br[2] = 4
165 br[3] = 5
166 fs = FieldSelectableInt(a, br)
167
168 self.assertEqual(fs[2:4], 0b11)
169
170 fs[0:2] = 0b10
171 self.assertEqual(fs.get_range(), 0b1011)
172
173
174 class SelectableInt:
175 """SelectableInt - a class that behaves exactly like python int
176
177 this class is designed to mirror precisely the behaviour of python int.
178 the only difference is that it must contain the context of the bitwidth
179 (number of bits) associated with that integer.
180
181 FieldSelectableInt can then operate on partial bits, and because there
182 is a bit width associated with SelectableInt, slices operate correctly
183 including negative start/end points.
184 """
185
186 def __init__(self, value, bits=None):
187 if isinstance(value, SelectableInt):
188 bits = value.bits
189 value = value.value
190 mask = (1 << bits) - 1
191 self.value = value & mask
192 self.bits = bits
193 self.overflow = (value & ~mask) != 0
194
195 def eq(self, b):
196 self.value = b.value
197 self.bits = b.bits
198
199 def to_signed_int(self):
200 log ("to signed?", self.value & (1<<(self.bits-1)), self.value)
201 if self.value & (1<<(self.bits-1)) != 0: # negative
202 res = self.value - (1<<self.bits)
203 log (" val -ve:", self.bits, res)
204 else:
205 res = self.value
206 log (" val +ve:", res)
207 return res
208
209 def _op(self, op, b):
210 if isinstance(b, int):
211 b = SelectableInt(b, self.bits)
212 b = check_extsign(self, b)
213 assert b.bits == self.bits
214 return SelectableInt(op(self.value, b.value), self.bits)
215
216 def __add__(self, b):
217 return self._op(add, b)
218
219 def __sub__(self, b):
220 return self._op(sub, b)
221
222 def __mul__(self, b):
223 # different case: mul result needs to fit the total bitsize
224 if isinstance(b, int):
225 b = SelectableInt(b, self.bits)
226 log("SelectableInt mul", hex(self.value), hex(b.value),
227 self.bits, b.bits)
228 return SelectableInt(self.value * b.value, self.bits + b.bits)
229
230 def __floordiv__(self, b):
231 return self._op(floordiv, b)
232
233 def __truediv__(self, b):
234 return self._op(truediv, b)
235
236 def __mod__(self, b):
237 return self._op(mod, b)
238
239 def __and__(self, b):
240 return self._op(and_, b)
241
242 def __or__(self, b):
243 return self._op(or_, b)
244
245 def __xor__(self, b):
246 return self._op(xor, b)
247
248 def __abs__(self):
249 log("abs", self.value & (1 << (self.bits-1)))
250 if self.value & (1 << (self.bits-1)) != 0:
251 return -self
252 return self
253
254 def __rsub__(self, b):
255 if isinstance(b, int):
256 b = SelectableInt(b, self.bits)
257 b = check_extsign(self, b)
258 assert b.bits == self.bits
259 return SelectableInt(b.value - self.value, self.bits)
260
261 def __radd__(self, b):
262 if isinstance(b, int):
263 b = SelectableInt(b, self.bits)
264 b = check_extsign(self, b)
265 assert b.bits == self.bits
266 return SelectableInt(b.value + self.value, self.bits)
267
268 def __rxor__(self, b):
269 b = check_extsign(self, b)
270 assert b.bits == self.bits
271 return SelectableInt(self.value ^ b.value, self.bits)
272
273 def __invert__(self):
274 return SelectableInt(~self.value, self.bits)
275
276 def __neg__(self):
277 res = SelectableInt((~self.value) + 1, self.bits)
278 log ("neg", hex(self.value), hex(res.value))
279 return res
280
281 def __lshift__(self, b):
282 b = check_extsign(self, b)
283 return SelectableInt(self.value << b.value, self.bits)
284
285 def __rshift__(self, b):
286 b = check_extsign(self, b)
287 return SelectableInt(self.value >> b.value, self.bits)
288
289 def __getitem__(self, key):
290 if isinstance(key, SelectableInt):
291 key = key.value
292 if isinstance(key, int):
293 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
294 assert key >= 0
295 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
296 # MSB is indexed **LOWEST** (sigh)
297 key = self.bits - (key + 1)
298
299 value = (self.value >> key) & 1
300 log("getitem", key, self.bits, hex(self.value), value)
301 return SelectableInt(value, 1)
302 elif isinstance(key, slice):
303 assert key.step is None or key.step == 1
304 assert key.start < key.stop
305 assert key.start >= 0
306 assert key.stop <= self.bits
307
308 stop = self.bits - key.start
309 start = self.bits - key.stop
310
311 bits = stop - start
312 #log ("__getitem__ slice num bits", start, stop, bits)
313 mask = (1 << bits) - 1
314 value = (self.value >> start) & mask
315 log("getitem", stop, start, self.bits, hex(self.value), value)
316 return SelectableInt(value, bits)
317
318 def __setitem__(self, key, value):
319 if isinstance(key, SelectableInt):
320 key = key.value
321 if isinstance(key, int):
322 if isinstance(value, SelectableInt):
323 assert value.bits == 1
324 value = value.value
325 log("setitem", key, self.bits, hex(self.value), hex(value))
326
327 assert key < self.bits
328 assert key >= 0
329 key = self.bits - (key + 1)
330
331 value = value << key
332 mask = 1 << key
333 self.value = (self.value & ~mask) | (value & mask)
334 elif isinstance(key, slice):
335 assert key.step is None or key.step == 1
336 assert key.start < key.stop
337 assert key.start >= 0
338 assert key.stop <= self.bits, \
339 "key stop %d bits %d" % (key.stop, self.bits)
340
341 stop = self.bits - key.start
342 start = self.bits - key.stop
343
344 bits = stop - start
345 #log ("__setitem__ slice num bits", bits)
346 if isinstance(value, SelectableInt):
347 assert value.bits == bits, "%d into %d" % (value.bits, bits)
348 value = value.value
349 log("setitem", key, self.bits, hex(self.value), hex(value))
350 mask = ((1 << bits) - 1) << start
351 value = value << start
352 self.value = (self.value & ~mask) | (value & mask)
353
354 def __ge__(self, other):
355 if isinstance(other, FieldSelectableInt):
356 other = other.get_range()
357 if isinstance(other, SelectableInt):
358 other = check_extsign(self, other)
359 assert other.bits == self.bits
360 other = other.to_signed_int()
361 if isinstance(other, int):
362 return onebit(self.to_signed_int() >= other)
363 assert False
364
365 def __le__(self, other):
366 if isinstance(other, FieldSelectableInt):
367 other = other.get_range()
368 if isinstance(other, SelectableInt):
369 other = check_extsign(self, other)
370 assert other.bits == self.bits
371 other = other.to_signed_int()
372 if isinstance(other, int):
373 return onebit(self.to_signed_int() <= other)
374 assert False
375
376 def __gt__(self, other):
377 if isinstance(other, FieldSelectableInt):
378 other = other.get_range()
379 if isinstance(other, SelectableInt):
380 other = check_extsign(self, other)
381 assert other.bits == self.bits
382 other = other.to_signed_int()
383 if isinstance(other, int):
384 return onebit(self.to_signed_int() > other)
385 assert False
386
387 def __lt__(self, other):
388 log ("SelectableInt lt", self, other)
389 if isinstance(other, FieldSelectableInt):
390 other = other.get_range()
391 if isinstance(other, SelectableInt):
392 other = check_extsign(self, other)
393 assert other.bits == self.bits
394 other = other.to_signed_int()
395 if isinstance(other, int):
396 a = self.to_signed_int()
397 res = onebit(a < other)
398 log (" a < b", a, other, res)
399 return res
400 assert False
401
402 def __eq__(self, other):
403 log("__eq__", self, other)
404 if isinstance(other, FieldSelectableInt):
405 other = other.get_range()
406 if isinstance(other, SelectableInt):
407 other = check_extsign(self, other)
408 assert other.bits == self.bits
409 other = other.value
410 log (" eq", other, self.value, other == self.value)
411 if isinstance(other, int):
412 return onebit(other == self.value)
413 assert False
414
415 def narrow(self, bits):
416 assert bits <= self.bits
417 return SelectableInt(self.value, bits)
418
419 def __bool__(self):
420 return self.value != 0
421
422 def __repr__(self):
423 value = f"value=0x{self.value:x}, bits={self.bits}"
424 return f"{self.__class__.__name__}({value})"
425
426 def __len__(self):
427 return self.bits
428
429 def asint(self):
430 return self.value
431
432 def __float__(self):
433 """convert to double-precision float. TODO, properly convert
434 rather than a hack-job: must actually support Power IEEE754 FP
435 """
436 assert self.bits == 64 # must be 64-bit
437 data = self.value.to_bytes(8, byteorder='little')
438 return struct.unpack('<d', data)[0]
439
440
441 class SelectableIntMapping(dict):
442 def __init__(self, si, fields=None):
443 def field(item):
444 (key, value) = item
445 if isinstance(value, dict):
446 value = dict(map(field, value.items()))
447 else:
448 value = FieldSelectableInt(si, value)
449 return (key, value)
450
451 return super().__init__(map(field, fields.items()))
452
453 def __getattr__(self, attr):
454 try:
455 return self[attr]
456 except KeyError as error:
457 raise AttributeError from error
458
459
460 def onebit(bit):
461 return SelectableInt(1 if bit else 0, 1)
462
463
464 def selectltu(lhs, rhs):
465 """ less-than (unsigned)
466 """
467 if isinstance(rhs, SelectableInt):
468 rhs = rhs.value
469 return onebit(lhs.value < rhs)
470
471
472 def selectgtu(lhs, rhs):
473 """ greater-than (unsigned)
474 """
475 if isinstance(rhs, SelectableInt):
476 rhs = rhs.value
477 return onebit(lhs.value > rhs)
478
479
480 # XXX this probably isn't needed...
481 def selectassign(lhs, idx, rhs):
482 if isinstance(idx, tuple):
483 if len(idx) == 2:
484 lower, upper = idx
485 step = None
486 else:
487 lower, upper, step = idx
488 toidx = range(lower, upper, step)
489 fromidx = range(0, upper-lower, step) # XXX eurgh...
490 else:
491 toidx = [idx]
492 fromidx = [0]
493 for t, f in zip(toidx, fromidx):
494 lhs[t] = rhs[f]
495
496
497 def selectconcat(*args, repeat=1):
498 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
499 args = [SelectableInt(args[0], 1)]
500 if repeat != 1: # multiplies the incoming arguments
501 tmp = []
502 for i in range(repeat):
503 tmp += args
504 args = tmp
505 res = copy(args[0])
506 for i in args[1:]:
507 if isinstance(i, FieldSelectableInt):
508 i = i.si
509 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
510 res.bits += i.bits
511 res.value = (res.value << i.bits) | i.value
512 log("concat", repeat, res)
513 return res
514
515
516 class SelectableIntTestCase(unittest.TestCase):
517 def test_arith(self):
518 a = SelectableInt(5, 8)
519 b = SelectableInt(9, 8)
520 c = a + b
521 d = a - b
522 e = a * b
523 f = -a
524 g = abs(f)
525 h = abs(a)
526 self.assertEqual(c.value, a.value + b.value)
527 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
528 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
529 self.assertEqual(f.value, (-a.value) & 0xFF)
530 self.assertEqual(c.bits, a.bits)
531 self.assertEqual(d.bits, a.bits)
532 self.assertEqual(e.bits, a.bits)
533 self.assertEqual(a.bits, f.bits)
534 self.assertEqual(a.bits, h.bits)
535
536 def test_logic(self):
537 a = SelectableInt(0x0F, 8)
538 b = SelectableInt(0xA5, 8)
539 c = a & b
540 d = a | b
541 e = a ^ b
542 f = ~a
543 self.assertEqual(c.value, a.value & b.value)
544 self.assertEqual(d.value, a.value | b.value)
545 self.assertEqual(e.value, a.value ^ b.value)
546 self.assertEqual(f.value, 0xF0)
547
548 def test_get(self):
549 a = SelectableInt(0xa2, 8)
550 # These should be big endian
551 self.assertEqual(a[7], 0)
552 self.assertEqual(a[0:4], 10)
553 self.assertEqual(a[4:8], 2)
554
555 def test_set(self):
556 a = SelectableInt(0x5, 8)
557 a[7] = SelectableInt(0, 1)
558 self.assertEqual(a, 4)
559 a[4:8] = 9
560 self.assertEqual(a, 9)
561 a[0:4] = 3
562 self.assertEqual(a, 0x39)
563 a[0:4] = a[4:8]
564 self.assertEqual(a, 0x99)
565
566 def test_concat(self):
567 a = SelectableInt(0x1, 1)
568 c = selectconcat(a, repeat=8)
569 self.assertEqual(c, 0xff)
570 self.assertEqual(c.bits, 8)
571 a = SelectableInt(0x0, 1)
572 c = selectconcat(a, repeat=8)
573 self.assertEqual(c, 0x00)
574 self.assertEqual(c.bits, 8)
575
576 def test_repr(self):
577 for i in range(65536):
578 a = SelectableInt(i, 16)
579 b = eval(repr(a))
580 self.assertEqual(a, b)
581
582 def test_cmp(self):
583 a = SelectableInt(10, bits=8)
584 b = SelectableInt(5, bits=8)
585 self.assertTrue(a > b)
586 self.assertFalse(a < b)
587 self.assertTrue(a != b)
588 self.assertFalse(a == b)
589
590 def test_unsigned(self):
591 a = SelectableInt(0x80, bits=8)
592 b = SelectableInt(0x7f, bits=8)
593 self.assertTrue(a > b)
594 self.assertFalse(a < b)
595 self.assertTrue(a != b)
596 self.assertFalse(a == b)
597
598 def test_maxint(self):
599 a = SelectableInt(0xffffffffffffffff, bits=64)
600 b = SelectableInt(0, bits=64)
601 result = a + b
602 self.assertTrue(result.value == 0xffffffffffffffff)
603
604 def test_double_1(self):
605 """use http://weitz.de/ieee/,
606 """
607 for asint, asfloat in [(0x4000000000000000, 2.0),
608 (0x4056C00000000000, 91.0),
609 (0xff80000000000000, -1.4044477616111843e+306),
610 ]:
611 a = SelectableInt(asint, bits=64)
612 convert = float(a)
613 log ("test_double_1", asint, asfloat, convert)
614 self.assertTrue(asfloat == convert)
615
616
617 if __name__ == "__main__":
618 unittest.main()