add docstring comment for SelectableInt
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21 def __init__(self, si, br):
22 self.si = si # target selectable int
23 if isinstance(br, list) or isinstance(br, tuple):
24 _br = BitRange()
25 for i, v in enumerate(br):
26 _br[i] = v
27 br = _br
28 self.br = br # map of indices.
29
30 def eq(self, b):
31 if isinstance(b, SelectableInt):
32 for i in range(b.bits):
33 self[i] = b[i]
34 else:
35 self.si = copy(b.si)
36 self.br = copy(b.br)
37
38 def _op(self, op, b):
39 vi = self.get_range()
40 vi = op(vi, b)
41 return self.merge(vi)
42
43 def _op1(self, op):
44 vi = self.get_range()
45 vi = op(vi)
46 return self.merge(vi)
47
48 def __getitem__(self, key):
49 print ("getitem", key, self.br)
50 if isinstance(key, SelectableInt):
51 key = key.value
52 if isinstance(key, int):
53 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
54 return self.si[key]
55 if isinstance(key, slice):
56 key = self.br[key]
57 return selectconcat(*[self.si[x] for x in key])
58
59 def __setitem__(self, key, value):
60 if isinstance(key, SelectableInt):
61 key = key.value
62 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
63 if isinstance(key, int):
64 return self.si.__setitem__(key, value)
65 else:
66 if not isinstance(value, SelectableInt):
67 value = SelectableInt(value, bits=len(key))
68 for i, k in enumerate(key):
69 self.si[k] = value[i]
70
71 def __negate__(self):
72 return self._op1(negate)
73 def __invert__(self):
74 return self._op1(inv)
75 def __add__(self, b):
76 return self._op(add, b)
77 def __sub__(self, b):
78 return self._op(sub, b)
79 def __mul__(self, b):
80 return self._op(mul, b)
81 def __div__(self, b):
82 return self._op(truediv, b)
83 def __mod__(self, b):
84 return self._op(mod, b)
85 def __and__(self, b):
86 return self._op(and_, b)
87 def __or__(self, b):
88 return self._op(or_, b)
89 def __xor__(self, b):
90 return self._op(xor, b)
91
92 def get_range(self):
93 vi = SelectableInt(0, len(self.br))
94 for k, v in self.br.items():
95 vi[k] = self.si[v]
96 return vi
97
98 def merge(self, vi):
99 fi = copy(self)
100 for i, v in fi.br.items():
101 fi.si[v] = vi[i]
102 return fi
103
104 def __repr__(self):
105 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
106
107
108 class FieldSelectableIntTestCase(unittest.TestCase):
109 def test_arith(self):
110 a = SelectableInt(0b10101, 5)
111 b = SelectableInt(0b011, 3)
112 br = BitRange()
113 br[0] = 0
114 br[1] = 2
115 br[2] = 3
116 fs = FieldSelectableInt(a, br)
117 c = fs + b
118 print (c)
119 #self.assertEqual(c.value, a.value + b.value)
120
121 def test_select(self):
122 a = SelectableInt(0b00001111, 8)
123 br = BitRange()
124 br[0] = 0
125 br[1] = 1
126 br[2] = 4
127 br[3] = 5
128 fs = FieldSelectableInt(a, br)
129
130 self.assertEqual(fs.get_range(), 0b0011)
131
132 def test_select_range(self):
133 a = SelectableInt(0b00001111, 8)
134 br = BitRange()
135 br[0] = 0
136 br[1] = 1
137 br[2] = 4
138 br[3] = 5
139 fs = FieldSelectableInt(a, br)
140
141 self.assertEqual(fs[2:4], 0b11)
142
143 fs[0:2] = 0b10
144 self.assertEqual(fs.get_range(), 0b1011)
145
146
147 class SelectableInt:
148 """SelectableInt - a class that behaves exactly like python int
149
150 this class is designed to mirror precisely the behaviour of python int.
151 the only difference is that it must contain the context of the bitwidth
152 (number of bits) associated with that integer.
153
154 FieldSelectableInt can then operate on partial bits, and because there
155 is a bit width associated with SelectableInt, slices operate correctly
156 including negative start/end points.
157 """
158 def __init__(self, value, bits):
159 if isinstance(value, SelectableInt):
160 value = value.value
161 mask = (1 << bits) - 1
162 self.value = value & mask
163 self.bits = bits
164
165 def eq(self, b):
166 self.value = b.value
167 self.bits = b.bits
168
169 def _op(self, op, b):
170 if isinstance(b, int):
171 b = SelectableInt(b, self.bits)
172 b = check_extsign(self, b)
173 assert b.bits == self.bits
174 return SelectableInt(op(self.value, b.value), self.bits)
175
176 def __add__(self, b):
177 return self._op(add, b)
178 def __sub__(self, b):
179 return self._op(sub, b)
180 def __mul__(self, b):
181 return self._op(mul, b)
182 def __floordiv__(self, b):
183 return self._op(floordiv, b)
184 def __truediv__(self, b):
185 return self._op(truediv, b)
186 def __mod__(self, b):
187 return self._op(mod, b)
188 def __and__(self, b):
189 return self._op(and_, b)
190 def __or__(self, b):
191 return self._op(or_, b)
192 def __xor__(self, b):
193 return self._op(xor, b)
194 def __abs__(self):
195 return SelectableInt(0, self.bits) - self
196
197 def __rsub__(self, b):
198 if isinstance(b, int):
199 b = SelectableInt(b, self.bits)
200 b = check_extsign(self, b)
201 assert b.bits == self.bits
202 return SelectableInt(b.value - self.value, self.bits)
203
204 def __radd__(self, b):
205 if isinstance(b, int):
206 b = SelectableInt(b, self.bits)
207 b = check_extsign(self, b)
208 assert b.bits == self.bits
209 return SelectableInt(b.value + self.value, self.bits)
210
211 def __rxor__(self, b):
212 b = check_extsign(self, b)
213 assert b.bits == self.bits
214 return SelectableInt(self.value ^ b.value, self.bits)
215
216 def __invert__(self):
217 return SelectableInt(~self.value, self.bits)
218
219 def __neg__(self):
220 return SelectableInt(~self.value + 1, self.bits)
221
222 def __lshift__(self, b):
223 b = check_extsign(self, b)
224 return SelectableInt(self.value << b.value, self.bits)
225
226 def __rshift__(self, b):
227 b = check_extsign(self, b)
228 return SelectableInt(self.value >> b.value, self.bits)
229
230 def __getitem__(self, key):
231 if isinstance(key, SelectableInt):
232 key = key.value
233 if isinstance(key, int):
234 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
235 assert key >= 0
236 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
237 # MSB is indexed **LOWEST** (sigh)
238 key = self.bits - (key + 1)
239
240 value = (self.value >> key) & 1
241 return SelectableInt(value, 1)
242 elif isinstance(key, slice):
243 assert key.step is None or key.step == 1
244 assert key.start < key.stop
245 assert key.start >= 0
246 assert key.stop <= self.bits
247
248 stop = self.bits - key.start
249 start = self.bits - key.stop
250
251 bits = stop - start
252 #print ("__getitem__ slice num bits", bits)
253 mask = (1 << bits) - 1
254 value = (self.value >> start) & mask
255 return SelectableInt(value, bits)
256
257 def __setitem__(self, key, value):
258 if isinstance(key, SelectableInt):
259 key = key.value
260 if isinstance(key, int):
261 assert key < self.bits
262 assert key >= 0
263 key = self.bits - (key + 1)
264 if isinstance(value, SelectableInt):
265 assert value.bits == 1
266 value = value.value
267
268 value = value << key
269 mask = 1 << key
270 self.value = (self.value & ~mask) | (value & mask)
271 elif isinstance(key, slice):
272 assert key.step is None or key.step == 1
273 assert key.start < key.stop
274 assert key.start >= 0
275 assert key.stop <= self.bits
276
277 stop = self.bits - key.start
278 start = self.bits - key.stop
279
280 bits = stop - start
281 #print ("__setitem__ slice num bits", bits)
282 if isinstance(value, SelectableInt):
283 assert value.bits == bits, "%d into %d" % (value.bits, bits)
284 value = value.value
285 mask = ((1 << bits) - 1) << start
286 value = value << start
287 self.value = (self.value & ~mask) | (value & mask)
288
289 def __ge__(self, other):
290 if isinstance(other, FieldSelectableInt):
291 other = other.get_range()
292 if isinstance(other, SelectableInt):
293 other = check_extsign(self, other)
294 assert other.bits == self.bits
295 other = other.value
296 if isinstance(other, int):
297 return onebit(self.value >= other.value)
298 assert False
299
300 def __le__(self, other):
301 if isinstance(other, FieldSelectableInt):
302 other = other.get_range()
303 if isinstance(other, SelectableInt):
304 other = check_extsign(self, other)
305 assert other.bits == self.bits
306 other = other.value
307 if isinstance(other, int):
308 return onebit(self.value <= other)
309 assert False
310
311 def __gt__(self, other):
312 if isinstance(other, FieldSelectableInt):
313 other = other.get_range()
314 if isinstance(other, SelectableInt):
315 other = check_extsign(self, other)
316 assert other.bits == self.bits
317 other = other.value
318 if isinstance(other, int):
319 return onebit(self.value > other)
320 assert False
321
322 def __lt__(self, other):
323 if isinstance(other, FieldSelectableInt):
324 other = other.get_range()
325 if isinstance(other, SelectableInt):
326 other = check_extsign(self, other)
327 assert other.bits == self.bits
328 other = other.value
329 if isinstance(other, int):
330 return onebit(self.value < other)
331 assert False
332
333 def __eq__(self, other):
334 print ("__eq__", self, other)
335 if isinstance(other, FieldSelectableInt):
336 other = other.get_range()
337 if isinstance(other, SelectableInt):
338 other = check_extsign(self, other)
339 assert other.bits == self.bits
340 other = other.value
341 if isinstance(other, int):
342 return onebit(other == self.value)
343 assert False
344
345 def narrow(self, bits):
346 assert bits <= self.bits
347 return SelectableInt(self.value, bits)
348
349 def __bool__(self):
350 return self.value != 0
351
352 def __repr__(self):
353 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
354 self.bits)
355
356 def __len__(self):
357 return self.bits
358
359 def asint(self):
360 return self.value
361
362
363 def onebit(bit):
364 return SelectableInt(1 if bit else 0, 1)
365
366 def selectltu(lhs, rhs):
367 """ less-than (unsigned)
368 """
369 if isinstance(rhs, SelectableInt):
370 rhs = rhs.value
371 return onebit(lhs.value < rhs)
372
373 def selectgtu(lhs, rhs):
374 """ greater-than (unsigned)
375 """
376 if isinstance(rhs, SelectableInt):
377 rhs = rhs.value
378 return onebit(lhs.value > rhs)
379
380
381 # XXX this probably isn't needed...
382 def selectassign(lhs, idx, rhs):
383 if isinstance(idx, tuple):
384 if len(idx) == 2:
385 lower, upper = idx
386 step = None
387 else:
388 lower, upper, step = idx
389 toidx = range(lower, upper, step)
390 fromidx = range(0, upper-lower, step) # XXX eurgh...
391 else:
392 toidx = [idx]
393 fromidx = [0]
394 for t, f in zip(toidx, fromidx):
395 lhs[t] = rhs[f]
396
397
398 def selectconcat(*args, repeat=1):
399 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
400 args = [SelectableInt(args[0], 1)]
401 if repeat != 1: # multiplies the incoming arguments
402 tmp = []
403 for i in range(repeat):
404 tmp += args
405 args = tmp
406 res = copy(args[0])
407 for i in args[1:]:
408 if isinstance(i, FieldSelectableInt):
409 i = i.si
410 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
411 res.bits += i.bits
412 res.value = (res.value << i.bits) | i.value
413 print ("concat", repeat, res)
414 return res
415
416
417 class SelectableIntTestCase(unittest.TestCase):
418 def test_arith(self):
419 a = SelectableInt(5, 8)
420 b = SelectableInt(9, 8)
421 c = a + b
422 d = a - b
423 e = a * b
424 f = -a
425 g = abs(f)
426 h = abs(a)
427 self.assertEqual(c.value, a.value + b.value)
428 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
429 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
430 self.assertEqual(f.value, (-a.value) & 0xFF)
431 self.assertEqual(c.bits, a.bits)
432 self.assertEqual(d.bits, a.bits)
433 self.assertEqual(e.bits, a.bits)
434 self.assertEqual(a.bits, f.bits)
435 self.assertEqual(a.bits, h.bits)
436
437 def test_logic(self):
438 a = SelectableInt(0x0F, 8)
439 b = SelectableInt(0xA5, 8)
440 c = a & b
441 d = a | b
442 e = a ^ b
443 f = ~a
444 self.assertEqual(c.value, a.value & b.value)
445 self.assertEqual(d.value, a.value | b.value)
446 self.assertEqual(e.value, a.value ^ b.value)
447 self.assertEqual(f.value, 0xF0)
448
449 def test_get(self):
450 a = SelectableInt(0xa2, 8)
451 # These should be big endian
452 self.assertEqual(a[7], 0)
453 self.assertEqual(a[0:4], 10)
454 self.assertEqual(a[4:8], 2)
455
456 def test_set(self):
457 a = SelectableInt(0x5, 8)
458 a[7] = SelectableInt(0, 1)
459 self.assertEqual(a, 4)
460 a[4:8] = 9
461 self.assertEqual(a, 9)
462 a[0:4] = 3
463 self.assertEqual(a, 0x39)
464 a[0:4] = a[4:8]
465 self.assertEqual(a, 0x99)
466
467 def test_concat(self):
468 a = SelectableInt(0x1, 1)
469 c = selectconcat(a, repeat=8)
470 self.assertEqual(c, 0xff)
471 self.assertEqual(c.bits, 8)
472 a = SelectableInt(0x0, 1)
473 c = selectconcat(a, repeat=8)
474 self.assertEqual(c, 0x00)
475 self.assertEqual(c.bits, 8)
476
477 def test_repr(self):
478 for i in range(65536):
479 a = SelectableInt(i, 16)
480 b = eval(repr(a))
481 self.assertEqual(a, b)
482
483 def test_cmp(self):
484 a = SelectableInt(10, bits=8)
485 b = SelectableInt(5, bits=8)
486 self.assertTrue(a > b)
487 self.assertFalse(a < b)
488 self.assertTrue(a != b)
489 self.assertFalse(a == b)
490
491 def test_unsigned(self):
492 a = SelectableInt(0x80, bits=8)
493 b = SelectableInt(0x7f, bits=8)
494 self.assertTrue(a > b)
495 self.assertFalse(a < b)
496 self.assertTrue(a != b)
497 self.assertFalse(a == b)
498
499 if __name__ == "__main__":
500 unittest.main()