From 020aeea7de3be7752ed52fa7b6e4a31f2620ddb8 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 15 Dec 2018 19:37:36 +0000 Subject: [PATCH] back.pysim: implement ArrayProxy. --- nmigen/back/pysim.py | 6 ++++++ nmigen/hdl/ast.py | 6 ++++-- nmigen/hdl/xfrm.py | 6 ++++++ nmigen/test/test_sim.py | 23 +++++++++++++++++++++++ 4 files changed, 39 insertions(+), 2 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 634dd20..a3c09d1 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -148,6 +148,12 @@ class _RHSValueCompiler(ValueTransformer): return normalize(result, shape) return eval + def on_ArrayProxy(self, value): + shape = value.shape() + elems = list(map(self, value.elems)) + index = self(value.index) + return lambda state: normalize(elems[index(state)](state), shape) + class _StatementCompiler(StatementTransformer): def __init__(self): diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index a962324..42923f4 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -776,10 +776,12 @@ class ArrayProxy(Value): return bits, sign def _lhs_signals(self): - return union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + signals = union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + return signals def _rhs_signals(self): - return union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + signals = union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + return self.index._rhs_signals() | signals def __repr__(self): return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 1ef3275..8d09ac8 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -40,6 +40,10 @@ class ValueTransformer: def on_Repl(self, value): return Repl(self.on_value(value.value), value.count) + def on_ArrayProxy(self, value): + return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()], + self.on_value(value.index)) + def on_unknown_value(self, value): raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov: @@ -62,6 +66,8 @@ class ValueTransformer: new_value = self.on_Cat(value) elif isinstance(value, Repl): new_value = self.on_Repl(value) + elif isinstance(value, ArrayProxy): + new_value = self.on_ArrayProxy(value) else: new_value = self.on_unknown_value(value) if isinstance(new_value, Value): diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index f6446e9..7cecfd7 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -136,3 +136,26 @@ class SimulatorUnitTestCase(FHDLTestCase): def test_repl(self): stmt = lambda a: Repl(a, 3) self.assertOperator(stmt, [C(0b10, 2)], C(0b101010, 6)) + + def test_array(self): + array = Array([1, 4, 10]) + stmt = lambda a: array[a] + self.assertOperator(stmt, [C(0)], C(1)) + self.assertOperator(stmt, [C(1)], C(4)) + self.assertOperator(stmt, [C(2)], C(10)) + + def test_array_index(self): + array = Array(Array(x * y for y in range(10)) for x in range(10)) + stmt = lambda a, b: array[a][b] + for x in range(10): + for y in range(10): + self.assertOperator(stmt, [C(x), C(y)], C(x * y)) + + def test_array_attr(self): + from collections import namedtuple + pair = namedtuple("pair", ("p", "n")) + + array = Array(pair(x, -x) for x in range(10)) + stmt = lambda a: array[a].p + array[a].n + for i in range(10): + self.assertOperator(stmt, [C(i)], C(0)) -- 2.30.2