Merge branch 'master' of git.libre-soc.org:soc
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21
22 def __init__(self, si, br):
23 self.si = si # target selectable int
24 if isinstance(br, list) or isinstance(br, tuple):
25 _br = BitRange()
26 for i, v in enumerate(br):
27 _br[i] = v
28 br = _br
29 self.br = br # map of indices.
30
31 def eq(self, b):
32 if isinstance(b, SelectableInt):
33 for i in range(b.bits):
34 self[i] = b[i]
35 else:
36 self.si = copy(b.si)
37 self.br = copy(b.br)
38
39 def _op(self, op, b):
40 vi = self.get_range()
41 vi = op(vi, b)
42 return self.merge(vi)
43
44 def _op1(self, op):
45 vi = self.get_range()
46 vi = op(vi)
47 return self.merge(vi)
48
49 def __getitem__(self, key):
50 print("getitem", key, self.br)
51 if isinstance(key, SelectableInt):
52 key = key.value
53 if isinstance(key, int):
54 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
55 return self.si[key]
56 if isinstance(key, slice):
57 key = self.br[key]
58 return selectconcat(*[self.si[x] for x in key])
59
60 def __setitem__(self, key, value):
61 if isinstance(key, SelectableInt):
62 key = key.value
63 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
64 if isinstance(key, int):
65 return self.si.__setitem__(key, value)
66 else:
67 if not isinstance(value, SelectableInt):
68 value = SelectableInt(value, bits=len(key))
69 for i, k in enumerate(key):
70 self.si[k] = value[i]
71
72 def __negate__(self):
73 return self._op1(negate)
74
75 def __invert__(self):
76 return self._op1(inv)
77
78 def __add__(self, b):
79 return self._op(add, b)
80
81 def __sub__(self, b):
82 return self._op(sub, b)
83
84 def __mul__(self, b):
85 return self._op(mul, b)
86
87 def __div__(self, b):
88 return self._op(truediv, b)
89
90 def __mod__(self, b):
91 return self._op(mod, b)
92
93 def __and__(self, b):
94 return self._op(and_, b)
95
96 def __or__(self, b):
97 return self._op(or_, b)
98
99 def __xor__(self, b):
100 return self._op(xor, b)
101
102 def get_range(self):
103 vi = SelectableInt(0, len(self.br))
104 for k, v in self.br.items():
105 vi[k] = self.si[v]
106 return vi
107
108 def merge(self, vi):
109 fi = copy(self)
110 for i, v in fi.br.items():
111 fi.si[v] = vi[i]
112 return fi
113
114 def __repr__(self):
115 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
116
117
118 class FieldSelectableIntTestCase(unittest.TestCase):
119 def test_arith(self):
120 a = SelectableInt(0b10101, 5)
121 b = SelectableInt(0b011, 3)
122 br = BitRange()
123 br[0] = 0
124 br[1] = 2
125 br[2] = 3
126 fs = FieldSelectableInt(a, br)
127 c = fs + b
128 print(c)
129 #self.assertEqual(c.value, a.value + b.value)
130
131 def test_select(self):
132 a = SelectableInt(0b00001111, 8)
133 br = BitRange()
134 br[0] = 0
135 br[1] = 1
136 br[2] = 4
137 br[3] = 5
138 fs = FieldSelectableInt(a, br)
139
140 self.assertEqual(fs.get_range(), 0b0011)
141
142 def test_select_range(self):
143 a = SelectableInt(0b00001111, 8)
144 br = BitRange()
145 br[0] = 0
146 br[1] = 1
147 br[2] = 4
148 br[3] = 5
149 fs = FieldSelectableInt(a, br)
150
151 self.assertEqual(fs[2:4], 0b11)
152
153 fs[0:2] = 0b10
154 self.assertEqual(fs.get_range(), 0b1011)
155
156 class SelectableInt:
157 """SelectableInt - a class that behaves exactly like python int
158
159 this class is designed to mirror precisely the behaviour of python int.
160 the only difference is that it must contain the context of the bitwidth
161 (number of bits) associated with that integer.
162
163 FieldSelectableInt can then operate on partial bits, and because there
164 is a bit width associated with SelectableInt, slices operate correctly
165 including negative start/end points.
166 """
167
168 def __init__(self, value, bits):
169 if isinstance(value, SelectableInt):
170 value = value.value
171 mask = (1 << bits) - 1
172 self.value = value & mask
173 self.bits = bits
174 self.overflow = (value & ~mask) != 0
175
176 def eq(self, b):
177 self.value = b.value
178 self.bits = b.bits
179
180 def to_signed_int(self):
181 print ("to signed?", self.value & (1<<(self.bits-1)), self.value)
182 if self.value & (1<<(self.bits-1)) != 0: # negative
183 res = self.value - (1<<self.bits)
184 print (" val -ve:", self.bits, res)
185 else:
186 res = self.value
187 print (" val +ve:", res)
188 return res
189
190 def _op(self, op, b):
191 if isinstance(b, int):
192 b = SelectableInt(b, self.bits)
193 b = check_extsign(self, b)
194 assert b.bits == self.bits
195 return SelectableInt(op(self.value, b.value), self.bits)
196
197 def __add__(self, b):
198 return self._op(add, b)
199
200 def __sub__(self, b):
201 return self._op(sub, b)
202
203 def __mul__(self, b):
204 # different case: mul result needs to fit the total bitsize
205 if isinstance(b, int):
206 b = SelectableInt(b, self.bits)
207 print("SelectableInt mul", hex(self.value), hex(b.value),
208 self.bits, b.bits)
209 return SelectableInt(self.value * b.value, self.bits + b.bits)
210
211 def __floordiv__(self, b):
212 return self._op(floordiv, b)
213
214 def __truediv__(self, b):
215 return self._op(truediv, b)
216
217 def __mod__(self, b):
218 return self._op(mod, b)
219
220 def __and__(self, b):
221 return self._op(and_, b)
222
223 def __or__(self, b):
224 return self._op(or_, b)
225
226 def __xor__(self, b):
227 return self._op(xor, b)
228
229 def __abs__(self):
230 print("abs", self.value & (1 << (self.bits-1)))
231 if self.value & (1 << (self.bits-1)) != 0:
232 return -self
233 return self
234
235 def __rsub__(self, b):
236 if isinstance(b, int):
237 b = SelectableInt(b, self.bits)
238 b = check_extsign(self, b)
239 assert b.bits == self.bits
240 return SelectableInt(b.value - self.value, self.bits)
241
242 def __radd__(self, b):
243 if isinstance(b, int):
244 b = SelectableInt(b, self.bits)
245 b = check_extsign(self, b)
246 assert b.bits == self.bits
247 return SelectableInt(b.value + self.value, self.bits)
248
249 def __rxor__(self, b):
250 b = check_extsign(self, b)
251 assert b.bits == self.bits
252 return SelectableInt(self.value ^ b.value, self.bits)
253
254 def __invert__(self):
255 return SelectableInt(~self.value, self.bits)
256
257 def __neg__(self):
258 res = SelectableInt((~self.value) + 1, self.bits)
259 print ("neg", hex(self.value), hex(res.value))
260 return res
261
262 def __lshift__(self, b):
263 b = check_extsign(self, b)
264 return SelectableInt(self.value << b.value, self.bits)
265
266 def __rshift__(self, b):
267 b = check_extsign(self, b)
268 return SelectableInt(self.value >> b.value, self.bits)
269
270 def __getitem__(self, key):
271 if isinstance(key, SelectableInt):
272 key = key.value
273 if isinstance(key, int):
274 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
275 assert key >= 0
276 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
277 # MSB is indexed **LOWEST** (sigh)
278 key = self.bits - (key + 1)
279
280 value = (self.value >> key) & 1
281 return SelectableInt(value, 1)
282 elif isinstance(key, slice):
283 assert key.step is None or key.step == 1
284 assert key.start < key.stop
285 assert key.start >= 0
286 assert key.stop <= self.bits
287
288 stop = self.bits - key.start
289 start = self.bits - key.stop
290
291 bits = stop - start
292 #print ("__getitem__ slice num bits", start, stop, bits)
293 mask = (1 << bits) - 1
294 value = (self.value >> start) & mask
295 return SelectableInt(value, bits)
296
297 def __setitem__(self, key, value):
298 if isinstance(key, SelectableInt):
299 key = key.value
300 if isinstance(key, int):
301 assert key < self.bits
302 assert key >= 0
303 key = self.bits - (key + 1)
304 if isinstance(value, SelectableInt):
305 assert value.bits == 1
306 value = value.value
307
308 value = value << key
309 mask = 1 << key
310 self.value = (self.value & ~mask) | (value & mask)
311 elif isinstance(key, slice):
312 assert key.step is None or key.step == 1
313 assert key.start < key.stop
314 assert key.start >= 0
315 assert key.stop <= self.bits
316
317 stop = self.bits - key.start
318 start = self.bits - key.stop
319
320 bits = stop - start
321 #print ("__setitem__ slice num bits", bits)
322 if isinstance(value, SelectableInt):
323 assert value.bits == bits, "%d into %d" % (value.bits, bits)
324 value = value.value
325 mask = ((1 << bits) - 1) << start
326 value = value << start
327 self.value = (self.value & ~mask) | (value & mask)
328
329 def __ge__(self, other):
330 if isinstance(other, FieldSelectableInt):
331 other = other.get_range()
332 if isinstance(other, SelectableInt):
333 other = check_extsign(self, other)
334 assert other.bits == self.bits
335 other = other.to_signed_int()
336 if isinstance(other, int):
337 return onebit(self.to_signed_int() >= other)
338 assert False
339
340 def __le__(self, other):
341 if isinstance(other, FieldSelectableInt):
342 other = other.get_range()
343 if isinstance(other, SelectableInt):
344 other = check_extsign(self, other)
345 assert other.bits == self.bits
346 other = other.to_signed_int()
347 if isinstance(other, int):
348 return onebit(self.to_signed_int() <= other)
349 assert False
350
351 def __gt__(self, other):
352 if isinstance(other, FieldSelectableInt):
353 other = other.get_range()
354 if isinstance(other, SelectableInt):
355 other = check_extsign(self, other)
356 assert other.bits == self.bits
357 other = other.to_signed_int()
358 if isinstance(other, int):
359 return onebit(self.to_signed_int() > other)
360 assert False
361
362 def __lt__(self, other):
363 print ("SelectableInt lt", self, other)
364 if isinstance(other, FieldSelectableInt):
365 other = other.get_range()
366 if isinstance(other, SelectableInt):
367 other = check_extsign(self, other)
368 assert other.bits == self.bits
369 other = other.to_signed_int()
370 if isinstance(other, int):
371 a = self.to_signed_int()
372 res = onebit(a < other)
373 print (" a < b", a, other, res)
374 return res
375 assert False
376
377 def __eq__(self, other):
378 print("__eq__", self, other)
379 if isinstance(other, FieldSelectableInt):
380 other = other.get_range()
381 if isinstance(other, SelectableInt):
382 other = check_extsign(self, other)
383 assert other.bits == self.bits
384 other = other.value
385 print (" eq", other, self.value, other == self.value)
386 if isinstance(other, int):
387 return onebit(other == self.value)
388 assert False
389
390 def narrow(self, bits):
391 assert bits <= self.bits
392 return SelectableInt(self.value, bits)
393
394 def __bool__(self):
395 return self.value != 0
396
397 def __repr__(self):
398 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
399 self.bits)
400
401 def __len__(self):
402 return self.bits
403
404 def asint(self):
405 return self.value
406
407
408 def onebit(bit):
409 return SelectableInt(1 if bit else 0, 1)
410
411
412 def selectltu(lhs, rhs):
413 """ less-than (unsigned)
414 """
415 if isinstance(rhs, SelectableInt):
416 rhs = rhs.value
417 return onebit(lhs.value < rhs)
418
419
420 def selectgtu(lhs, rhs):
421 """ greater-than (unsigned)
422 """
423 if isinstance(rhs, SelectableInt):
424 rhs = rhs.value
425 return onebit(lhs.value > rhs)
426
427
428 # XXX this probably isn't needed...
429 def selectassign(lhs, idx, rhs):
430 if isinstance(idx, tuple):
431 if len(idx) == 2:
432 lower, upper = idx
433 step = None
434 else:
435 lower, upper, step = idx
436 toidx = range(lower, upper, step)
437 fromidx = range(0, upper-lower, step) # XXX eurgh...
438 else:
439 toidx = [idx]
440 fromidx = [0]
441 for t, f in zip(toidx, fromidx):
442 lhs[t] = rhs[f]
443
444
445 def selectconcat(*args, repeat=1):
446 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
447 args = [SelectableInt(args[0], 1)]
448 if repeat != 1: # multiplies the incoming arguments
449 tmp = []
450 for i in range(repeat):
451 tmp += args
452 args = tmp
453 res = copy(args[0])
454 for i in args[1:]:
455 if isinstance(i, FieldSelectableInt):
456 i = i.si
457 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
458 res.bits += i.bits
459 res.value = (res.value << i.bits) | i.value
460 print("concat", repeat, res)
461 return res
462
463
464 class SelectableIntTestCase(unittest.TestCase):
465 def test_arith(self):
466 a = SelectableInt(5, 8)
467 b = SelectableInt(9, 8)
468 c = a + b
469 d = a - b
470 e = a * b
471 f = -a
472 g = abs(f)
473 h = abs(a)
474 self.assertEqual(c.value, a.value + b.value)
475 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
476 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
477 self.assertEqual(f.value, (-a.value) & 0xFF)
478 self.assertEqual(c.bits, a.bits)
479 self.assertEqual(d.bits, a.bits)
480 self.assertEqual(e.bits, a.bits)
481 self.assertEqual(a.bits, f.bits)
482 self.assertEqual(a.bits, h.bits)
483
484 def test_logic(self):
485 a = SelectableInt(0x0F, 8)
486 b = SelectableInt(0xA5, 8)
487 c = a & b
488 d = a | b
489 e = a ^ b
490 f = ~a
491 self.assertEqual(c.value, a.value & b.value)
492 self.assertEqual(d.value, a.value | b.value)
493 self.assertEqual(e.value, a.value ^ b.value)
494 self.assertEqual(f.value, 0xF0)
495
496 def test_get(self):
497 a = SelectableInt(0xa2, 8)
498 # These should be big endian
499 self.assertEqual(a[7], 0)
500 self.assertEqual(a[0:4], 10)
501 self.assertEqual(a[4:8], 2)
502
503 def test_set(self):
504 a = SelectableInt(0x5, 8)
505 a[7] = SelectableInt(0, 1)
506 self.assertEqual(a, 4)
507 a[4:8] = 9
508 self.assertEqual(a, 9)
509 a[0:4] = 3
510 self.assertEqual(a, 0x39)
511 a[0:4] = a[4:8]
512 self.assertEqual(a, 0x99)
513
514 def test_concat(self):
515 a = SelectableInt(0x1, 1)
516 c = selectconcat(a, repeat=8)
517 self.assertEqual(c, 0xff)
518 self.assertEqual(c.bits, 8)
519 a = SelectableInt(0x0, 1)
520 c = selectconcat(a, repeat=8)
521 self.assertEqual(c, 0x00)
522 self.assertEqual(c.bits, 8)
523
524 def test_repr(self):
525 for i in range(65536):
526 a = SelectableInt(i, 16)
527 b = eval(repr(a))
528 self.assertEqual(a, b)
529
530 def test_cmp(self):
531 a = SelectableInt(10, bits=8)
532 b = SelectableInt(5, bits=8)
533 self.assertTrue(a > b)
534 self.assertFalse(a < b)
535 self.assertTrue(a != b)
536 self.assertFalse(a == b)
537
538 def test_unsigned(self):
539 a = SelectableInt(0x80, bits=8)
540 b = SelectableInt(0x7f, bits=8)
541 self.assertTrue(a > b)
542 self.assertFalse(a < b)
543 self.assertTrue(a != b)
544 self.assertFalse(a == b)
545
546
547 if __name__ == "__main__":
548 unittest.main()