From 48d13e47ec085bb8921bf7bff77803a17cab3fe1 Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 21 Dec 2018 12:32:08 +0000 Subject: [PATCH] back.pysim: handle out of bounds ArrayProxy indexes. --- nmigen/back/pysim.py | 12 ++++++++++-- nmigen/test/test_sim.py | 15 +++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index bd90d7a..87fef0c 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -193,7 +193,12 @@ class _RHSValueCompiler(AbstractValueTransformer): shape = value.shape() elems = list(map(self, value.elems)) index = self(value.index) - return lambda state: normalize(elems[index(state)](state), shape) + def eval(state): + index_value = index(state) + if index_value >= len(elems): + index_value = len(elems) - 1 + return normalize(elems[index_value](state), shape) + return eval class _LHSValueCompiler(AbstractValueTransformer): @@ -263,7 +268,10 @@ class _LHSValueCompiler(AbstractValueTransformer): elems = list(map(self, value.elems)) index = self.rhs_compiler(value.index) def eval(state, rhs): - elems[index(state)](state, rhs) + index_value = index(state) + if index_value >= len(elems): + index_value = len(elems) - 1 + elems[index_value](state, rhs) return eval diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index 963652a..070fcdb 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -184,6 +184,12 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1)], C(4)) self.assertStatement(stmt, [C(2)], C(10)) + def test_array_oob(self): + array = Array([1, 4, 10]) + stmt = lambda y, a: y.eq(array[a]) + self.assertStatement(stmt, [C(3)], C(10)) + self.assertStatement(stmt, [C(4)], C(10)) + def test_array_lhs(self): l = Signal(3, reset=1) m = Signal(3, reset=4) @@ -194,6 +200,15 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001)) self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001)) + def test_array_lhs_oob(self): + l = Signal(3) + m = Signal(3) + n = Signal(3) + array = Array([l, m, n]) + stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))] + self.assertStatement(stmt, [C(3), C(0b001)], C(0b001000000)) + self.assertStatement(stmt, [C(4), C(0b010)], C(0b010000000)) + def test_array_index(self): array = Array(Array(x * y for y in range(10)) for x in range(10)) stmt = lambda y, a, b: y.eq(array[a][b]) -- 2.30.2