format code
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21
22 def __init__(self, si, br):
23 self.si = si # target selectable int
24 if isinstance(br, list) or isinstance(br, tuple):
25 _br = BitRange()
26 for i, v in enumerate(br):
27 _br[i] = v
28 br = _br
29 self.br = br # map of indices.
30
31 def eq(self, b):
32 if isinstance(b, SelectableInt):
33 for i in range(b.bits):
34 self[i] = b[i]
35 else:
36 self.si = copy(b.si)
37 self.br = copy(b.br)
38
39 def _op(self, op, b):
40 vi = self.get_range()
41 vi = op(vi, b)
42 return self.merge(vi)
43
44 def _op1(self, op):
45 vi = self.get_range()
46 vi = op(vi)
47 return self.merge(vi)
48
49 def __getitem__(self, key):
50 print("getitem", key, self.br)
51 if isinstance(key, SelectableInt):
52 key = key.value
53 if isinstance(key, int):
54 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
55 return self.si[key]
56 if isinstance(key, slice):
57 key = self.br[key]
58 return selectconcat(*[self.si[x] for x in key])
59
60 def __setitem__(self, key, value):
61 if isinstance(key, SelectableInt):
62 key = key.value
63 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
64 if isinstance(key, int):
65 return self.si.__setitem__(key, value)
66 else:
67 if not isinstance(value, SelectableInt):
68 value = SelectableInt(value, bits=len(key))
69 for i, k in enumerate(key):
70 self.si[k] = value[i]
71
72 def __negate__(self):
73 return self._op1(negate)
74
75 def __invert__(self):
76 return self._op1(inv)
77
78 def __add__(self, b):
79 return self._op(add, b)
80
81 def __sub__(self, b):
82 return self._op(sub, b)
83
84 def __mul__(self, b):
85 return self._op(mul, b)
86
87 def __div__(self, b):
88 return self._op(truediv, b)
89
90 def __mod__(self, b):
91 return self._op(mod, b)
92
93 def __and__(self, b):
94 return self._op(and_, b)
95
96 def __or__(self, b):
97 return self._op(or_, b)
98
99 def __xor__(self, b):
100 return self._op(xor, b)
101
102 def get_range(self):
103 vi = SelectableInt(0, len(self.br))
104 for k, v in self.br.items():
105 vi[k] = self.si[v]
106 return vi
107
108 def merge(self, vi):
109 fi = copy(self)
110 for i, v in fi.br.items():
111 fi.si[v] = vi[i]
112 return fi
113
114 def __repr__(self):
115 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
116
117
118 class FieldSelectableIntTestCase(unittest.TestCase):
119 def test_arith(self):
120 a = SelectableInt(0b10101, 5)
121 b = SelectableInt(0b011, 3)
122 br = BitRange()
123 br[0] = 0
124 br[1] = 2
125 br[2] = 3
126 fs = FieldSelectableInt(a, br)
127 c = fs + b
128 print(c)
129 #self.assertEqual(c.value, a.value + b.value)
130
131 def test_select(self):
132 a = SelectableInt(0b00001111, 8)
133 br = BitRange()
134 br[0] = 0
135 br[1] = 1
136 br[2] = 4
137 br[3] = 5
138 fs = FieldSelectableInt(a, br)
139
140 self.assertEqual(fs.get_range(), 0b0011)
141
142 def test_select_range(self):
143 a = SelectableInt(0b00001111, 8)
144 br = BitRange()
145 br[0] = 0
146 br[1] = 1
147 br[2] = 4
148 br[3] = 5
149 fs = FieldSelectableInt(a, br)
150
151 self.assertEqual(fs[2:4], 0b11)
152
153 fs[0:2] = 0b10
154 self.assertEqual(fs.get_range(), 0b1011)
155
156
157 class SelectableInt:
158 """SelectableInt - a class that behaves exactly like python int
159
160 this class is designed to mirror precisely the behaviour of python int.
161 the only difference is that it must contain the context of the bitwidth
162 (number of bits) associated with that integer.
163
164 FieldSelectableInt can then operate on partial bits, and because there
165 is a bit width associated with SelectableInt, slices operate correctly
166 including negative start/end points.
167 """
168
169 def __init__(self, value, bits):
170 if isinstance(value, SelectableInt):
171 value = value.value
172 mask = (1 << bits) - 1
173 self.value = value & mask
174 self.bits = bits
175
176 def eq(self, b):
177 self.value = b.value
178 self.bits = b.bits
179
180 def _op(self, op, b):
181 if isinstance(b, int):
182 b = SelectableInt(b, self.bits)
183 b = check_extsign(self, b)
184 assert b.bits == self.bits
185 return SelectableInt(op(self.value, b.value), self.bits)
186
187 def __add__(self, b):
188 return self._op(add, b)
189
190 def __sub__(self, b):
191 return self._op(sub, b)
192
193 def __mul__(self, b):
194 # different case: mul result needs to fit the total bitsize
195 if isinstance(b, int):
196 b = SelectableInt(b, self.bits)
197 print("SelectableInt mul", hex(self.value), hex(b.value),
198 self.bits, b.bits)
199 return SelectableInt(self.value * b.value, self.bits + b.bits)
200
201 def __floordiv__(self, b):
202 return self._op(floordiv, b)
203
204 def __truediv__(self, b):
205 return self._op(truediv, b)
206
207 def __mod__(self, b):
208 return self._op(mod, b)
209
210 def __and__(self, b):
211 return self._op(and_, b)
212
213 def __or__(self, b):
214 return self._op(or_, b)
215
216 def __xor__(self, b):
217 return self._op(xor, b)
218
219 def __abs__(self):
220 print("abs", self.value & (1 << (self.bits-1)))
221 if self.value & (1 << (self.bits-1)) != 0:
222 return -self
223 return self
224
225 def __rsub__(self, b):
226 if isinstance(b, int):
227 b = SelectableInt(b, self.bits)
228 b = check_extsign(self, b)
229 assert b.bits == self.bits
230 return SelectableInt(b.value - self.value, self.bits)
231
232 def __radd__(self, b):
233 if isinstance(b, int):
234 b = SelectableInt(b, self.bits)
235 b = check_extsign(self, b)
236 assert b.bits == self.bits
237 return SelectableInt(b.value + self.value, self.bits)
238
239 def __rxor__(self, b):
240 b = check_extsign(self, b)
241 assert b.bits == self.bits
242 return SelectableInt(self.value ^ b.value, self.bits)
243
244 def __invert__(self):
245 return SelectableInt(~self.value, self.bits)
246
247 def __neg__(self):
248 return SelectableInt(~self.value + 1, self.bits)
249
250 def __lshift__(self, b):
251 b = check_extsign(self, b)
252 return SelectableInt(self.value << b.value, self.bits)
253
254 def __rshift__(self, b):
255 b = check_extsign(self, b)
256 return SelectableInt(self.value >> b.value, self.bits)
257
258 def __getitem__(self, key):
259 if isinstance(key, SelectableInt):
260 key = key.value
261 if isinstance(key, int):
262 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
263 assert key >= 0
264 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
265 # MSB is indexed **LOWEST** (sigh)
266 key = self.bits - (key + 1)
267
268 value = (self.value >> key) & 1
269 return SelectableInt(value, 1)
270 elif isinstance(key, slice):
271 assert key.step is None or key.step == 1
272 assert key.start < key.stop
273 assert key.start >= 0
274 assert key.stop <= self.bits
275
276 stop = self.bits - key.start
277 start = self.bits - key.stop
278
279 bits = stop - start
280 #print ("__getitem__ slice num bits", bits)
281 mask = (1 << bits) - 1
282 value = (self.value >> start) & mask
283 return SelectableInt(value, bits)
284
285 def __setitem__(self, key, value):
286 if isinstance(key, SelectableInt):
287 key = key.value
288 if isinstance(key, int):
289 assert key < self.bits
290 assert key >= 0
291 key = self.bits - (key + 1)
292 if isinstance(value, SelectableInt):
293 assert value.bits == 1
294 value = value.value
295
296 value = value << key
297 mask = 1 << key
298 self.value = (self.value & ~mask) | (value & mask)
299 elif isinstance(key, slice):
300 assert key.step is None or key.step == 1
301 assert key.start < key.stop
302 assert key.start >= 0
303 assert key.stop <= self.bits
304
305 stop = self.bits - key.start
306 start = self.bits - key.stop
307
308 bits = stop - start
309 #print ("__setitem__ slice num bits", bits)
310 if isinstance(value, SelectableInt):
311 assert value.bits == bits, "%d into %d" % (value.bits, bits)
312 value = value.value
313 mask = ((1 << bits) - 1) << start
314 value = value << start
315 self.value = (self.value & ~mask) | (value & mask)
316
317 def __ge__(self, other):
318 if isinstance(other, FieldSelectableInt):
319 other = other.get_range()
320 if isinstance(other, SelectableInt):
321 other = check_extsign(self, other)
322 assert other.bits == self.bits
323 other = other.value
324 if isinstance(other, int):
325 return onebit(self.value >= other.value)
326 assert False
327
328 def __le__(self, other):
329 if isinstance(other, FieldSelectableInt):
330 other = other.get_range()
331 if isinstance(other, SelectableInt):
332 other = check_extsign(self, other)
333 assert other.bits == self.bits
334 other = other.value
335 if isinstance(other, int):
336 return onebit(self.value <= other)
337 assert False
338
339 def __gt__(self, other):
340 if isinstance(other, FieldSelectableInt):
341 other = other.get_range()
342 if isinstance(other, SelectableInt):
343 other = check_extsign(self, other)
344 assert other.bits == self.bits
345 other = other.value
346 if isinstance(other, int):
347 return onebit(self.value > other)
348 assert False
349
350 def __lt__(self, other):
351 if isinstance(other, FieldSelectableInt):
352 other = other.get_range()
353 if isinstance(other, SelectableInt):
354 other = check_extsign(self, other)
355 assert other.bits == self.bits
356 other = other.value
357 if isinstance(other, int):
358 return onebit(self.value < other)
359 assert False
360
361 def __eq__(self, other):
362 print("__eq__", self, other)
363 if isinstance(other, FieldSelectableInt):
364 other = other.get_range()
365 if isinstance(other, SelectableInt):
366 other = check_extsign(self, other)
367 assert other.bits == self.bits
368 other = other.value
369 if isinstance(other, int):
370 return onebit(other == self.value)
371 assert False
372
373 def narrow(self, bits):
374 assert bits <= self.bits
375 return SelectableInt(self.value, bits)
376
377 def __bool__(self):
378 return self.value != 0
379
380 def __repr__(self):
381 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
382 self.bits)
383
384 def __len__(self):
385 return self.bits
386
387 def asint(self):
388 return self.value
389
390
391 def onebit(bit):
392 return SelectableInt(1 if bit else 0, 1)
393
394
395 def selectltu(lhs, rhs):
396 """ less-than (unsigned)
397 """
398 if isinstance(rhs, SelectableInt):
399 rhs = rhs.value
400 return onebit(lhs.value < rhs)
401
402
403 def selectgtu(lhs, rhs):
404 """ greater-than (unsigned)
405 """
406 if isinstance(rhs, SelectableInt):
407 rhs = rhs.value
408 return onebit(lhs.value > rhs)
409
410
411 # XXX this probably isn't needed...
412 def selectassign(lhs, idx, rhs):
413 if isinstance(idx, tuple):
414 if len(idx) == 2:
415 lower, upper = idx
416 step = None
417 else:
418 lower, upper, step = idx
419 toidx = range(lower, upper, step)
420 fromidx = range(0, upper-lower, step) # XXX eurgh...
421 else:
422 toidx = [idx]
423 fromidx = [0]
424 for t, f in zip(toidx, fromidx):
425 lhs[t] = rhs[f]
426
427
428 def selectconcat(*args, repeat=1):
429 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
430 args = [SelectableInt(args[0], 1)]
431 if repeat != 1: # multiplies the incoming arguments
432 tmp = []
433 for i in range(repeat):
434 tmp += args
435 args = tmp
436 res = copy(args[0])
437 for i in args[1:]:
438 if isinstance(i, FieldSelectableInt):
439 i = i.si
440 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
441 res.bits += i.bits
442 res.value = (res.value << i.bits) | i.value
443 print("concat", repeat, res)
444 return res
445
446
447 class SelectableIntTestCase(unittest.TestCase):
448 def test_arith(self):
449 a = SelectableInt(5, 8)
450 b = SelectableInt(9, 8)
451 c = a + b
452 d = a - b
453 e = a * b
454 f = -a
455 g = abs(f)
456 h = abs(a)
457 self.assertEqual(c.value, a.value + b.value)
458 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
459 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
460 self.assertEqual(f.value, (-a.value) & 0xFF)
461 self.assertEqual(c.bits, a.bits)
462 self.assertEqual(d.bits, a.bits)
463 self.assertEqual(e.bits, a.bits)
464 self.assertEqual(a.bits, f.bits)
465 self.assertEqual(a.bits, h.bits)
466
467 def test_logic(self):
468 a = SelectableInt(0x0F, 8)
469 b = SelectableInt(0xA5, 8)
470 c = a & b
471 d = a | b
472 e = a ^ b
473 f = ~a
474 self.assertEqual(c.value, a.value & b.value)
475 self.assertEqual(d.value, a.value | b.value)
476 self.assertEqual(e.value, a.value ^ b.value)
477 self.assertEqual(f.value, 0xF0)
478
479 def test_get(self):
480 a = SelectableInt(0xa2, 8)
481 # These should be big endian
482 self.assertEqual(a[7], 0)
483 self.assertEqual(a[0:4], 10)
484 self.assertEqual(a[4:8], 2)
485
486 def test_set(self):
487 a = SelectableInt(0x5, 8)
488 a[7] = SelectableInt(0, 1)
489 self.assertEqual(a, 4)
490 a[4:8] = 9
491 self.assertEqual(a, 9)
492 a[0:4] = 3
493 self.assertEqual(a, 0x39)
494 a[0:4] = a[4:8]
495 self.assertEqual(a, 0x99)
496
497 def test_concat(self):
498 a = SelectableInt(0x1, 1)
499 c = selectconcat(a, repeat=8)
500 self.assertEqual(c, 0xff)
501 self.assertEqual(c.bits, 8)
502 a = SelectableInt(0x0, 1)
503 c = selectconcat(a, repeat=8)
504 self.assertEqual(c, 0x00)
505 self.assertEqual(c.bits, 8)
506
507 def test_repr(self):
508 for i in range(65536):
509 a = SelectableInt(i, 16)
510 b = eval(repr(a))
511 self.assertEqual(a, b)
512
513 def test_cmp(self):
514 a = SelectableInt(10, bits=8)
515 b = SelectableInt(5, bits=8)
516 self.assertTrue(a > b)
517 self.assertFalse(a < b)
518 self.assertTrue(a != b)
519 self.assertFalse(a == b)
520
521 def test_unsigned(self):
522 a = SelectableInt(0x80, bits=8)
523 b = SelectableInt(0x7f, bits=8)
524 self.assertTrue(a > b)
525 self.assertFalse(a < b)
526 self.assertTrue(a != b)
527 self.assertFalse(a == b)
528
529
530 if __name__ == "__main__":
531 unittest.main()