add selectconcat test
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3
4 def check_extsign(a, b):
5 if b.bits != 256:
6 return b
7 return SelectableInt(b.value, a.bits)
8
9
10 class SelectableInt:
11 def __init__(self, value, bits):
12 mask = (1 << bits) - 1
13 self.value = value & mask
14 self.bits = bits
15
16 def __add__(self, b):
17 if isinstance(b, int):
18 b = SelectableInt(b, self.bits)
19 b = check_extsign(self, b)
20 assert b.bits == self.bits
21 return SelectableInt(self.value + b.value, self.bits)
22
23 def __sub__(self, b):
24 if isinstance(b, int):
25 b = SelectableInt(b, self.bits)
26 b = check_extsign(self, b)
27 assert b.bits == self.bits
28 return SelectableInt(self.value - b.value, self.bits)
29
30 def __mul__(self, b):
31 b = check_extsign(self, b)
32 assert b.bits == self.bits
33 return SelectableInt(self.value * b.value, self.bits)
34
35 def __div__(self, b):
36 b = check_extsign(self, b)
37 assert b.bits == self.bits
38 return SelectableInt(self.value / b.value, self.bits)
39
40 def __mod__(self, b):
41 b = check_extsign(self, b)
42 assert b.bits == self.bits
43 return SelectableInt(self.value % b.value, self.bits)
44
45 def __or__(self, b):
46 b = check_extsign(self, b)
47 assert b.bits == self.bits
48 return SelectableInt(self.value | b.value, self.bits)
49
50 def __and__(self, b):
51 b = check_extsign(self, b)
52 assert b.bits == self.bits
53 return SelectableInt(self.value & b.value, self.bits)
54
55 def __xor__(self, b):
56 b = check_extsign(self, b)
57 assert b.bits == self.bits
58 return SelectableInt(self.value ^ b.value, self.bits)
59
60 def __invert__(self):
61 return SelectableInt(~self.value, self.bits)
62
63 def __neg__(self):
64 return SelectableInt(~self.value + 1, self.bits)
65
66 def __getitem__(self, key):
67 if isinstance(key, int):
68 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
69 assert key >= 0
70 key = self.bits - (key + 1)
71
72 value = (self.value >> key) & 1
73 return SelectableInt(value, 1)
74 elif isinstance(key, slice):
75 assert key.step is None or key.step == 1
76 assert key.start < key.stop
77 assert key.start >= 0
78 assert key.stop <= self.bits
79
80 stop = self.bits - key.start
81 start = self.bits - key.stop
82
83 bits = stop - start + 1
84 mask = (1 << bits) - 1
85 value = (self.value >> start) & mask
86 return SelectableInt(value, bits)
87
88 def __setitem__(self, key, value):
89 if isinstance(key, int):
90 assert key < self.bits
91 assert key >= 0
92 key = self.bits - (key + 1)
93 if isinstance(value, SelectableInt):
94 assert value.bits == 1
95 value = value.value
96
97 value = value << key
98 mask = 1 << key
99 self.value = (self.value & ~mask) | (value & mask)
100 elif isinstance(key, slice):
101 assert key.step is None or key.step == 1
102 assert key.start < key.stop
103 assert key.start >= 0
104 assert key.stop <= self.bits
105
106 stop = self.bits - key.start
107 start = self.bits - key.stop
108
109 bits = stop - start + 1
110 if isinstance(value, SelectableInt):
111 assert value.bits == bits, "%d into %d" % (value.bits, bits)
112 value = value.value
113 mask = ((1 << bits) - 1) << start
114 value = value << start
115 self.value = (self.value & ~mask) | (value & mask)
116
117 def __ge__(self, other):
118 if isinstance(other, SelectableInt):
119 other = check_extsign(self, other)
120 assert other.bits == self.bits
121 other = other.value
122 if isinstance(other, int):
123 return other >= self.value
124 assert False
125
126 def __le__(self, other):
127 if isinstance(other, SelectableInt):
128 other = check_extsign(self, other)
129 assert other.bits == self.bits
130 other = other.value
131 if isinstance(other, int):
132 return onebit(other <= self.value)
133 assert False
134
135 def __gt__(self, other):
136 if isinstance(other, SelectableInt):
137 other = check_extsign(self, other)
138 assert other.bits == self.bits
139 other = other.value
140 if isinstance(other, int):
141 return onebit(other > self.value)
142 assert False
143
144 def __lt__(self, other):
145 if isinstance(other, SelectableInt):
146 other = check_extsign(self, other)
147 assert other.bits == self.bits
148 other = other.value
149 if isinstance(other, int):
150 return onebit(other < self.value)
151 assert False
152
153 def __eq__(self, other):
154 if isinstance(other, SelectableInt):
155 other = check_extsign(self, other)
156 assert other.bits == self.bits
157 other = other.value
158 if isinstance(other, int):
159 return onebit(other == self.value)
160 assert False
161
162 def __bool__(self):
163 return self.value != 0
164
165 def __repr__(self):
166 return "SelectableInt(value={:x}, bits={})".format(self.value,
167 self.bits)
168
169 def onebit(bit):
170 return SelectableInt(1 if bit else 0, 1)
171
172 def selectltu(lhs, rhs):
173 """ less-than (unsigned)
174 """
175 if isinstance(rhs, SelectableInt):
176 rhs = rhs.value
177 return onebit(lhs.value < rhs)
178
179 def selectgtu(lhs, rhs):
180 """ greater-than (unsigned)
181 """
182 if isinstance(rhs, SelectableInt):
183 rhs = rhs.value
184 return onebit(lhs.value > rhs)
185
186
187 # XXX this probably isn't needed...
188 def selectassign(lhs, idx, rhs):
189 if isinstance(idx, tuple):
190 if len(idx) == 2:
191 lower, upper = idx
192 step = None
193 else:
194 lower, upper, step = idx
195 toidx = range(lower, upper, step)
196 fromidx = range(0, upper-lower, step) # XXX eurgh...
197 else:
198 toidx = [idx]
199 fromidx = [0]
200 for t, f in zip(toidx, fromidx):
201 lhs[t] = rhs[f]
202
203
204 def selectconcat(*args, repeat=1):
205 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
206 args = [SelectableInt(args[0], 1)]
207 if repeat != 1: # multiplies the incoming arguments
208 tmp = []
209 for i in range(repeat):
210 tmp += args
211 args = tmp
212 res = copy(args[0])
213 for i in args[1:]:
214 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
215 res.bits += i.bits
216 res.value = (res.value << i.bits) | i.value
217 print ("concat", repeat, res)
218 return res
219
220
221 class SelectableIntTestCase(unittest.TestCase):
222 def test_arith(self):
223 a = SelectableInt(5, 8)
224 b = SelectableInt(9, 8)
225 c = a + b
226 d = a - b
227 e = a * b
228 f = -a
229 self.assertEqual(c.value, a.value + b.value)
230 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
231 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
232 self.assertEqual(f.value, (-a.value) & 0xFF)
233 self.assertEqual(c.bits, a.bits)
234 self.assertEqual(d.bits, a.bits)
235 self.assertEqual(e.bits, a.bits)
236
237 def test_logic(self):
238 a = SelectableInt(0x0F, 8)
239 b = SelectableInt(0xA5, 8)
240 c = a & b
241 d = a | b
242 e = a ^ b
243 f = ~a
244 self.assertEqual(c.value, a.value & b.value)
245 self.assertEqual(d.value, a.value | b.value)
246 self.assertEqual(e.value, a.value ^ b.value)
247 self.assertEqual(f.value, 0xF0)
248
249 def test_get(self):
250 a = SelectableInt(0xa2, 8)
251 # These should be big endian
252 self.assertEqual(a[7], 0)
253 self.assertEqual(a[0:4], 10)
254 self.assertEqual(a[4:8], 2)
255
256 def test_set(self):
257 a = SelectableInt(0x5, 8)
258 a[7] = SelectableInt(0, 1)
259 self.assertEqual(a, 4)
260 a[4:8] = 9
261 self.assertEqual(a, 9)
262 a[0:4] = 3
263 self.assertEqual(a, 0x39)
264 a[0:4] = a[4:8]
265 self.assertEqual(a, 0x199)
266
267 def test_concat(self):
268 a = SelectableInt(0x1, 1)
269 c = selectconcat(a, repeat=8)
270 self.assertEqual(c, 0xff)
271 self.assertEqual(c.bits, 8)
272 a = SelectableInt(0x0, 1)
273 c = selectconcat(a, repeat=8)
274 self.assertEqual(c, 0x00)
275 self.assertEqual(c.bits, 8)
276
277
278 if __name__ == "__main__":
279 unittest.main()