From 80c53436003d5247270bcbf2bb71aead86e0efd9 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 15 Dec 2018 17:16:22 +0000 Subject: [PATCH] hdl.ast: implement Array and ArrayProxy. --- nmigen/__init__.py | 2 +- nmigen/hdl/ast.py | 133 ++++++++++++++++++++++++++++++++++-- nmigen/test/test_hdl_ast.py | 60 ++++++++++++++++ nmigen/test/tools.py | 8 +++ nmigen/tools.py | 4 +- 5 files changed, 199 insertions(+), 8 deletions(-) diff --git a/nmigen/__init__.py b/nmigen/__init__.py index 739d4df..8dcd2cf 100644 --- a/nmigen/__init__.py +++ b/nmigen/__init__.py @@ -1,4 +1,4 @@ -from .hdl.ast import Value, Const, Mux, Cat, Repl, Signal, ClockSignal, ResetSignal +from .hdl.ast import Value, Const, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal from .hdl.dsl import Module from .hdl.cd import ClockDomain from .hdl.ir import Fragment diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 31ea24a..a962324 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod import builtins import traceback from collections import OrderedDict -from collections.abc import Iterable, MutableMapping, MutableSet +from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence from .. import tracer from ..tools import * @@ -10,6 +10,7 @@ from ..tools import * __all__ = [ "Value", "Const", "C", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", + "Array", "ArrayProxy", "Signal", "ClockSignal", "ResetSignal", "Statement", "Assign", "Switch", "Delay", "Tick", "Passive", "ValueKey", "ValueDict", "ValueSet", @@ -39,8 +40,7 @@ class Value(metaclass=ABCMeta): def __init__(self, src_loc_at=0): super().__init__() - src_loc_at += 3 - tb = traceback.extract_stack(limit=src_loc_at) + tb = traceback.extract_stack(limit=3 + src_loc_at) if len(tb) < src_loc_at: self.src_loc = None else: @@ -664,6 +664,127 @@ class ResetSignal(Value): return "(rst {})".format(self.domain) +class Array(MutableSequence): + """Addressable multiplexer. + + An array is similar to a ``list`` that can also be indexed by ``Value``s; indexing by an integer or a slice works the same as for Python lists, but indexing by a ``Value`` results + in a proxy. + + The array proxy can be used as an ordinary ``Value``, i.e. participate in calculations and + assignments, provided that all elements of the array are values. The array proxy also supports + attribute access and further indexing, each returning another array proxy; this means that + the results of indexing into arrays, arrays of records, and arrays of arrays can all + be used as first-class values. + + It is an error to change an array or any of its elements after an array proxy was created. + Changing the array directly will raise an exception. However, it is not possible to detect + the elements being modified; if an element's attribute or element is modified after the proxy + for it has been created, the proxy will refer to stale data. + + Examples + -------- + + Simple array:: + + gpios = Array(Signal() for _ in range(10)) + with m.If(bus.we): + m.d.sync += gpios[bus.adr].eq(bus.dat_w) + with m.Else(): + m.d.sync += bus.dat_r.eq(gpios[bus.adr]) + + Multidimensional array:: + + mult = Array(Array(x * y for y in range(10)) for x in range(10)) + a = Signal(max=10) + b = Signal(max=10) + r = Signal(8) + m.d.comb += r.eq(mult[a][b]) + + Array of records:: + + layout = [ + ("re", 1), + ("dat_r", 16), + ] + buses = Array(Record(layout) for busno in range(4)) + master = Record(layout) + m.d.comb += [ + buses[sel].re.eq(master.re), + master.dat_r.eq(buses[sel].dat_r), + ] + """ + def __init__(self, iterable): + self._inner = list(iterable) + self._proxy_at = None + self._mutable = True + + def __getitem__(self, index): + if isinstance(index, Value): + if self._mutable: + tb = traceback.extract_stack(limit=2) + self._proxy_at = (tb[0].filename, tb[0].lineno) + self._mutable = False + return ArrayProxy(self, index) + else: + return self._inner[index] + + def __len__(self): + return len(self._inner) + + def _check_mutability(self): + if not self._mutable: + raise ValueError("Array can no longer be mutated after it was indexed with a value " + "at {}:{}".format(*self._proxy_at)) + + def __setitem__(self, index, value): + self._check_mutability() + self._inner[index] = value + + def __delitem__(self, index): + self._check_mutability() + del self._inner[index] + + def insert(self, index, value): + self._check_mutability() + self._inner.insert(index, value) + + def __repr__(self): + return "(array{} [{}])".format(" mutable" if self._mutable else "", + ", ".join(map(repr, self._inner))) + + +class ArrayProxy(Value): + def __init__(self, elems, index): + super().__init__(src_loc_at=1) + self.elems = elems + self.index = Value.wrap(index) + + def __getattr__(self, attr): + return ArrayProxy([getattr(elem, attr) for elem in self.elems], self.index) + + def __getitem__(self, index): + return ArrayProxy([ elem[index] for elem in self.elems], self.index) + + def _iter_as_values(self): + return (Value.wrap(elem) for elem in self.elems) + + def shape(self): + bits, sign = 0, False + for elem_bits, elem_sign in (elem.shape() for elem in self._iter_as_values()): + bits = max(bits, elem_bits + elem_sign) + sign = max(sign, elem_sign) + return bits, sign + + def _lhs_signals(self): + return union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + + def _rhs_signals(self): + return union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet()) + + def __repr__(self): + return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index) + + class _StatementList(list): def __repr__(self): return "({})".format(" ".join(map(repr, self))) @@ -713,11 +834,13 @@ class Switch(Statement): self.cases[key] = Statement.wrap(stmts) def _lhs_signals(self): - signals = union(s._lhs_signals() for ss in self.cases.values() for s in ss) or ValueSet() + signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss), + start=ValueSet()) return signals def _rhs_signals(self): - signals = union(s._rhs_signals() for ss in self.cases.values() for s in ss) or ValueSet() + signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss), + start=ValueSet()) return self.test._rhs_signals() | signals def __repr__(self): diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 0fde4c6..85e72f0 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -334,6 +334,66 @@ class ReplTestCase(FHDLTestCase): self.assertEqual(repr(s), "(repl (const 4'd10) 3)") +class ArrayTestCase(FHDLTestCase): + def test_acts_like_array(self): + a = Array([1,2,3]) + self.assertSequenceEqual(a, [1,2,3]) + self.assertEqual(a[1], 2) + a[1] = 4 + self.assertSequenceEqual(a, [1,4,3]) + del a[1] + self.assertSequenceEqual(a, [1,3]) + a.insert(1, 2) + self.assertSequenceEqual(a, [1,2,3]) + + def test_becomes_immutable(self): + a = Array([1,2,3]) + s1 = Signal(max=len(a)) + s2 = Signal(max=len(a)) + v1 = a[s1] + v2 = a[s2] + with self.assertRaisesRegex(ValueError, + regex=r"^Array can no longer be mutated after it was indexed with a value at "): + a[1] = 2 + with self.assertRaisesRegex(ValueError, + regex=r"^Array can no longer be mutated after it was indexed with a value at "): + del a[1] + with self.assertRaisesRegex(ValueError, + regex=r"^Array can no longer be mutated after it was indexed with a value at "): + a.insert(1, 2) + + def test_repr(self): + a = Array([1,2,3]) + self.assertEqual(repr(a), "(array mutable [1, 2, 3])") + s = Signal(max=len(a)) + v = a[s] + self.assertEqual(repr(a), "(array [1, 2, 3])") + + +class ArrayProxyTestCase(FHDLTestCase): + def test_index_shape(self): + m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4)) + a = Signal(max=3) + b = Signal(max=3) + v = m[a][b] + self.assertEqual(v.shape(), (4, False)) + + def test_attr_shape(self): + from collections import namedtuple + pair = namedtuple("pair", ("p", "n")) + a = Array(pair(i, -i) for i in range(10)) + s = Signal(max=len(a)) + v = a[s] + self.assertEqual(v.p.shape(), (4, False)) + self.assertEqual(v.n.shape(), (6, True)) + + def test_repr(self): + a = Array([1, 2, 3]) + s = Signal(max=3) + v = a[s] + self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))") + + class SignalTestCase(FHDLTestCase): def test_shape(self): s1 = Signal() diff --git a/nmigen/test/tools.py b/nmigen/test/tools.py index 097925b..7c89e88 100644 --- a/nmigen/test/tools.py +++ b/nmigen/test/tools.py @@ -25,6 +25,14 @@ class FHDLTestCase(unittest.TestCase): # WTF? unittest.assertRaises is completely broken. self.assertEqual(str(cm.exception), msg) + @contextmanager + def assertRaisesRegex(self, exception, regex=None): + with super().assertRaises(exception) as cm: + yield + if regex is not None: + # unittest.assertRaisesRegex also seems broken... + self.assertRegex(str(cm.exception), regex) + @contextmanager def assertWarns(self, category, msg=None): with warnings.catch_warnings(record=True) as warns: diff --git a/nmigen/tools.py b/nmigen/tools.py index 0ab4638..1f03180 100644 --- a/nmigen/tools.py +++ b/nmigen/tools.py @@ -14,8 +14,8 @@ def flatten(i): yield e -def union(i): - r = None +def union(i, start=None): + r = start for e in i: if r is None: r = e -- 2.30.2