Sorta kinda working bl and blr - need to properly implement lr
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if isinstance(b, FieldSelectableInt):
9 b = b.get_range()
10 if b.bits != 256:
11 return b
12 return SelectableInt(b.value, a.bits)
13
14
15 class FieldSelectableInt:
16 """FieldSelectableInt: allows bit-range selection onto another target
17 """
18 def __init__(self, si, br):
19 self.si = si # target selectable int
20 if isinstance(br, list) or isinstance(br, tuple):
21 _br = BitRange()
22 for i, v in enumerate(br):
23 _br[i] = v
24 br = _br
25 self.br = br # map of indices.
26
27 def eq(self, b):
28 if isinstance(b, SelectableInt):
29 for i in range(b.bits):
30 self[i] = b[i]
31 else:
32 self.si = copy(b.si)
33 self.br = copy(b.br)
34
35 def _op(self, op, b):
36 vi = self.get_range()
37 vi = op(vi, b)
38 return self.merge(vi)
39
40 def _op1(self, op):
41 vi = self.get_range()
42 vi = op(vi)
43 return self.merge(vi)
44
45 def __getitem__(self, key):
46 print ("getitem", key, self.br)
47 if isinstance(key, SelectableInt):
48 key = key.value
49 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
50 return self.si[key]
51
52 def __setitem__(self, key, value):
53 if isinstance(key, SelectableInt):
54 key = key.value
55 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
56 return self.si.__setitem__(key, value)
57
58 def __negate__(self):
59 return self._op1(negate)
60 def __invert__(self):
61 return self._op1(inv)
62 def __add__(self, b):
63 return self._op(add, b)
64 def __sub__(self, b):
65 return self._op(sub, b)
66 def __mul__(self, b):
67 return self._op(mul, b)
68 def __div__(self, b):
69 return self._op(truediv, b)
70 def __mod__(self, b):
71 return self._op(mod, b)
72 def __and__(self, b):
73 return self._op(and_, b)
74 def __or__(self, b):
75 return self._op(or_, b)
76 def __xor__(self, b):
77 return self._op(xor, b)
78
79 def get_range(self):
80 print ("get_range", self.si)
81 vi = SelectableInt(0, len(self.br))
82 for k, v in self.br.items():
83 print ("get_range", k, v, self.si[v])
84 vi[k] = self.si[v]
85 print ("get_range", vi)
86 return vi
87
88 def merge(self, vi):
89 fi = copy(self)
90 for i, v in fi.br.items():
91 fi.si[v] = vi[i]
92 return fi
93
94 def __repr__(self):
95 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
96
97
98 class FieldSelectableIntTestCase(unittest.TestCase):
99 def test_arith(self):
100 a = SelectableInt(0b10101, 5)
101 b = SelectableInt(0b011, 3)
102 br = BitRange()
103 br[0] = 0
104 br[1] = 2
105 br[2] = 3
106 fs = FieldSelectableInt(a, br)
107 c = fs + b
108 print (c)
109 #self.assertEqual(c.value, a.value + b.value)
110
111
112 class SelectableInt:
113 def __init__(self, value, bits):
114 mask = (1 << bits) - 1
115 self.value = value & mask
116 self.bits = bits
117
118 def eq(self, b):
119 self.value = b.value
120 self.bits = b.bits
121
122 def __add__(self, b):
123 if isinstance(b, int):
124 b = SelectableInt(b, self.bits)
125 b = check_extsign(self, b)
126 assert b.bits == self.bits
127 return SelectableInt(self.value + b.value, self.bits)
128
129 def __sub__(self, b):
130 if isinstance(b, int):
131 b = SelectableInt(b, self.bits)
132 b = check_extsign(self, b)
133 assert b.bits == self.bits
134 return SelectableInt(self.value - b.value, self.bits)
135
136 def __mul__(self, b):
137 b = check_extsign(self, b)
138 assert b.bits == self.bits
139 return SelectableInt(self.value * b.value, self.bits)
140
141 def __div__(self, b):
142 b = check_extsign(self, b)
143 assert b.bits == self.bits
144 return SelectableInt(self.value / b.value, self.bits)
145
146 def __mod__(self, b):
147 b = check_extsign(self, b)
148 assert b.bits == self.bits
149 return SelectableInt(self.value % b.value, self.bits)
150
151 def __or__(self, b):
152 b = check_extsign(self, b)
153 assert b.bits == self.bits
154 return SelectableInt(self.value | b.value, self.bits)
155
156 def __and__(self, b):
157 print ("__and__", self, b)
158 b = check_extsign(self, b)
159 assert b.bits == self.bits
160 return SelectableInt(self.value & b.value, self.bits)
161
162 def __xor__(self, b):
163 b = check_extsign(self, b)
164 assert b.bits == self.bits
165 return SelectableInt(self.value ^ b.value, self.bits)
166
167 def __invert__(self):
168 return SelectableInt(~self.value, self.bits)
169
170 def __neg__(self):
171 return SelectableInt(~self.value + 1, self.bits)
172
173 def __getitem__(self, key):
174 if isinstance(key, int):
175 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
176 assert key >= 0
177 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
178 # MSB is indexed **LOWEST** (sigh)
179 key = self.bits - (key + 1)
180
181 value = (self.value >> key) & 1
182 return SelectableInt(value, 1)
183 elif isinstance(key, slice):
184 assert key.step is None or key.step == 1
185 assert key.start < key.stop
186 assert key.start >= 0
187 assert key.stop <= self.bits
188
189 stop = self.bits - key.start
190 start = self.bits - key.stop
191
192 bits = stop - start
193 #print ("__getitem__ slice num bits", bits)
194 mask = (1 << bits) - 1
195 value = (self.value >> start) & mask
196 return SelectableInt(value, bits)
197
198 def __setitem__(self, key, value):
199 if isinstance(key, int):
200 assert key < self.bits
201 assert key >= 0
202 key = self.bits - (key + 1)
203 if isinstance(value, SelectableInt):
204 assert value.bits == 1
205 value = value.value
206
207 value = value << key
208 mask = 1 << key
209 self.value = (self.value & ~mask) | (value & mask)
210 elif isinstance(key, slice):
211 assert key.step is None or key.step == 1
212 assert key.start < key.stop
213 assert key.start >= 0
214 assert key.stop <= self.bits
215
216 stop = self.bits - key.start
217 start = self.bits - key.stop
218
219 bits = stop - start
220 #print ("__setitem__ slice num bits", bits)
221 if isinstance(value, SelectableInt):
222 assert value.bits == bits, "%d into %d" % (value.bits, bits)
223 value = value.value
224 mask = ((1 << bits) - 1) << start
225 value = value << start
226 self.value = (self.value & ~mask) | (value & mask)
227
228 def __ge__(self, other):
229 if isinstance(other, FieldSelectableInt):
230 other = other.get_range()
231 if isinstance(other, SelectableInt):
232 other = check_extsign(self, other)
233 assert other.bits == self.bits
234 other = other.value
235 if isinstance(other, int):
236 return other >= self.value
237 assert False
238
239 def __le__(self, other):
240 if isinstance(other, FieldSelectableInt):
241 other = other.get_range()
242 if isinstance(other, SelectableInt):
243 other = check_extsign(self, other)
244 assert other.bits == self.bits
245 other = other.value
246 if isinstance(other, int):
247 return onebit(other <= self.value)
248 assert False
249
250 def __gt__(self, other):
251 if isinstance(other, FieldSelectableInt):
252 other = other.get_range()
253 if isinstance(other, SelectableInt):
254 other = check_extsign(self, other)
255 assert other.bits == self.bits
256 other = other.value
257 if isinstance(other, int):
258 return onebit(other > self.value)
259 assert False
260
261 def __lt__(self, other):
262 if isinstance(other, FieldSelectableInt):
263 other = other.get_range()
264 if isinstance(other, SelectableInt):
265 other = check_extsign(self, other)
266 assert other.bits == self.bits
267 other = other.value
268 if isinstance(other, int):
269 return onebit(other < self.value)
270 assert False
271
272 def __eq__(self, other):
273 print ("__eq__", self, other)
274 if isinstance(other, FieldSelectableInt):
275 other = other.get_range()
276 if isinstance(other, SelectableInt):
277 other = check_extsign(self, other)
278 assert other.bits == self.bits
279 other = other.value
280 if isinstance(other, int):
281 return onebit(other == self.value)
282 assert False
283
284 def narrow(self, bits):
285 assert bits <= self.bits
286 return SelectableInt(self.value, bits)
287
288 def __bool__(self):
289 return self.value != 0
290
291 def __repr__(self):
292 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
293 self.bits)
294
295 def onebit(bit):
296 return SelectableInt(1 if bit else 0, 1)
297
298 def selectltu(lhs, rhs):
299 """ less-than (unsigned)
300 """
301 if isinstance(rhs, SelectableInt):
302 rhs = rhs.value
303 return onebit(lhs.value < rhs)
304
305 def selectgtu(lhs, rhs):
306 """ greater-than (unsigned)
307 """
308 if isinstance(rhs, SelectableInt):
309 rhs = rhs.value
310 return onebit(lhs.value > rhs)
311
312
313 # XXX this probably isn't needed...
314 def selectassign(lhs, idx, rhs):
315 if isinstance(idx, tuple):
316 if len(idx) == 2:
317 lower, upper = idx
318 step = None
319 else:
320 lower, upper, step = idx
321 toidx = range(lower, upper, step)
322 fromidx = range(0, upper-lower, step) # XXX eurgh...
323 else:
324 toidx = [idx]
325 fromidx = [0]
326 for t, f in zip(toidx, fromidx):
327 lhs[t] = rhs[f]
328
329
330 def selectconcat(*args, repeat=1):
331 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
332 args = [SelectableInt(args[0], 1)]
333 if repeat != 1: # multiplies the incoming arguments
334 tmp = []
335 for i in range(repeat):
336 tmp += args
337 args = tmp
338 res = copy(args[0])
339 for i in args[1:]:
340 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
341 res.bits += i.bits
342 res.value = (res.value << i.bits) | i.value
343 print ("concat", repeat, res)
344 return res
345
346
347 class SelectableIntTestCase(unittest.TestCase):
348 def test_arith(self):
349 a = SelectableInt(5, 8)
350 b = SelectableInt(9, 8)
351 c = a + b
352 d = a - b
353 e = a * b
354 f = -a
355 self.assertEqual(c.value, a.value + b.value)
356 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
357 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
358 self.assertEqual(f.value, (-a.value) & 0xFF)
359 self.assertEqual(c.bits, a.bits)
360 self.assertEqual(d.bits, a.bits)
361 self.assertEqual(e.bits, a.bits)
362
363 def test_logic(self):
364 a = SelectableInt(0x0F, 8)
365 b = SelectableInt(0xA5, 8)
366 c = a & b
367 d = a | b
368 e = a ^ b
369 f = ~a
370 self.assertEqual(c.value, a.value & b.value)
371 self.assertEqual(d.value, a.value | b.value)
372 self.assertEqual(e.value, a.value ^ b.value)
373 self.assertEqual(f.value, 0xF0)
374
375 def test_get(self):
376 a = SelectableInt(0xa2, 8)
377 # These should be big endian
378 self.assertEqual(a[7], 0)
379 self.assertEqual(a[0:4], 10)
380 self.assertEqual(a[4:8], 2)
381
382 def test_set(self):
383 a = SelectableInt(0x5, 8)
384 a[7] = SelectableInt(0, 1)
385 self.assertEqual(a, 4)
386 a[4:8] = 9
387 self.assertEqual(a, 9)
388 a[0:4] = 3
389 self.assertEqual(a, 0x39)
390 a[0:4] = a[4:8]
391 self.assertEqual(a, 0x99)
392
393 def test_concat(self):
394 a = SelectableInt(0x1, 1)
395 c = selectconcat(a, repeat=8)
396 self.assertEqual(c, 0xff)
397 self.assertEqual(c.bits, 8)
398 a = SelectableInt(0x0, 1)
399 c = selectconcat(a, repeat=8)
400 self.assertEqual(c, 0x00)
401 self.assertEqual(c.bits, 8)
402
403 def test_repr(self):
404 for i in range(65536):
405 a = SelectableInt(i, 16)
406 b = eval(repr(a))
407 self.assertEqual(a, b)
408
409 if __name__ == "__main__":
410 unittest.main()