allow int in addition and subtraction
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3
4
5 class SelectableInt:
6 def __init__(self, value, bits):
7 mask = (1 << bits) - 1
8 self.value = value & mask
9 self.bits = bits
10
11 def __add__(self, b):
12 if isinstance(b, int):
13 b = SelectableInt(b, self.bits)
14 assert b.bits == self.bits
15 return SelectableInt(self.value + b.value, self.bits)
16
17 def __sub__(self, b):
18 if isinstance(b, int):
19 b = SelectableInt(b, self.bits)
20 assert b.bits == self.bits
21 return SelectableInt(self.value - b.value, self.bits)
22
23 def __mul__(self, b):
24 assert b.bits == self.bits
25 return SelectableInt(self.value * b.value, self.bits)
26
27 def __div__(self, b):
28 assert b.bits == self.bits
29 return SelectableInt(self.value / b.value, self.bits)
30
31 def __mod__(self, b):
32 assert b.bits == self.bits
33 return SelectableInt(self.value % b.value, self.bits)
34
35 def __or__(self, b):
36 assert b.bits == self.bits
37 return SelectableInt(self.value | b.value, self.bits)
38
39 def __and__(self, b):
40 assert b.bits == self.bits
41 return SelectableInt(self.value & b.value, self.bits)
42
43 def __xor__(self, b):
44 assert b.bits == self.bits
45 return SelectableInt(self.value ^ b.value, self.bits)
46
47 def __invert__(self):
48 return SelectableInt(~self.value, self.bits)
49
50 def __neg__(self):
51 return SelectableInt(~self.value + 1, self.bits)
52
53 def __getitem__(self, key):
54 if isinstance(key, int):
55 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
56 assert key >= 0
57 key = self.bits - (key + 1)
58
59 value = (self.value >> key) & 1
60 return SelectableInt(value, 1)
61 elif isinstance(key, slice):
62 assert key.step is None or key.step == 1
63 assert key.start < key.stop
64 assert key.start >= 0
65 assert key.stop <= self.bits
66
67 stop = self.bits - key.start
68 start = self.bits - key.stop
69
70 bits = stop - start + 1
71 mask = (1 << bits) - 1
72 value = (self.value >> start) & mask
73 return SelectableInt(value, bits)
74
75 def __setitem__(self, key, value):
76 if isinstance(key, int):
77 assert key < self.bits
78 assert key >= 0
79 key = self.bits - (key + 1)
80 if isinstance(value, SelectableInt):
81 assert value.bits == 1
82 value = value.value
83
84 value = value << key
85 mask = 1 << key
86 self.value = (self.value & ~mask) | (value & mask)
87 elif isinstance(key, slice):
88 assert key.step is None or key.step == 1
89 assert key.start < key.stop
90 assert key.start >= 0
91 assert key.stop <= self.bits
92
93 stop = self.bits - key.start
94 start = self.bits - key.stop
95
96 bits = stop - start + 1
97 if isinstance(value, SelectableInt):
98 assert value.bits == bits, "%d into %d" % (value.bits, bits)
99 value = value.value
100 mask = ((1 << bits) - 1) << start
101 value = value << start
102 self.value = (self.value & ~mask) | (value & mask)
103
104 def __ge__(self, other):
105 if isinstance(other, SelectableInt):
106 assert other.bits == self.bits
107 other = other.value
108 if isinstance(other, int):
109 return other >= self.value
110 assert False
111
112 def __le__(self, other):
113 if isinstance(other, SelectableInt):
114 assert other.bits == self.bits
115 other = other.value
116 if isinstance(other, int):
117 return onebit(other <= self.value)
118 assert False
119
120 def __gt__(self, other):
121 if isinstance(other, SelectableInt):
122 assert other.bits == self.bits
123 other = other.value
124 if isinstance(other, int):
125 return onebit(other > self.value)
126 assert False
127
128 def __lt__(self, other):
129 if isinstance(other, SelectableInt):
130 assert other.bits == self.bits
131 other = other.value
132 if isinstance(other, int):
133 return onebit(other < self.value)
134 assert False
135
136 def __eq__(self, other):
137 if isinstance(other, SelectableInt):
138 assert other.bits == self.bits
139 other = other.value
140 if isinstance(other, int):
141 return onebit(other == self.value)
142 assert False
143
144 def __bool__(self):
145 return self.value != 0
146
147 def __repr__(self):
148 return "SelectableInt(value={:x}, bits={})".format(self.value,
149 self.bits)
150
151 def onebit(bit):
152 return SelectableInt(1 if bit else 0, 1)
153
154 def selectltu(lhs, rhs):
155 """ less-than (unsigned)
156 """
157 if isinstance(rhs, SelectableInt):
158 rhs = rhs.value
159 return onebit(lhs.value < rhs)
160
161 def selectgtu(lhs, rhs):
162 """ greater-than (unsigned)
163 """
164 if isinstance(rhs, SelectableInt):
165 rhs = rhs.value
166 return onebit(lhs.value > rhs)
167
168
169 # XXX this probably isn't needed...
170 def selectassign(lhs, idx, rhs):
171 if isinstance(idx, tuple):
172 if len(idx) == 2:
173 lower, upper = idx
174 step = None
175 else:
176 lower, upper, step = idx
177 toidx = range(lower, upper, step)
178 fromidx = range(0, upper-lower, step) # XXX eurgh...
179 else:
180 toidx = [idx]
181 fromidx = [0]
182 for t, f in zip(toidx, fromidx):
183 lhs[t] = rhs[f]
184
185
186 def selectconcat(*args, repeat=1):
187 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
188 args = [SelectableInt(args[0], 1)]
189 if repeat != 1: # multiplies the incoming arguments
190 tmp = []
191 for i in range(repeat):
192 tmp += args
193 args = tmp
194 res = copy(args[0])
195 for i in args[1:]:
196 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
197 res.bits += i.bits
198 res.value = (res.value << i.bits) | i.value
199 print ("concat", repeat, res)
200 return res
201
202
203 class SelectableIntTestCase(unittest.TestCase):
204 def test_arith(self):
205 a = SelectableInt(5, 8)
206 b = SelectableInt(9, 8)
207 c = a + b
208 d = a - b
209 e = a * b
210 f = -a
211 self.assertEqual(c.value, a.value + b.value)
212 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
213 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
214 self.assertEqual(f.value, (-a.value) & 0xFF)
215 self.assertEqual(c.bits, a.bits)
216 self.assertEqual(d.bits, a.bits)
217 self.assertEqual(e.bits, a.bits)
218
219 def test_logic(self):
220 a = SelectableInt(0x0F, 8)
221 b = SelectableInt(0xA5, 8)
222 c = a & b
223 d = a | b
224 e = a ^ b
225 f = ~a
226 self.assertEqual(c.value, a.value & b.value)
227 self.assertEqual(d.value, a.value | b.value)
228 self.assertEqual(e.value, a.value ^ b.value)
229 self.assertEqual(f.value, 0xF0)
230
231 def test_get(self):
232 a = SelectableInt(0xa2, 8)
233 # These should be big endian
234 self.assertEqual(a[7], 0)
235 self.assertEqual(a[0:4], 10)
236 self.assertEqual(a[4:8], 2)
237
238 def test_set(self):
239 a = SelectableInt(0x5, 8)
240 a[7] = SelectableInt(0, 1)
241 self.assertEqual(a, 4)
242 a[4:8] = 9
243 self.assertEqual(a, 9)
244 a[0:4] = 3
245 self.assertEqual(a, 0x39)
246 a[0:4] = a[4:8]
247 self.assertEqual(a, 0x199)
248
249 if __name__ == "__main__":
250 unittest.main()