Implement bctr and mtspr
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if isinstance(b, FieldSelectableInt):
9 b = b.get_range()
10 if b.bits != 256:
11 return b
12 return SelectableInt(b.value, a.bits)
13
14
15 class FieldSelectableInt:
16 """FieldSelectableInt: allows bit-range selection onto another target
17 """
18 def __init__(self, si, br):
19 self.si = si # target selectable int
20 if isinstance(br, list) or isinstance(br, tuple):
21 _br = BitRange()
22 for i, v in enumerate(br):
23 _br[i] = v
24 br = _br
25 self.br = br # map of indices.
26
27 def eq(self, b):
28 if isinstance(b, SelectableInt):
29 for i in range(b.bits):
30 self[i] = b[i]
31 else:
32 self.si = copy(b.si)
33 self.br = copy(b.br)
34
35 def _op(self, op, b):
36 vi = self.get_range()
37 vi = op(vi, b)
38 return self.merge(vi)
39
40 def _op1(self, op):
41 vi = self.get_range()
42 vi = op(vi)
43 return self.merge(vi)
44
45 def __getitem__(self, key):
46 print ("getitem", key, self.br)
47 if isinstance(key, SelectableInt):
48 key = key.value
49 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
50 return self.si[key]
51
52 def __setitem__(self, key, value):
53 if isinstance(key, SelectableInt):
54 key = key.value
55 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
56 return self.si.__setitem__(key, value)
57
58 def __negate__(self):
59 return self._op1(negate)
60 def __invert__(self):
61 return self._op1(inv)
62 def __add__(self, b):
63 return self._op(add, b)
64 def __sub__(self, b):
65 return self._op(sub, b)
66 def __mul__(self, b):
67 return self._op(mul, b)
68 def __div__(self, b):
69 return self._op(truediv, b)
70 def __mod__(self, b):
71 return self._op(mod, b)
72 def __and__(self, b):
73 return self._op(and_, b)
74 def __or__(self, b):
75 return self._op(or_, b)
76 def __xor__(self, b):
77 return self._op(xor, b)
78
79 def get_range(self):
80 print ("get_range", self.si)
81 vi = SelectableInt(0, len(self.br))
82 for k, v in self.br.items():
83 print ("get_range", k, v, self.si[v])
84 vi[k] = self.si[v]
85 print ("get_range", vi)
86 return vi
87
88 def merge(self, vi):
89 fi = copy(self)
90 for i, v in fi.br.items():
91 fi.si[v] = vi[i]
92 return fi
93
94 def __repr__(self):
95 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
96
97
98 class FieldSelectableIntTestCase(unittest.TestCase):
99 def test_arith(self):
100 a = SelectableInt(0b10101, 5)
101 b = SelectableInt(0b011, 3)
102 br = BitRange()
103 br[0] = 0
104 br[1] = 2
105 br[2] = 3
106 fs = FieldSelectableInt(a, br)
107 c = fs + b
108 print (c)
109 #self.assertEqual(c.value, a.value + b.value)
110
111
112 class SelectableInt:
113 def __init__(self, value, bits):
114 mask = (1 << bits) - 1
115 self.value = value & mask
116 self.bits = bits
117
118 def eq(self, b):
119 self.value = b.value
120 self.bits = b.bits
121
122 def __add__(self, b):
123 if isinstance(b, int):
124 b = SelectableInt(b, self.bits)
125 b = check_extsign(self, b)
126 assert b.bits == self.bits
127 return SelectableInt(self.value + b.value, self.bits)
128
129 def __sub__(self, b):
130 if isinstance(b, int):
131 b = SelectableInt(b, self.bits)
132 b = check_extsign(self, b)
133 assert b.bits == self.bits
134 return SelectableInt(self.value - b.value, self.bits)
135
136 def __mul__(self, b):
137 b = check_extsign(self, b)
138 assert b.bits == self.bits
139 return SelectableInt(self.value * b.value, self.bits)
140
141 def __div__(self, b):
142 b = check_extsign(self, b)
143 assert b.bits == self.bits
144 return SelectableInt(self.value / b.value, self.bits)
145
146 def __mod__(self, b):
147 b = check_extsign(self, b)
148 assert b.bits == self.bits
149 return SelectableInt(self.value % b.value, self.bits)
150
151 def __or__(self, b):
152 b = check_extsign(self, b)
153 assert b.bits == self.bits
154 return SelectableInt(self.value | b.value, self.bits)
155
156 def __and__(self, b):
157 print ("__and__", self, b)
158 b = check_extsign(self, b)
159 assert b.bits == self.bits
160 return SelectableInt(self.value & b.value, self.bits)
161
162 def __xor__(self, b):
163 b = check_extsign(self, b)
164 assert b.bits == self.bits
165 return SelectableInt(self.value ^ b.value, self.bits)
166
167 def __invert__(self):
168 return SelectableInt(~self.value, self.bits)
169
170 def __neg__(self):
171 return SelectableInt(~self.value + 1, self.bits)
172
173 def __getitem__(self, key):
174 if isinstance(key, int):
175 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
176 assert key >= 0
177 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
178 # MSB is indexed **LOWEST** (sigh)
179 key = self.bits - (key + 1)
180
181 value = (self.value >> key) & 1
182 return SelectableInt(value, 1)
183 elif isinstance(key, slice):
184 assert key.step is None or key.step == 1
185 assert key.start < key.stop
186 assert key.start >= 0
187 assert key.stop <= self.bits
188
189 stop = self.bits - key.start
190 start = self.bits - key.stop
191
192 bits = stop - start
193 #print ("__getitem__ slice num bits", bits)
194 mask = (1 << bits) - 1
195 value = (self.value >> start) & mask
196 return SelectableInt(value, bits)
197
198 def __setitem__(self, key, value):
199 if isinstance(key, int):
200 assert key < self.bits
201 assert key >= 0
202 key = self.bits - (key + 1)
203 if isinstance(value, SelectableInt):
204 assert value.bits == 1
205 value = value.value
206
207 value = value << key
208 mask = 1 << key
209 self.value = (self.value & ~mask) | (value & mask)
210 elif isinstance(key, slice):
211 assert key.step is None or key.step == 1
212 assert key.start < key.stop
213 assert key.start >= 0
214 assert key.stop <= self.bits
215
216 stop = self.bits - key.start
217 start = self.bits - key.stop
218
219 bits = stop - start
220 #print ("__setitem__ slice num bits", bits)
221 if isinstance(value, SelectableInt):
222 assert value.bits == bits, "%d into %d" % (value.bits, bits)
223 value = value.value
224 mask = ((1 << bits) - 1) << start
225 value = value << start
226 self.value = (self.value & ~mask) | (value & mask)
227
228 def __ge__(self, other):
229 if isinstance(other, FieldSelectableInt):
230 other = other.get_range()
231 if isinstance(other, SelectableInt):
232 other = check_extsign(self, other)
233 assert other.bits == self.bits
234 other = other.value
235 if isinstance(other, int):
236 return other >= self.value
237 assert False
238
239 def __le__(self, other):
240 if isinstance(other, FieldSelectableInt):
241 other = other.get_range()
242 if isinstance(other, SelectableInt):
243 other = check_extsign(self, other)
244 assert other.bits == self.bits
245 other = other.value
246 if isinstance(other, int):
247 return onebit(other <= self.value)
248 assert False
249
250 def __gt__(self, other):
251 if isinstance(other, FieldSelectableInt):
252 other = other.get_range()
253 if isinstance(other, SelectableInt):
254 other = check_extsign(self, other)
255 assert other.bits == self.bits
256 other = other.value
257 if isinstance(other, int):
258 return onebit(other > self.value)
259 assert False
260
261 def __lt__(self, other):
262 if isinstance(other, FieldSelectableInt):
263 other = other.get_range()
264 if isinstance(other, SelectableInt):
265 other = check_extsign(self, other)
266 assert other.bits == self.bits
267 other = other.value
268 if isinstance(other, int):
269 return onebit(other < self.value)
270 assert False
271
272 def __eq__(self, other):
273 print ("__eq__", self, other)
274 if isinstance(other, FieldSelectableInt):
275 other = other.get_range()
276 if isinstance(other, SelectableInt):
277 other = check_extsign(self, other)
278 assert other.bits == self.bits
279 other = other.value
280 if isinstance(other, int):
281 return onebit(other == self.value)
282 assert False
283
284 def narrow(self, bits):
285 assert bits <= self.bits
286 return SelectableInt(self.value, bits)
287
288 def __bool__(self):
289 return self.value != 0
290
291 def __repr__(self):
292 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
293 self.bits)
294
295 def __len__(self):
296 return self.bits
297
298 def onebit(bit):
299 return SelectableInt(1 if bit else 0, 1)
300
301 def selectltu(lhs, rhs):
302 """ less-than (unsigned)
303 """
304 if isinstance(rhs, SelectableInt):
305 rhs = rhs.value
306 return onebit(lhs.value < rhs)
307
308 def selectgtu(lhs, rhs):
309 """ greater-than (unsigned)
310 """
311 if isinstance(rhs, SelectableInt):
312 rhs = rhs.value
313 return onebit(lhs.value > rhs)
314
315
316 # XXX this probably isn't needed...
317 def selectassign(lhs, idx, rhs):
318 if isinstance(idx, tuple):
319 if len(idx) == 2:
320 lower, upper = idx
321 step = None
322 else:
323 lower, upper, step = idx
324 toidx = range(lower, upper, step)
325 fromidx = range(0, upper-lower, step) # XXX eurgh...
326 else:
327 toidx = [idx]
328 fromidx = [0]
329 for t, f in zip(toidx, fromidx):
330 lhs[t] = rhs[f]
331
332
333 def selectconcat(*args, repeat=1):
334 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
335 args = [SelectableInt(args[0], 1)]
336 if repeat != 1: # multiplies the incoming arguments
337 tmp = []
338 for i in range(repeat):
339 tmp += args
340 args = tmp
341 res = copy(args[0])
342 for i in args[1:]:
343 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
344 res.bits += i.bits
345 res.value = (res.value << i.bits) | i.value
346 print ("concat", repeat, res)
347 return res
348
349
350 class SelectableIntTestCase(unittest.TestCase):
351 def test_arith(self):
352 a = SelectableInt(5, 8)
353 b = SelectableInt(9, 8)
354 c = a + b
355 d = a - b
356 e = a * b
357 f = -a
358 self.assertEqual(c.value, a.value + b.value)
359 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
360 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
361 self.assertEqual(f.value, (-a.value) & 0xFF)
362 self.assertEqual(c.bits, a.bits)
363 self.assertEqual(d.bits, a.bits)
364 self.assertEqual(e.bits, a.bits)
365
366 def test_logic(self):
367 a = SelectableInt(0x0F, 8)
368 b = SelectableInt(0xA5, 8)
369 c = a & b
370 d = a | b
371 e = a ^ b
372 f = ~a
373 self.assertEqual(c.value, a.value & b.value)
374 self.assertEqual(d.value, a.value | b.value)
375 self.assertEqual(e.value, a.value ^ b.value)
376 self.assertEqual(f.value, 0xF0)
377
378 def test_get(self):
379 a = SelectableInt(0xa2, 8)
380 # These should be big endian
381 self.assertEqual(a[7], 0)
382 self.assertEqual(a[0:4], 10)
383 self.assertEqual(a[4:8], 2)
384
385 def test_set(self):
386 a = SelectableInt(0x5, 8)
387 a[7] = SelectableInt(0, 1)
388 self.assertEqual(a, 4)
389 a[4:8] = 9
390 self.assertEqual(a, 9)
391 a[0:4] = 3
392 self.assertEqual(a, 0x39)
393 a[0:4] = a[4:8]
394 self.assertEqual(a, 0x99)
395
396 def test_concat(self):
397 a = SelectableInt(0x1, 1)
398 c = selectconcat(a, repeat=8)
399 self.assertEqual(c, 0xff)
400 self.assertEqual(c.bits, 8)
401 a = SelectableInt(0x0, 1)
402 c = selectconcat(a, repeat=8)
403 self.assertEqual(c, 0x00)
404 self.assertEqual(c.bits, 8)
405
406 def test_repr(self):
407 for i in range(65536):
408 a = SelectableInt(i, 16)
409 b = eval(repr(a))
410 self.assertEqual(a, b)
411
412 if __name__ == "__main__":
413 unittest.main()