add FieldSelectableInt which allows re-targetting of fields
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, div, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if b.bits != 256:
9 return b
10 return SelectableInt(b.value, a.bits)
11
12
13 class FieldSelectableInt:
14 """FieldSelectableInt: allows bit-range selection onto another target
15 """
16 def __init__(self, si, br):
17 self.si = si # target selectable int
18 self.br = br # map of indices.
19
20 def _op(self, op, b):
21 vi = self.get_range()
22 vi = op(vi + b)
23 return self.merge(vi)
24
25 def _op1(self, op):
26 vi = self.get_range()
27 vi = op(vi)
28 return self.merge(vi)
29
30 def __getitem__(self, key):
31 key = self.br[key]
32 return self.si[key]
33
34 def __setitem__(self, key, value)
35 key = self.br[key]
36 return self.si__setitem__(key, value)
37
38 def __negate__(self):
39 return self._op1(negate)
40 def __invert__(self):
41 return self._op1(inv)
42 def __add__(self, b):
43 return self._op(add, b)
44 def __sub__(self, b):
45 return self._op(sub, b)
46 def __mul__(self, b):
47 return self._op(mul, b)
48 def __div__(self, b):
49 return self._op(div, b)
50 def __and__(self, b):
51 return self._op(and_, b)
52 def __or__(self, b):
53 return self._op(or_, b)
54 def __xor__(self, b):
55 return self._op(xor, b)
56
57 def get_range(self):
58 print ("get_range", self.si)
59 vi = SelectableInt(0, len(self.br))
60 for k, v in self.br.items():
61 print ("get_range", k, v, self.si[v])
62 vi[k] = self.si[v]
63 print ("get_range", vi)
64 return vi
65
66 def merge(self, vi):
67 fi = copy(self)
68 for i, v in fi.br.items():
69 fi.si[v] = vi[i]
70 return fi
71
72 def __repr__(self):
73 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
74
75
76 class FieldSelectableIntTestCase(unittest.TestCase):
77 def test_arith(self):
78 a = SelectableInt(0b10101, 5)
79 b = SelectableInt(0b011, 3)
80 br = BitRange()
81 br[0] = 0
82 br[1] = 2
83 br[2] = 3
84 fs = FieldSelectableInt(a, br)
85 c = fs + b
86 print (c)
87 #self.assertEqual(c.value, a.value + b.value)
88
89
90 class SelectableInt:
91 def __init__(self, value, bits):
92 mask = (1 << bits) - 1
93 self.value = value & mask
94 self.bits = bits
95
96 def __add__(self, b):
97 if isinstance(b, int):
98 b = SelectableInt(b, self.bits)
99 b = check_extsign(self, b)
100 assert b.bits == self.bits
101 return SelectableInt(self.value + b.value, self.bits)
102
103 def __sub__(self, b):
104 if isinstance(b, int):
105 b = SelectableInt(b, self.bits)
106 b = check_extsign(self, b)
107 assert b.bits == self.bits
108 return SelectableInt(self.value - b.value, self.bits)
109
110 def __mul__(self, b):
111 b = check_extsign(self, b)
112 assert b.bits == self.bits
113 return SelectableInt(self.value * b.value, self.bits)
114
115 def __div__(self, b):
116 b = check_extsign(self, b)
117 assert b.bits == self.bits
118 return SelectableInt(self.value / b.value, self.bits)
119
120 def __mod__(self, b):
121 b = check_extsign(self, b)
122 assert b.bits == self.bits
123 return SelectableInt(self.value % b.value, self.bits)
124
125 def __or__(self, b):
126 b = check_extsign(self, b)
127 assert b.bits == self.bits
128 return SelectableInt(self.value | b.value, self.bits)
129
130 def __and__(self, b):
131 b = check_extsign(self, b)
132 assert b.bits == self.bits
133 return SelectableInt(self.value & b.value, self.bits)
134
135 def __xor__(self, b):
136 b = check_extsign(self, b)
137 assert b.bits == self.bits
138 return SelectableInt(self.value ^ b.value, self.bits)
139
140 def __invert__(self):
141 return SelectableInt(~self.value, self.bits)
142
143 def __neg__(self):
144 return SelectableInt(~self.value + 1, self.bits)
145
146 def __getitem__(self, key):
147 if isinstance(key, int):
148 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
149 assert key >= 0
150 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
151 # MSB is indexed **LOWEST** (sigh)
152 key = self.bits - (key + 1)
153
154 value = (self.value >> key) & 1
155 return SelectableInt(value, 1)
156 elif isinstance(key, slice):
157 assert key.step is None or key.step == 1
158 assert key.start < key.stop
159 assert key.start >= 0
160 assert key.stop <= self.bits
161
162 stop = self.bits - key.start
163 start = self.bits - key.stop
164
165 bits = stop - start + 1
166 mask = (1 << bits) - 1
167 value = (self.value >> start) & mask
168 return SelectableInt(value, bits)
169
170 def __setitem__(self, key, value):
171 if isinstance(key, int):
172 assert key < self.bits
173 assert key >= 0
174 key = self.bits - (key + 1)
175 if isinstance(value, SelectableInt):
176 assert value.bits == 1
177 value = value.value
178
179 value = value << key
180 mask = 1 << key
181 self.value = (self.value & ~mask) | (value & mask)
182 elif isinstance(key, slice):
183 assert key.step is None or key.step == 1
184 assert key.start < key.stop
185 assert key.start >= 0
186 assert key.stop <= self.bits
187
188 stop = self.bits - key.start
189 start = self.bits - key.stop
190
191 bits = stop - start + 1
192 if isinstance(value, SelectableInt):
193 assert value.bits == bits, "%d into %d" % (value.bits, bits)
194 value = value.value
195 mask = ((1 << bits) - 1) << start
196 value = value << start
197 self.value = (self.value & ~mask) | (value & mask)
198
199 def __ge__(self, other):
200 if isinstance(other, SelectableInt):
201 other = check_extsign(self, other)
202 assert other.bits == self.bits
203 other = other.value
204 if isinstance(other, int):
205 return other >= self.value
206 assert False
207
208 def __le__(self, other):
209 if isinstance(other, SelectableInt):
210 other = check_extsign(self, other)
211 assert other.bits == self.bits
212 other = other.value
213 if isinstance(other, int):
214 return onebit(other <= self.value)
215 assert False
216
217 def __gt__(self, other):
218 if isinstance(other, SelectableInt):
219 other = check_extsign(self, other)
220 assert other.bits == self.bits
221 other = other.value
222 if isinstance(other, int):
223 return onebit(other > self.value)
224 assert False
225
226 def __lt__(self, other):
227 if isinstance(other, SelectableInt):
228 other = check_extsign(self, other)
229 assert other.bits == self.bits
230 other = other.value
231 if isinstance(other, int):
232 return onebit(other < self.value)
233 assert False
234
235 def __eq__(self, other):
236 if isinstance(other, SelectableInt):
237 other = check_extsign(self, other)
238 assert other.bits == self.bits
239 other = other.value
240 if isinstance(other, int):
241 return onebit(other == self.value)
242 assert False
243
244 def narrow(self, bits):
245 assert bits <= self.bits
246 return SelectableInt(self.value, bits)
247
248 def __bool__(self):
249 return self.value != 0
250
251 def __repr__(self):
252 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
253 self.bits)
254
255 def onebit(bit):
256 return SelectableInt(1 if bit else 0, 1)
257
258 def selectltu(lhs, rhs):
259 """ less-than (unsigned)
260 """
261 if isinstance(rhs, SelectableInt):
262 rhs = rhs.value
263 return onebit(lhs.value < rhs)
264
265 def selectgtu(lhs, rhs):
266 """ greater-than (unsigned)
267 """
268 if isinstance(rhs, SelectableInt):
269 rhs = rhs.value
270 return onebit(lhs.value > rhs)
271
272
273 # XXX this probably isn't needed...
274 def selectassign(lhs, idx, rhs):
275 if isinstance(idx, tuple):
276 if len(idx) == 2:
277 lower, upper = idx
278 step = None
279 else:
280 lower, upper, step = idx
281 toidx = range(lower, upper, step)
282 fromidx = range(0, upper-lower, step) # XXX eurgh...
283 else:
284 toidx = [idx]
285 fromidx = [0]
286 for t, f in zip(toidx, fromidx):
287 lhs[t] = rhs[f]
288
289
290 def selectconcat(*args, repeat=1):
291 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
292 args = [SelectableInt(args[0], 1)]
293 if repeat != 1: # multiplies the incoming arguments
294 tmp = []
295 for i in range(repeat):
296 tmp += args
297 args = tmp
298 res = copy(args[0])
299 for i in args[1:]:
300 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
301 res.bits += i.bits
302 res.value = (res.value << i.bits) | i.value
303 print ("concat", repeat, res)
304 return res
305
306
307 class SelectableIntTestCase(unittest.TestCase):
308 def test_arith(self):
309 a = SelectableInt(5, 8)
310 b = SelectableInt(9, 8)
311 c = a + b
312 d = a - b
313 e = a * b
314 f = -a
315 self.assertEqual(c.value, a.value + b.value)
316 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
317 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
318 self.assertEqual(f.value, (-a.value) & 0xFF)
319 self.assertEqual(c.bits, a.bits)
320 self.assertEqual(d.bits, a.bits)
321 self.assertEqual(e.bits, a.bits)
322
323 def test_logic(self):
324 a = SelectableInt(0x0F, 8)
325 b = SelectableInt(0xA5, 8)
326 c = a & b
327 d = a | b
328 e = a ^ b
329 f = ~a
330 self.assertEqual(c.value, a.value & b.value)
331 self.assertEqual(d.value, a.value | b.value)
332 self.assertEqual(e.value, a.value ^ b.value)
333 self.assertEqual(f.value, 0xF0)
334
335 def test_get(self):
336 a = SelectableInt(0xa2, 8)
337 # These should be big endian
338 self.assertEqual(a[7], 0)
339 self.assertEqual(a[0:4], 10)
340 self.assertEqual(a[4:8], 2)
341
342 def test_set(self):
343 a = SelectableInt(0x5, 8)
344 a[7] = SelectableInt(0, 1)
345 self.assertEqual(a, 4)
346 a[4:8] = 9
347 self.assertEqual(a, 9)
348 a[0:4] = 3
349 self.assertEqual(a, 0x39)
350 a[0:4] = a[4:8]
351 self.assertEqual(a, 0x199)
352
353 def test_concat(self):
354 a = SelectableInt(0x1, 1)
355 c = selectconcat(a, repeat=8)
356 self.assertEqual(c, 0xff)
357 self.assertEqual(c.bits, 8)
358 a = SelectableInt(0x0, 1)
359 c = selectconcat(a, repeat=8)
360 self.assertEqual(c, 0x00)
361 self.assertEqual(c.bits, 8)
362
363 def test_repr(self):
364 for i in range(65536):
365 a = SelectableInt(i, 16)
366 b = eval(repr(a))
367 self.assertEqual(a, b)
368
369 if __name__ == "__main__":
370 unittest.main()