fix repr (0x prefix) and add repr test for selectable int
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3
4 def check_extsign(a, b):
5 if b.bits != 256:
6 return b
7 return SelectableInt(b.value, a.bits)
8
9
10 class SelectableInt:
11 def __init__(self, value, bits):
12 mask = (1 << bits) - 1
13 self.value = value & mask
14 self.bits = bits
15
16 def __add__(self, b):
17 if isinstance(b, int):
18 b = SelectableInt(b, self.bits)
19 b = check_extsign(self, b)
20 assert b.bits == self.bits
21 return SelectableInt(self.value + b.value, self.bits)
22
23 def __sub__(self, b):
24 if isinstance(b, int):
25 b = SelectableInt(b, self.bits)
26 b = check_extsign(self, b)
27 assert b.bits == self.bits
28 return SelectableInt(self.value - b.value, self.bits)
29
30 def __mul__(self, b):
31 b = check_extsign(self, b)
32 assert b.bits == self.bits
33 return SelectableInt(self.value * b.value, self.bits)
34
35 def __div__(self, b):
36 b = check_extsign(self, b)
37 assert b.bits == self.bits
38 return SelectableInt(self.value / b.value, self.bits)
39
40 def __mod__(self, b):
41 b = check_extsign(self, b)
42 assert b.bits == self.bits
43 return SelectableInt(self.value % b.value, self.bits)
44
45 def __or__(self, b):
46 b = check_extsign(self, b)
47 assert b.bits == self.bits
48 return SelectableInt(self.value | b.value, self.bits)
49
50 def __and__(self, b):
51 b = check_extsign(self, b)
52 assert b.bits == self.bits
53 return SelectableInt(self.value & b.value, self.bits)
54
55 def __xor__(self, b):
56 b = check_extsign(self, b)
57 assert b.bits == self.bits
58 return SelectableInt(self.value ^ b.value, self.bits)
59
60 def __invert__(self):
61 return SelectableInt(~self.value, self.bits)
62
63 def __neg__(self):
64 return SelectableInt(~self.value + 1, self.bits)
65
66 def __getitem__(self, key):
67 if isinstance(key, int):
68 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
69 assert key >= 0
70 key = self.bits - (key + 1)
71
72 value = (self.value >> key) & 1
73 return SelectableInt(value, 1)
74 elif isinstance(key, slice):
75 assert key.step is None or key.step == 1
76 assert key.start < key.stop
77 assert key.start >= 0
78 assert key.stop <= self.bits
79
80 stop = self.bits - key.start
81 start = self.bits - key.stop
82
83 bits = stop - start + 1
84 mask = (1 << bits) - 1
85 value = (self.value >> start) & mask
86 return SelectableInt(value, bits)
87
88 def __setitem__(self, key, value):
89 if isinstance(key, int):
90 assert key < self.bits
91 assert key >= 0
92 key = self.bits - (key + 1)
93 if isinstance(value, SelectableInt):
94 assert value.bits == 1
95 value = value.value
96
97 value = value << key
98 mask = 1 << key
99 self.value = (self.value & ~mask) | (value & mask)
100 elif isinstance(key, slice):
101 assert key.step is None or key.step == 1
102 assert key.start < key.stop
103 assert key.start >= 0
104 assert key.stop <= self.bits
105
106 stop = self.bits - key.start
107 start = self.bits - key.stop
108
109 bits = stop - start + 1
110 if isinstance(value, SelectableInt):
111 assert value.bits == bits, "%d into %d" % (value.bits, bits)
112 value = value.value
113 mask = ((1 << bits) - 1) << start
114 value = value << start
115 self.value = (self.value & ~mask) | (value & mask)
116
117 def __ge__(self, other):
118 if isinstance(other, SelectableInt):
119 other = check_extsign(self, other)
120 assert other.bits == self.bits
121 other = other.value
122 if isinstance(other, int):
123 return other >= self.value
124 assert False
125
126 def __le__(self, other):
127 if isinstance(other, SelectableInt):
128 other = check_extsign(self, other)
129 assert other.bits == self.bits
130 other = other.value
131 if isinstance(other, int):
132 return onebit(other <= self.value)
133 assert False
134
135 def __gt__(self, other):
136 if isinstance(other, SelectableInt):
137 other = check_extsign(self, other)
138 assert other.bits == self.bits
139 other = other.value
140 if isinstance(other, int):
141 return onebit(other > self.value)
142 assert False
143
144 def __lt__(self, other):
145 if isinstance(other, SelectableInt):
146 other = check_extsign(self, other)
147 assert other.bits == self.bits
148 other = other.value
149 if isinstance(other, int):
150 return onebit(other < self.value)
151 assert False
152
153 def __eq__(self, other):
154 if isinstance(other, SelectableInt):
155 other = check_extsign(self, other)
156 assert other.bits == self.bits
157 other = other.value
158 if isinstance(other, int):
159 return onebit(other == self.value)
160 assert False
161
162 def narrow(self, bits):
163 assert bits <= self.bits
164 return SelectableInt(self.value, bits)
165
166 def __bool__(self):
167 return self.value != 0
168
169 def __repr__(self):
170 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
171 self.bits)
172
173 def onebit(bit):
174 return SelectableInt(1 if bit else 0, 1)
175
176 def selectltu(lhs, rhs):
177 """ less-than (unsigned)
178 """
179 if isinstance(rhs, SelectableInt):
180 rhs = rhs.value
181 return onebit(lhs.value < rhs)
182
183 def selectgtu(lhs, rhs):
184 """ greater-than (unsigned)
185 """
186 if isinstance(rhs, SelectableInt):
187 rhs = rhs.value
188 return onebit(lhs.value > rhs)
189
190
191 # XXX this probably isn't needed...
192 def selectassign(lhs, idx, rhs):
193 if isinstance(idx, tuple):
194 if len(idx) == 2:
195 lower, upper = idx
196 step = None
197 else:
198 lower, upper, step = idx
199 toidx = range(lower, upper, step)
200 fromidx = range(0, upper-lower, step) # XXX eurgh...
201 else:
202 toidx = [idx]
203 fromidx = [0]
204 for t, f in zip(toidx, fromidx):
205 lhs[t] = rhs[f]
206
207
208 def selectconcat(*args, repeat=1):
209 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
210 args = [SelectableInt(args[0], 1)]
211 if repeat != 1: # multiplies the incoming arguments
212 tmp = []
213 for i in range(repeat):
214 tmp += args
215 args = tmp
216 res = copy(args[0])
217 for i in args[1:]:
218 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
219 res.bits += i.bits
220 res.value = (res.value << i.bits) | i.value
221 print ("concat", repeat, res)
222 return res
223
224
225 class SelectableIntTestCase(unittest.TestCase):
226 def test_arith(self):
227 a = SelectableInt(5, 8)
228 b = SelectableInt(9, 8)
229 c = a + b
230 d = a - b
231 e = a * b
232 f = -a
233 self.assertEqual(c.value, a.value + b.value)
234 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
235 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
236 self.assertEqual(f.value, (-a.value) & 0xFF)
237 self.assertEqual(c.bits, a.bits)
238 self.assertEqual(d.bits, a.bits)
239 self.assertEqual(e.bits, a.bits)
240
241 def test_logic(self):
242 a = SelectableInt(0x0F, 8)
243 b = SelectableInt(0xA5, 8)
244 c = a & b
245 d = a | b
246 e = a ^ b
247 f = ~a
248 self.assertEqual(c.value, a.value & b.value)
249 self.assertEqual(d.value, a.value | b.value)
250 self.assertEqual(e.value, a.value ^ b.value)
251 self.assertEqual(f.value, 0xF0)
252
253 def test_get(self):
254 a = SelectableInt(0xa2, 8)
255 # These should be big endian
256 self.assertEqual(a[7], 0)
257 self.assertEqual(a[0:4], 10)
258 self.assertEqual(a[4:8], 2)
259
260 def test_set(self):
261 a = SelectableInt(0x5, 8)
262 a[7] = SelectableInt(0, 1)
263 self.assertEqual(a, 4)
264 a[4:8] = 9
265 self.assertEqual(a, 9)
266 a[0:4] = 3
267 self.assertEqual(a, 0x39)
268 a[0:4] = a[4:8]
269 self.assertEqual(a, 0x199)
270
271 def test_concat(self):
272 a = SelectableInt(0x1, 1)
273 c = selectconcat(a, repeat=8)
274 self.assertEqual(c, 0xff)
275 self.assertEqual(c.bits, 8)
276 a = SelectableInt(0x0, 1)
277 c = selectconcat(a, repeat=8)
278 self.assertEqual(c, 0x00)
279 self.assertEqual(c.bits, 8)
280
281 def test_repr(self):
282 for i in range(65536):
283 a = SelectableInt(i, 16)
284 b = eval(repr(a))
285 self.assertEqual(a, b)
286
287 if __name__ == "__main__":
288 unittest.main()