fix SelectableInt abs
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21 def __init__(self, si, br):
22 self.si = si # target selectable int
23 if isinstance(br, list) or isinstance(br, tuple):
24 _br = BitRange()
25 for i, v in enumerate(br):
26 _br[i] = v
27 br = _br
28 self.br = br # map of indices.
29
30 def eq(self, b):
31 if isinstance(b, SelectableInt):
32 for i in range(b.bits):
33 self[i] = b[i]
34 else:
35 self.si = copy(b.si)
36 self.br = copy(b.br)
37
38 def _op(self, op, b):
39 vi = self.get_range()
40 vi = op(vi, b)
41 return self.merge(vi)
42
43 def _op1(self, op):
44 vi = self.get_range()
45 vi = op(vi)
46 return self.merge(vi)
47
48 def __getitem__(self, key):
49 print ("getitem", key, self.br)
50 if isinstance(key, SelectableInt):
51 key = key.value
52 if isinstance(key, int):
53 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
54 return self.si[key]
55 if isinstance(key, slice):
56 key = self.br[key]
57 return selectconcat(*[self.si[x] for x in key])
58
59 def __setitem__(self, key, value):
60 if isinstance(key, SelectableInt):
61 key = key.value
62 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
63 if isinstance(key, int):
64 return self.si.__setitem__(key, value)
65 else:
66 if not isinstance(value, SelectableInt):
67 value = SelectableInt(value, bits=len(key))
68 for i, k in enumerate(key):
69 self.si[k] = value[i]
70
71 def __negate__(self):
72 return self._op1(negate)
73 def __invert__(self):
74 return self._op1(inv)
75 def __add__(self, b):
76 return self._op(add, b)
77 def __sub__(self, b):
78 return self._op(sub, b)
79 def __mul__(self, b):
80 return self._op(mul, b)
81 def __div__(self, b):
82 return self._op(truediv, b)
83 def __mod__(self, b):
84 return self._op(mod, b)
85 def __and__(self, b):
86 return self._op(and_, b)
87 def __or__(self, b):
88 return self._op(or_, b)
89 def __xor__(self, b):
90 return self._op(xor, b)
91
92 def get_range(self):
93 vi = SelectableInt(0, len(self.br))
94 for k, v in self.br.items():
95 vi[k] = self.si[v]
96 return vi
97
98 def merge(self, vi):
99 fi = copy(self)
100 for i, v in fi.br.items():
101 fi.si[v] = vi[i]
102 return fi
103
104 def __repr__(self):
105 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
106
107
108 class FieldSelectableIntTestCase(unittest.TestCase):
109 def test_arith(self):
110 a = SelectableInt(0b10101, 5)
111 b = SelectableInt(0b011, 3)
112 br = BitRange()
113 br[0] = 0
114 br[1] = 2
115 br[2] = 3
116 fs = FieldSelectableInt(a, br)
117 c = fs + b
118 print (c)
119 #self.assertEqual(c.value, a.value + b.value)
120
121 def test_select(self):
122 a = SelectableInt(0b00001111, 8)
123 br = BitRange()
124 br[0] = 0
125 br[1] = 1
126 br[2] = 4
127 br[3] = 5
128 fs = FieldSelectableInt(a, br)
129
130 self.assertEqual(fs.get_range(), 0b0011)
131
132 def test_select_range(self):
133 a = SelectableInt(0b00001111, 8)
134 br = BitRange()
135 br[0] = 0
136 br[1] = 1
137 br[2] = 4
138 br[3] = 5
139 fs = FieldSelectableInt(a, br)
140
141 self.assertEqual(fs[2:4], 0b11)
142
143 fs[0:2] = 0b10
144 self.assertEqual(fs.get_range(), 0b1011)
145
146
147 class SelectableInt:
148 """SelectableInt - a class that behaves exactly like python int
149
150 this class is designed to mirror precisely the behaviour of python int.
151 the only difference is that it must contain the context of the bitwidth
152 (number of bits) associated with that integer.
153
154 FieldSelectableInt can then operate on partial bits, and because there
155 is a bit width associated with SelectableInt, slices operate correctly
156 including negative start/end points.
157 """
158 def __init__(self, value, bits):
159 if isinstance(value, SelectableInt):
160 value = value.value
161 mask = (1 << bits) - 1
162 self.value = value & mask
163 self.bits = bits
164
165 def eq(self, b):
166 self.value = b.value
167 self.bits = b.bits
168
169 def _op(self, op, b):
170 if isinstance(b, int):
171 b = SelectableInt(b, self.bits)
172 b = check_extsign(self, b)
173 assert b.bits == self.bits
174 return SelectableInt(op(self.value, b.value), self.bits)
175
176 def __add__(self, b):
177 return self._op(add, b)
178 def __sub__(self, b):
179 return self._op(sub, b)
180 def __mul__(self, b):
181 # different case: mul result needs to fit the total bitsize
182 if isinstance(b, int):
183 b = SelectableInt(b, self.bits)
184 print ("SelectableInt mul", hex(self.value), hex(b.value),
185 self.bits, b.bits)
186 return SelectableInt(self.value * b.value, self.bits + b.bits)
187 def __floordiv__(self, b):
188 return self._op(floordiv, b)
189 def __truediv__(self, b):
190 return self._op(truediv, b)
191 def __mod__(self, b):
192 return self._op(mod, b)
193 def __and__(self, b):
194 return self._op(and_, b)
195 def __or__(self, b):
196 return self._op(or_, b)
197 def __xor__(self, b):
198 return self._op(xor, b)
199 def __abs__(self):
200 print ("abs", self.value & (1<<(self.bits-1)))
201 if self.value & (1<<(self.bits-1)) != 0:
202 return SelectableInt(0, self.bits) - self
203 return self
204
205 def __rsub__(self, b):
206 if isinstance(b, int):
207 b = SelectableInt(b, self.bits)
208 b = check_extsign(self, b)
209 assert b.bits == self.bits
210 return SelectableInt(b.value - self.value, self.bits)
211
212 def __radd__(self, b):
213 if isinstance(b, int):
214 b = SelectableInt(b, self.bits)
215 b = check_extsign(self, b)
216 assert b.bits == self.bits
217 return SelectableInt(b.value + self.value, self.bits)
218
219 def __rxor__(self, b):
220 b = check_extsign(self, b)
221 assert b.bits == self.bits
222 return SelectableInt(self.value ^ b.value, self.bits)
223
224 def __invert__(self):
225 return SelectableInt(~self.value, self.bits)
226
227 def __neg__(self):
228 return SelectableInt(~self.value + 1, self.bits)
229
230 def __lshift__(self, b):
231 b = check_extsign(self, b)
232 return SelectableInt(self.value << b.value, self.bits)
233
234 def __rshift__(self, b):
235 b = check_extsign(self, b)
236 return SelectableInt(self.value >> b.value, self.bits)
237
238 def __getitem__(self, key):
239 if isinstance(key, SelectableInt):
240 key = key.value
241 if isinstance(key, int):
242 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
243 assert key >= 0
244 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
245 # MSB is indexed **LOWEST** (sigh)
246 key = self.bits - (key + 1)
247
248 value = (self.value >> key) & 1
249 return SelectableInt(value, 1)
250 elif isinstance(key, slice):
251 assert key.step is None or key.step == 1
252 assert key.start < key.stop
253 assert key.start >= 0
254 assert key.stop <= self.bits
255
256 stop = self.bits - key.start
257 start = self.bits - key.stop
258
259 bits = stop - start
260 #print ("__getitem__ slice num bits", bits)
261 mask = (1 << bits) - 1
262 value = (self.value >> start) & mask
263 return SelectableInt(value, bits)
264
265 def __setitem__(self, key, value):
266 if isinstance(key, SelectableInt):
267 key = key.value
268 if isinstance(key, int):
269 assert key < self.bits
270 assert key >= 0
271 key = self.bits - (key + 1)
272 if isinstance(value, SelectableInt):
273 assert value.bits == 1
274 value = value.value
275
276 value = value << key
277 mask = 1 << key
278 self.value = (self.value & ~mask) | (value & mask)
279 elif isinstance(key, slice):
280 assert key.step is None or key.step == 1
281 assert key.start < key.stop
282 assert key.start >= 0
283 assert key.stop <= self.bits
284
285 stop = self.bits - key.start
286 start = self.bits - key.stop
287
288 bits = stop - start
289 #print ("__setitem__ slice num bits", bits)
290 if isinstance(value, SelectableInt):
291 assert value.bits == bits, "%d into %d" % (value.bits, bits)
292 value = value.value
293 mask = ((1 << bits) - 1) << start
294 value = value << start
295 self.value = (self.value & ~mask) | (value & mask)
296
297 def __ge__(self, other):
298 if isinstance(other, FieldSelectableInt):
299 other = other.get_range()
300 if isinstance(other, SelectableInt):
301 other = check_extsign(self, other)
302 assert other.bits == self.bits
303 other = other.value
304 if isinstance(other, int):
305 return onebit(self.value >= other.value)
306 assert False
307
308 def __le__(self, other):
309 if isinstance(other, FieldSelectableInt):
310 other = other.get_range()
311 if isinstance(other, SelectableInt):
312 other = check_extsign(self, other)
313 assert other.bits == self.bits
314 other = other.value
315 if isinstance(other, int):
316 return onebit(self.value <= other)
317 assert False
318
319 def __gt__(self, other):
320 if isinstance(other, FieldSelectableInt):
321 other = other.get_range()
322 if isinstance(other, SelectableInt):
323 other = check_extsign(self, other)
324 assert other.bits == self.bits
325 other = other.value
326 if isinstance(other, int):
327 return onebit(self.value > other)
328 assert False
329
330 def __lt__(self, other):
331 if isinstance(other, FieldSelectableInt):
332 other = other.get_range()
333 if isinstance(other, SelectableInt):
334 other = check_extsign(self, other)
335 assert other.bits == self.bits
336 other = other.value
337 if isinstance(other, int):
338 return onebit(self.value < other)
339 assert False
340
341 def __eq__(self, other):
342 print ("__eq__", self, other)
343 if isinstance(other, FieldSelectableInt):
344 other = other.get_range()
345 if isinstance(other, SelectableInt):
346 other = check_extsign(self, other)
347 assert other.bits == self.bits
348 other = other.value
349 if isinstance(other, int):
350 return onebit(other == self.value)
351 assert False
352
353 def narrow(self, bits):
354 assert bits <= self.bits
355 return SelectableInt(self.value, bits)
356
357 def __bool__(self):
358 return self.value != 0
359
360 def __repr__(self):
361 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
362 self.bits)
363
364 def __len__(self):
365 return self.bits
366
367 def asint(self):
368 return self.value
369
370
371 def onebit(bit):
372 return SelectableInt(1 if bit else 0, 1)
373
374 def selectltu(lhs, rhs):
375 """ less-than (unsigned)
376 """
377 if isinstance(rhs, SelectableInt):
378 rhs = rhs.value
379 return onebit(lhs.value < rhs)
380
381 def selectgtu(lhs, rhs):
382 """ greater-than (unsigned)
383 """
384 if isinstance(rhs, SelectableInt):
385 rhs = rhs.value
386 return onebit(lhs.value > rhs)
387
388
389 # XXX this probably isn't needed...
390 def selectassign(lhs, idx, rhs):
391 if isinstance(idx, tuple):
392 if len(idx) == 2:
393 lower, upper = idx
394 step = None
395 else:
396 lower, upper, step = idx
397 toidx = range(lower, upper, step)
398 fromidx = range(0, upper-lower, step) # XXX eurgh...
399 else:
400 toidx = [idx]
401 fromidx = [0]
402 for t, f in zip(toidx, fromidx):
403 lhs[t] = rhs[f]
404
405
406 def selectconcat(*args, repeat=1):
407 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
408 args = [SelectableInt(args[0], 1)]
409 if repeat != 1: # multiplies the incoming arguments
410 tmp = []
411 for i in range(repeat):
412 tmp += args
413 args = tmp
414 res = copy(args[0])
415 for i in args[1:]:
416 if isinstance(i, FieldSelectableInt):
417 i = i.si
418 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
419 res.bits += i.bits
420 res.value = (res.value << i.bits) | i.value
421 print ("concat", repeat, res)
422 return res
423
424
425 class SelectableIntTestCase(unittest.TestCase):
426 def test_arith(self):
427 a = SelectableInt(5, 8)
428 b = SelectableInt(9, 8)
429 c = a + b
430 d = a - b
431 e = a * b
432 f = -a
433 g = abs(f)
434 h = abs(a)
435 self.assertEqual(c.value, a.value + b.value)
436 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
437 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
438 self.assertEqual(f.value, (-a.value) & 0xFF)
439 self.assertEqual(c.bits, a.bits)
440 self.assertEqual(d.bits, a.bits)
441 self.assertEqual(e.bits, a.bits)
442 self.assertEqual(a.bits, f.bits)
443 self.assertEqual(a.bits, h.bits)
444
445 def test_logic(self):
446 a = SelectableInt(0x0F, 8)
447 b = SelectableInt(0xA5, 8)
448 c = a & b
449 d = a | b
450 e = a ^ b
451 f = ~a
452 self.assertEqual(c.value, a.value & b.value)
453 self.assertEqual(d.value, a.value | b.value)
454 self.assertEqual(e.value, a.value ^ b.value)
455 self.assertEqual(f.value, 0xF0)
456
457 def test_get(self):
458 a = SelectableInt(0xa2, 8)
459 # These should be big endian
460 self.assertEqual(a[7], 0)
461 self.assertEqual(a[0:4], 10)
462 self.assertEqual(a[4:8], 2)
463
464 def test_set(self):
465 a = SelectableInt(0x5, 8)
466 a[7] = SelectableInt(0, 1)
467 self.assertEqual(a, 4)
468 a[4:8] = 9
469 self.assertEqual(a, 9)
470 a[0:4] = 3
471 self.assertEqual(a, 0x39)
472 a[0:4] = a[4:8]
473 self.assertEqual(a, 0x99)
474
475 def test_concat(self):
476 a = SelectableInt(0x1, 1)
477 c = selectconcat(a, repeat=8)
478 self.assertEqual(c, 0xff)
479 self.assertEqual(c.bits, 8)
480 a = SelectableInt(0x0, 1)
481 c = selectconcat(a, repeat=8)
482 self.assertEqual(c, 0x00)
483 self.assertEqual(c.bits, 8)
484
485 def test_repr(self):
486 for i in range(65536):
487 a = SelectableInt(i, 16)
488 b = eval(repr(a))
489 self.assertEqual(a, b)
490
491 def test_cmp(self):
492 a = SelectableInt(10, bits=8)
493 b = SelectableInt(5, bits=8)
494 self.assertTrue(a > b)
495 self.assertFalse(a < b)
496 self.assertTrue(a != b)
497 self.assertFalse(a == b)
498
499 def test_unsigned(self):
500 a = SelectableInt(0x80, bits=8)
501 b = SelectableInt(0x7f, bits=8)
502 self.assertTrue(a > b)
503 self.assertFalse(a < b)
504 self.assertTrue(a != b)
505 self.assertFalse(a == b)
506
507 if __name__ == "__main__":
508 unittest.main()