add SelectableInt.abs
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21 def __init__(self, si, br):
22 self.si = si # target selectable int
23 if isinstance(br, list) or isinstance(br, tuple):
24 _br = BitRange()
25 for i, v in enumerate(br):
26 _br[i] = v
27 br = _br
28 self.br = br # map of indices.
29
30 def eq(self, b):
31 if isinstance(b, SelectableInt):
32 for i in range(b.bits):
33 self[i] = b[i]
34 else:
35 self.si = copy(b.si)
36 self.br = copy(b.br)
37
38 def _op(self, op, b):
39 vi = self.get_range()
40 vi = op(vi, b)
41 return self.merge(vi)
42
43 def _op1(self, op):
44 vi = self.get_range()
45 vi = op(vi)
46 return self.merge(vi)
47
48 def __getitem__(self, key):
49 print ("getitem", key, self.br)
50 if isinstance(key, SelectableInt):
51 key = key.value
52 if isinstance(key, int):
53 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
54 return self.si[key]
55 if isinstance(key, slice):
56 key = self.br[key]
57 return selectconcat(*[self.si[x] for x in key])
58
59 def __setitem__(self, key, value):
60 if isinstance(key, SelectableInt):
61 key = key.value
62 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
63 if isinstance(key, int):
64 return self.si.__setitem__(key, value)
65 else:
66 if not isinstance(value, SelectableInt):
67 value = SelectableInt(value, bits=len(key))
68 for i, k in enumerate(key):
69 self.si[k] = value[i]
70
71 def __negate__(self):
72 return self._op1(negate)
73 def __invert__(self):
74 return self._op1(inv)
75 def __add__(self, b):
76 return self._op(add, b)
77 def __sub__(self, b):
78 return self._op(sub, b)
79 def __mul__(self, b):
80 return self._op(mul, b)
81 def __div__(self, b):
82 return self._op(truediv, b)
83 def __mod__(self, b):
84 return self._op(mod, b)
85 def __and__(self, b):
86 return self._op(and_, b)
87 def __or__(self, b):
88 return self._op(or_, b)
89 def __xor__(self, b):
90 return self._op(xor, b)
91
92 def get_range(self):
93 vi = SelectableInt(0, len(self.br))
94 for k, v in self.br.items():
95 vi[k] = self.si[v]
96 return vi
97
98 def merge(self, vi):
99 fi = copy(self)
100 for i, v in fi.br.items():
101 fi.si[v] = vi[i]
102 return fi
103
104 def __repr__(self):
105 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
106
107
108 class FieldSelectableIntTestCase(unittest.TestCase):
109 def test_arith(self):
110 a = SelectableInt(0b10101, 5)
111 b = SelectableInt(0b011, 3)
112 br = BitRange()
113 br[0] = 0
114 br[1] = 2
115 br[2] = 3
116 fs = FieldSelectableInt(a, br)
117 c = fs + b
118 print (c)
119 #self.assertEqual(c.value, a.value + b.value)
120
121 def test_select(self):
122 a = SelectableInt(0b00001111, 8)
123 br = BitRange()
124 br[0] = 0
125 br[1] = 1
126 br[2] = 4
127 br[3] = 5
128 fs = FieldSelectableInt(a, br)
129
130 self.assertEqual(fs.get_range(), 0b0011)
131
132 def test_select_range(self):
133 a = SelectableInt(0b00001111, 8)
134 br = BitRange()
135 br[0] = 0
136 br[1] = 1
137 br[2] = 4
138 br[3] = 5
139 fs = FieldSelectableInt(a, br)
140
141 self.assertEqual(fs[2:4], 0b11)
142
143 fs[0:2] = 0b10
144 self.assertEqual(fs.get_range(), 0b1011)
145
146
147 class SelectableInt:
148 def __init__(self, value, bits):
149 if isinstance(value, SelectableInt):
150 value = value.value
151 mask = (1 << bits) - 1
152 self.value = value & mask
153 self.bits = bits
154
155 def eq(self, b):
156 self.value = b.value
157 self.bits = b.bits
158
159 def _op(self, op, b):
160 if isinstance(b, int):
161 b = SelectableInt(b, self.bits)
162 b = check_extsign(self, b)
163 assert b.bits == self.bits
164 return SelectableInt(op(self.value, b.value), self.bits)
165
166 def __add__(self, b):
167 return self._op(add, b)
168 def __sub__(self, b):
169 return self._op(sub, b)
170 def __mul__(self, b):
171 return self._op(mul, b)
172 def __floordiv__(self, b):
173 return self._op(floordiv, b)
174 def __truediv__(self, b):
175 return self._op(truediv, b)
176 def __mod__(self, b):
177 return self._op(mod, b)
178 def __and__(self, b):
179 return self._op(and_, b)
180 def __or__(self, b):
181 return self._op(or_, b)
182 def __xor__(self, b):
183 return self._op(xor, b)
184 def __abs__(self):
185 return SelectableInt(0, self.bits) - self
186
187 def __rsub__(self, b):
188 if isinstance(b, int):
189 b = SelectableInt(b, self.bits)
190 b = check_extsign(self, b)
191 assert b.bits == self.bits
192 return SelectableInt(b.value - self.value, self.bits)
193
194 def __radd__(self, b):
195 if isinstance(b, int):
196 b = SelectableInt(b, self.bits)
197 b = check_extsign(self, b)
198 assert b.bits == self.bits
199 return SelectableInt(b.value + self.value, self.bits)
200
201 def __rxor__(self, b):
202 b = check_extsign(self, b)
203 assert b.bits == self.bits
204 return SelectableInt(self.value ^ b.value, self.bits)
205
206 def __invert__(self):
207 return SelectableInt(~self.value, self.bits)
208
209 def __neg__(self):
210 return SelectableInt(~self.value + 1, self.bits)
211
212 def __lshift__(self, b):
213 b = check_extsign(self, b)
214 return SelectableInt(self.value << b.value, self.bits)
215
216 def __rshift__(self, b):
217 b = check_extsign(self, b)
218 return SelectableInt(self.value >> b.value, self.bits)
219
220 def __getitem__(self, key):
221 if isinstance(key, SelectableInt):
222 key = key.value
223 if isinstance(key, int):
224 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
225 assert key >= 0
226 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
227 # MSB is indexed **LOWEST** (sigh)
228 key = self.bits - (key + 1)
229
230 value = (self.value >> key) & 1
231 return SelectableInt(value, 1)
232 elif isinstance(key, slice):
233 assert key.step is None or key.step == 1
234 assert key.start < key.stop
235 assert key.start >= 0
236 assert key.stop <= self.bits
237
238 stop = self.bits - key.start
239 start = self.bits - key.stop
240
241 bits = stop - start
242 #print ("__getitem__ slice num bits", bits)
243 mask = (1 << bits) - 1
244 value = (self.value >> start) & mask
245 return SelectableInt(value, bits)
246
247 def __setitem__(self, key, value):
248 if isinstance(key, SelectableInt):
249 key = key.value
250 if isinstance(key, int):
251 assert key < self.bits
252 assert key >= 0
253 key = self.bits - (key + 1)
254 if isinstance(value, SelectableInt):
255 assert value.bits == 1
256 value = value.value
257
258 value = value << key
259 mask = 1 << key
260 self.value = (self.value & ~mask) | (value & mask)
261 elif isinstance(key, slice):
262 assert key.step is None or key.step == 1
263 assert key.start < key.stop
264 assert key.start >= 0
265 assert key.stop <= self.bits
266
267 stop = self.bits - key.start
268 start = self.bits - key.stop
269
270 bits = stop - start
271 #print ("__setitem__ slice num bits", bits)
272 if isinstance(value, SelectableInt):
273 assert value.bits == bits, "%d into %d" % (value.bits, bits)
274 value = value.value
275 mask = ((1 << bits) - 1) << start
276 value = value << start
277 self.value = (self.value & ~mask) | (value & mask)
278
279 def __ge__(self, other):
280 if isinstance(other, FieldSelectableInt):
281 other = other.get_range()
282 if isinstance(other, SelectableInt):
283 other = check_extsign(self, other)
284 assert other.bits == self.bits
285 other = other.value
286 if isinstance(other, int):
287 return onebit(self.value >= other.value)
288 assert False
289
290 def __le__(self, other):
291 if isinstance(other, FieldSelectableInt):
292 other = other.get_range()
293 if isinstance(other, SelectableInt):
294 other = check_extsign(self, other)
295 assert other.bits == self.bits
296 other = other.value
297 if isinstance(other, int):
298 return onebit(self.value <= other)
299 assert False
300
301 def __gt__(self, other):
302 if isinstance(other, FieldSelectableInt):
303 other = other.get_range()
304 if isinstance(other, SelectableInt):
305 other = check_extsign(self, other)
306 assert other.bits == self.bits
307 other = other.value
308 if isinstance(other, int):
309 return onebit(self.value > other)
310 assert False
311
312 def __lt__(self, other):
313 if isinstance(other, FieldSelectableInt):
314 other = other.get_range()
315 if isinstance(other, SelectableInt):
316 other = check_extsign(self, other)
317 assert other.bits == self.bits
318 other = other.value
319 if isinstance(other, int):
320 return onebit(self.value < other)
321 assert False
322
323 def __eq__(self, other):
324 print ("__eq__", self, other)
325 if isinstance(other, FieldSelectableInt):
326 other = other.get_range()
327 if isinstance(other, SelectableInt):
328 other = check_extsign(self, other)
329 assert other.bits == self.bits
330 other = other.value
331 if isinstance(other, int):
332 return onebit(other == self.value)
333 assert False
334
335 def narrow(self, bits):
336 assert bits <= self.bits
337 return SelectableInt(self.value, bits)
338
339 def __bool__(self):
340 return self.value != 0
341
342 def __repr__(self):
343 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
344 self.bits)
345
346 def __len__(self):
347 return self.bits
348
349 def asint(self):
350 return self.value
351
352
353 def onebit(bit):
354 return SelectableInt(1 if bit else 0, 1)
355
356 def selectltu(lhs, rhs):
357 """ less-than (unsigned)
358 """
359 if isinstance(rhs, SelectableInt):
360 rhs = rhs.value
361 return onebit(lhs.value < rhs)
362
363 def selectgtu(lhs, rhs):
364 """ greater-than (unsigned)
365 """
366 if isinstance(rhs, SelectableInt):
367 rhs = rhs.value
368 return onebit(lhs.value > rhs)
369
370
371 # XXX this probably isn't needed...
372 def selectassign(lhs, idx, rhs):
373 if isinstance(idx, tuple):
374 if len(idx) == 2:
375 lower, upper = idx
376 step = None
377 else:
378 lower, upper, step = idx
379 toidx = range(lower, upper, step)
380 fromidx = range(0, upper-lower, step) # XXX eurgh...
381 else:
382 toidx = [idx]
383 fromidx = [0]
384 for t, f in zip(toidx, fromidx):
385 lhs[t] = rhs[f]
386
387
388 def selectconcat(*args, repeat=1):
389 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
390 args = [SelectableInt(args[0], 1)]
391 if repeat != 1: # multiplies the incoming arguments
392 tmp = []
393 for i in range(repeat):
394 tmp += args
395 args = tmp
396 res = copy(args[0])
397 for i in args[1:]:
398 if isinstance(i, FieldSelectableInt):
399 i = i.si
400 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
401 res.bits += i.bits
402 res.value = (res.value << i.bits) | i.value
403 print ("concat", repeat, res)
404 return res
405
406
407 class SelectableIntTestCase(unittest.TestCase):
408 def test_arith(self):
409 a = SelectableInt(5, 8)
410 b = SelectableInt(9, 8)
411 c = a + b
412 d = a - b
413 e = a * b
414 f = -a
415 self.assertEqual(c.value, a.value + b.value)
416 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
417 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
418 self.assertEqual(f.value, (-a.value) & 0xFF)
419 self.assertEqual(c.bits, a.bits)
420 self.assertEqual(d.bits, a.bits)
421 self.assertEqual(e.bits, a.bits)
422
423 def test_logic(self):
424 a = SelectableInt(0x0F, 8)
425 b = SelectableInt(0xA5, 8)
426 c = a & b
427 d = a | b
428 e = a ^ b
429 f = ~a
430 self.assertEqual(c.value, a.value & b.value)
431 self.assertEqual(d.value, a.value | b.value)
432 self.assertEqual(e.value, a.value ^ b.value)
433 self.assertEqual(f.value, 0xF0)
434
435 def test_get(self):
436 a = SelectableInt(0xa2, 8)
437 # These should be big endian
438 self.assertEqual(a[7], 0)
439 self.assertEqual(a[0:4], 10)
440 self.assertEqual(a[4:8], 2)
441
442 def test_set(self):
443 a = SelectableInt(0x5, 8)
444 a[7] = SelectableInt(0, 1)
445 self.assertEqual(a, 4)
446 a[4:8] = 9
447 self.assertEqual(a, 9)
448 a[0:4] = 3
449 self.assertEqual(a, 0x39)
450 a[0:4] = a[4:8]
451 self.assertEqual(a, 0x99)
452
453 def test_concat(self):
454 a = SelectableInt(0x1, 1)
455 c = selectconcat(a, repeat=8)
456 self.assertEqual(c, 0xff)
457 self.assertEqual(c.bits, 8)
458 a = SelectableInt(0x0, 1)
459 c = selectconcat(a, repeat=8)
460 self.assertEqual(c, 0x00)
461 self.assertEqual(c.bits, 8)
462
463 def test_repr(self):
464 for i in range(65536):
465 a = SelectableInt(i, 16)
466 b = eval(repr(a))
467 self.assertEqual(a, b)
468
469 def test_cmp(self):
470 a = SelectableInt(10, bits=8)
471 b = SelectableInt(5, bits=8)
472 self.assertTrue(a > b)
473 self.assertFalse(a < b)
474 self.assertTrue(a != b)
475 self.assertFalse(a == b)
476
477 def test_unsigned(self):
478 a = SelectableInt(0x80, bits=8)
479 b = SelectableInt(0x7f, bits=8)
480 self.assertTrue(a > b)
481 self.assertFalse(a < b)
482 self.assertTrue(a != b)
483 self.assertFalse(a == b)
484
485 if __name__ == "__main__":
486 unittest.main()