shape = value.shape()
elems = list(map(self, value.elems))
index = self(value.index)
- return lambda state: normalize(elems[index(state)](state), shape)
+ def eval(state):
+ index_value = index(state)
+ if index_value >= len(elems):
+ index_value = len(elems) - 1
+ return normalize(elems[index_value](state), shape)
+ return eval
class _LHSValueCompiler(AbstractValueTransformer):
elems = list(map(self, value.elems))
index = self.rhs_compiler(value.index)
def eval(state, rhs):
- elems[index(state)](state, rhs)
+ index_value = index(state)
+ if index_value >= len(elems):
+ index_value = len(elems) - 1
+ elems[index_value](state, rhs)
return eval
self.assertStatement(stmt, [C(1)], C(4))
self.assertStatement(stmt, [C(2)], C(10))
+ def test_array_oob(self):
+ array = Array([1, 4, 10])
+ stmt = lambda y, a: y.eq(array[a])
+ self.assertStatement(stmt, [C(3)], C(10))
+ self.assertStatement(stmt, [C(4)], C(10))
+
def test_array_lhs(self):
l = Signal(3, reset=1)
m = Signal(3, reset=4)
self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
+ def test_array_lhs_oob(self):
+ l = Signal(3)
+ m = Signal(3)
+ n = Signal(3)
+ array = Array([l, m, n])
+ stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
+ self.assertStatement(stmt, [C(3), C(0b001)], C(0b001000000))
+ self.assertStatement(stmt, [C(4), C(0b010)], C(0b010000000))
+
def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10))
stmt = lambda y, a, b: y.eq(array[a][b])