a275de0724d0c6fed689d6ceaef2904cf99c01a6
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if isinstance(b, FieldSelectableInt):
9 b = b.get_range()
10 if isinstance(b, int):
11 return SelectableInt(b, a.bits)
12 if b.bits != 256:
13 return b
14 return SelectableInt(b.value, a.bits)
15
16
17 class FieldSelectableInt:
18 """FieldSelectableInt: allows bit-range selection onto another target
19 """
20 def __init__(self, si, br):
21 self.si = si # target selectable int
22 if isinstance(br, list) or isinstance(br, tuple):
23 _br = BitRange()
24 for i, v in enumerate(br):
25 _br[i] = v
26 br = _br
27 self.br = br # map of indices.
28
29 def eq(self, b):
30 if isinstance(b, SelectableInt):
31 for i in range(b.bits):
32 self[i] = b[i]
33 else:
34 self.si = copy(b.si)
35 self.br = copy(b.br)
36
37 def _op(self, op, b):
38 vi = self.get_range()
39 vi = op(vi, b)
40 return self.merge(vi)
41
42 def _op1(self, op):
43 vi = self.get_range()
44 vi = op(vi)
45 return self.merge(vi)
46
47 def __getitem__(self, key):
48 print ("getitem", key, self.br)
49 if isinstance(key, SelectableInt):
50 key = key.value
51 if isinstance(key, int):
52 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
53 return self.si[key]
54 if isinstance(key, slice):
55 key = self.br[key]
56 return selectconcat(*[self.si[x] for x in key])
57
58 def __setitem__(self, key, value):
59 if isinstance(key, SelectableInt):
60 key = key.value
61 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
62 if isinstance(key, int):
63 return self.si.__setitem__(key, value)
64 else:
65 if not isinstance(value, SelectableInt):
66 value = SelectableInt(value, bits=len(key))
67 for i, k in enumerate(key):
68 self.si[k] = value[i]
69
70 def __negate__(self):
71 return self._op1(negate)
72 def __invert__(self):
73 return self._op1(inv)
74 def __add__(self, b):
75 return self._op(add, b)
76 def __sub__(self, b):
77 return self._op(sub, b)
78 def __mul__(self, b):
79 return self._op(mul, b)
80 def __div__(self, b):
81 return self._op(truediv, b)
82 def __mod__(self, b):
83 return self._op(mod, b)
84 def __and__(self, b):
85 return self._op(and_, b)
86 def __or__(self, b):
87 return self._op(or_, b)
88 def __xor__(self, b):
89 return self._op(xor, b)
90
91 def get_range(self):
92 print ("get_range", self.si)
93 vi = SelectableInt(0, len(self.br))
94 for k, v in self.br.items():
95 print ("get_range", k, v, self.si[v])
96 vi[k] = self.si[v]
97 print ("get_range", vi)
98 return vi
99
100 def merge(self, vi):
101 fi = copy(self)
102 for i, v in fi.br.items():
103 fi.si[v] = vi[i]
104 return fi
105
106 def __repr__(self):
107 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
108
109
110 class FieldSelectableIntTestCase(unittest.TestCase):
111 def test_arith(self):
112 a = SelectableInt(0b10101, 5)
113 b = SelectableInt(0b011, 3)
114 br = BitRange()
115 br[0] = 0
116 br[1] = 2
117 br[2] = 3
118 fs = FieldSelectableInt(a, br)
119 c = fs + b
120 print (c)
121 #self.assertEqual(c.value, a.value + b.value)
122
123 def test_select(self):
124 a = SelectableInt(0b00001111, 8)
125 br = BitRange()
126 br[0] = 0
127 br[1] = 1
128 br[2] = 4
129 br[3] = 5
130 fs = FieldSelectableInt(a, br)
131
132 self.assertEqual(fs.get_range(), 0b0011)
133
134 def test_select_range(self):
135 a = SelectableInt(0b00001111, 8)
136 br = BitRange()
137 br[0] = 0
138 br[1] = 1
139 br[2] = 4
140 br[3] = 5
141 fs = FieldSelectableInt(a, br)
142
143 self.assertEqual(fs[2:4], 0b11)
144
145 fs[0:2] = 0b10
146 self.assertEqual(fs.get_range(), 0b1011)
147
148
149
150 class SelectableInt:
151 def __init__(self, value, bits):
152 mask = (1 << bits) - 1
153 self.value = value & mask
154 self.bits = bits
155
156 def eq(self, b):
157 self.value = b.value
158 self.bits = b.bits
159
160 def __add__(self, b):
161 if isinstance(b, int):
162 b = SelectableInt(b, self.bits)
163 b = check_extsign(self, b)
164 assert b.bits == self.bits
165 return SelectableInt(self.value + b.value, self.bits)
166
167 def __sub__(self, b):
168 if isinstance(b, int):
169 b = SelectableInt(b, self.bits)
170 b = check_extsign(self, b)
171 assert b.bits == self.bits
172 return SelectableInt(self.value - b.value, self.bits)
173
174 def __rsub__(self, b):
175 if isinstance(b, int):
176 b = SelectableInt(b, self.bits)
177 b = check_extsign(self, b)
178 assert b.bits == self.bits
179 return SelectableInt(b.value - self.value, self.bits)
180
181 def __radd__(self, b):
182 if isinstance(b, int):
183 b = SelectableInt(b, self.bits)
184 b = check_extsign(self, b)
185 assert b.bits == self.bits
186 return SelectableInt(b.value + self.value, self.bits)
187
188 def __mul__(self, b):
189 b = check_extsign(self, b)
190 assert b.bits == self.bits
191 return SelectableInt(self.value * b.value, self.bits)
192
193 def __div__(self, b):
194 b = check_extsign(self, b)
195 assert b.bits == self.bits
196 return SelectableInt(self.value / b.value, self.bits)
197
198 def __mod__(self, b):
199 b = check_extsign(self, b)
200 assert b.bits == self.bits
201 return SelectableInt(self.value % b.value, self.bits)
202
203 def __or__(self, b):
204 b = check_extsign(self, b)
205 assert b.bits == self.bits
206 return SelectableInt(self.value | b.value, self.bits)
207
208 def __and__(self, b):
209 print ("__and__", self, b)
210 b = check_extsign(self, b)
211 assert b.bits == self.bits
212 return SelectableInt(self.value & b.value, self.bits)
213
214 def __xor__(self, b):
215 b = check_extsign(self, b)
216 assert b.bits == self.bits
217 return SelectableInt(self.value ^ b.value, self.bits)
218
219 def __invert__(self):
220 return SelectableInt(~self.value, self.bits)
221
222 def __neg__(self):
223 return SelectableInt(~self.value + 1, self.bits)
224
225 def __lshift__(self, b):
226 b = check_extsign(self, b)
227 return SelectableInt(self.value << b.value, self.bits)
228
229 def __rshift__(self, b):
230 b = check_extsign(self, b)
231 return SelectableInt(self.value >> b.value, self.bits)
232
233 def __getitem__(self, key):
234 if isinstance(key, int):
235 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
236 assert key >= 0
237 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
238 # MSB is indexed **LOWEST** (sigh)
239 key = self.bits - (key + 1)
240
241 value = (self.value >> key) & 1
242 return SelectableInt(value, 1)
243 elif isinstance(key, slice):
244 assert key.step is None or key.step == 1
245 assert key.start < key.stop
246 assert key.start >= 0
247 assert key.stop <= self.bits
248
249 stop = self.bits - key.start
250 start = self.bits - key.stop
251
252 bits = stop - start
253 #print ("__getitem__ slice num bits", bits)
254 mask = (1 << bits) - 1
255 value = (self.value >> start) & mask
256 return SelectableInt(value, bits)
257
258 def __setitem__(self, key, value):
259 if isinstance(key, int):
260 assert key < self.bits
261 assert key >= 0
262 key = self.bits - (key + 1)
263 if isinstance(value, SelectableInt):
264 assert value.bits == 1
265 value = value.value
266
267 value = value << key
268 mask = 1 << key
269 self.value = (self.value & ~mask) | (value & mask)
270 elif isinstance(key, slice):
271 assert key.step is None or key.step == 1
272 assert key.start < key.stop
273 assert key.start >= 0
274 assert key.stop <= self.bits
275
276 stop = self.bits - key.start
277 start = self.bits - key.stop
278
279 bits = stop - start
280 #print ("__setitem__ slice num bits", bits)
281 if isinstance(value, SelectableInt):
282 assert value.bits == bits, "%d into %d" % (value.bits, bits)
283 value = value.value
284 mask = ((1 << bits) - 1) << start
285 value = value << start
286 self.value = (self.value & ~mask) | (value & mask)
287
288 def __ge__(self, other):
289 if isinstance(other, FieldSelectableInt):
290 other = other.get_range()
291 if isinstance(other, SelectableInt):
292 other = check_extsign(self, other)
293 assert other.bits == self.bits
294 other = other.value
295 if isinstance(other, int):
296 return onebit(self.value >= other.value)
297 assert False
298
299 def __le__(self, other):
300 if isinstance(other, FieldSelectableInt):
301 other = other.get_range()
302 if isinstance(other, SelectableInt):
303 other = check_extsign(self, other)
304 assert other.bits == self.bits
305 other = other.value
306 if isinstance(other, int):
307 return onebit(self.value <= other)
308 assert False
309
310 def __gt__(self, other):
311 if isinstance(other, FieldSelectableInt):
312 other = other.get_range()
313 if isinstance(other, SelectableInt):
314 other = check_extsign(self, other)
315 assert other.bits == self.bits
316 other = other.value
317 if isinstance(other, int):
318 return onebit(self.value > other)
319 assert False
320
321 def __lt__(self, other):
322 if isinstance(other, FieldSelectableInt):
323 other = other.get_range()
324 if isinstance(other, SelectableInt):
325 other = check_extsign(self, other)
326 assert other.bits == self.bits
327 other = other.value
328 if isinstance(other, int):
329 return onebit(self.value < other)
330 assert False
331
332 def __eq__(self, other):
333 print ("__eq__", self, other)
334 if isinstance(other, FieldSelectableInt):
335 other = other.get_range()
336 if isinstance(other, SelectableInt):
337 other = check_extsign(self, other)
338 assert other.bits == self.bits
339 other = other.value
340 if isinstance(other, int):
341 return onebit(other == self.value)
342 assert False
343
344 def narrow(self, bits):
345 assert bits <= self.bits
346 return SelectableInt(self.value, bits)
347
348 def __bool__(self):
349 return self.value != 0
350
351 def __repr__(self):
352 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
353 self.bits)
354
355 def __len__(self):
356 return self.bits
357
358 def onebit(bit):
359 return SelectableInt(1 if bit else 0, 1)
360
361 def selectltu(lhs, rhs):
362 """ less-than (unsigned)
363 """
364 if isinstance(rhs, SelectableInt):
365 rhs = rhs.value
366 return onebit(lhs.value < rhs)
367
368 def selectgtu(lhs, rhs):
369 """ greater-than (unsigned)
370 """
371 if isinstance(rhs, SelectableInt):
372 rhs = rhs.value
373 return onebit(lhs.value > rhs)
374
375
376 # XXX this probably isn't needed...
377 def selectassign(lhs, idx, rhs):
378 if isinstance(idx, tuple):
379 if len(idx) == 2:
380 lower, upper = idx
381 step = None
382 else:
383 lower, upper, step = idx
384 toidx = range(lower, upper, step)
385 fromidx = range(0, upper-lower, step) # XXX eurgh...
386 else:
387 toidx = [idx]
388 fromidx = [0]
389 for t, f in zip(toidx, fromidx):
390 lhs[t] = rhs[f]
391
392
393 def selectconcat(*args, repeat=1):
394 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
395 args = [SelectableInt(args[0], 1)]
396 if repeat != 1: # multiplies the incoming arguments
397 tmp = []
398 for i in range(repeat):
399 tmp += args
400 args = tmp
401 res = copy(args[0])
402 for i in args[1:]:
403 if isinstance(i, FieldSelectableInt):
404 i = i.si
405 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
406 res.bits += i.bits
407 res.value = (res.value << i.bits) | i.value
408 print ("concat", repeat, res)
409 return res
410
411
412 class SelectableIntTestCase(unittest.TestCase):
413 def test_arith(self):
414 a = SelectableInt(5, 8)
415 b = SelectableInt(9, 8)
416 c = a + b
417 d = a - b
418 e = a * b
419 f = -a
420 self.assertEqual(c.value, a.value + b.value)
421 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
422 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
423 self.assertEqual(f.value, (-a.value) & 0xFF)
424 self.assertEqual(c.bits, a.bits)
425 self.assertEqual(d.bits, a.bits)
426 self.assertEqual(e.bits, a.bits)
427
428 def test_logic(self):
429 a = SelectableInt(0x0F, 8)
430 b = SelectableInt(0xA5, 8)
431 c = a & b
432 d = a | b
433 e = a ^ b
434 f = ~a
435 self.assertEqual(c.value, a.value & b.value)
436 self.assertEqual(d.value, a.value | b.value)
437 self.assertEqual(e.value, a.value ^ b.value)
438 self.assertEqual(f.value, 0xF0)
439
440 def test_get(self):
441 a = SelectableInt(0xa2, 8)
442 # These should be big endian
443 self.assertEqual(a[7], 0)
444 self.assertEqual(a[0:4], 10)
445 self.assertEqual(a[4:8], 2)
446
447 def test_set(self):
448 a = SelectableInt(0x5, 8)
449 a[7] = SelectableInt(0, 1)
450 self.assertEqual(a, 4)
451 a[4:8] = 9
452 self.assertEqual(a, 9)
453 a[0:4] = 3
454 self.assertEqual(a, 0x39)
455 a[0:4] = a[4:8]
456 self.assertEqual(a, 0x99)
457
458 def test_concat(self):
459 a = SelectableInt(0x1, 1)
460 c = selectconcat(a, repeat=8)
461 self.assertEqual(c, 0xff)
462 self.assertEqual(c.bits, 8)
463 a = SelectableInt(0x0, 1)
464 c = selectconcat(a, repeat=8)
465 self.assertEqual(c, 0x00)
466 self.assertEqual(c.bits, 8)
467
468 def test_repr(self):
469 for i in range(65536):
470 a = SelectableInt(i, 16)
471 b = eval(repr(a))
472 self.assertEqual(a, b)
473
474 def test_cmp(self):
475 a = SelectableInt(10, bits=8)
476 b = SelectableInt(5, bits=8)
477 self.assertTrue(a > b)
478 self.assertFalse(a < b)
479 self.assertTrue(a != b)
480 self.assertFalse(a == b)
481
482 def test_unsigned(self):
483 a = SelectableInt(0x80, bits=8)
484 b = SelectableInt(0x7f, bits=8)
485 self.assertTrue(a > b)
486 self.assertFalse(a < b)
487 self.assertTrue(a != b)
488 self.assertFalse(a == b)
489
490 if __name__ == "__main__":
491 unittest.main()