hdl.ast: implement Array and ArrayProxy.
authorwhitequark <cz@m-labs.hk>
Sat, 15 Dec 2018 17:16:22 +0000 (17:16 +0000)
committerwhitequark <cz@m-labs.hk>
Sat, 15 Dec 2018 17:16:31 +0000 (17:16 +0000)
nmigen/__init__.py
nmigen/hdl/ast.py
nmigen/test/test_hdl_ast.py
nmigen/test/tools.py
nmigen/tools.py

index 739d4dfb5cc89c7c41d59843b1ba17b267a76af7..8dcd2cfe58edc63cd5cccac9edb0505d0a9383ae 100644 (file)
@@ -1,4 +1,4 @@
-from .hdl.ast import Value, Const, Mux, Cat, Repl, Signal, ClockSignal, ResetSignal
+from .hdl.ast import Value, Const, Mux, Cat, Repl, Array, Signal, ClockSignal, ResetSignal
 from .hdl.dsl import Module
 from .hdl.cd import ClockDomain
 from .hdl.ir import Fragment
index 31ea24a3cd047aa979ee644987de4bdac7e6c6d1..a9623249ebe605eaf8bd83df5a27824af72970f7 100644 (file)
@@ -2,7 +2,7 @@ from abc import ABCMeta, abstractmethod
 import builtins
 import traceback
 from collections import OrderedDict
-from collections.abc import Iterable, MutableMapping, MutableSet
+from collections.abc import Iterable, MutableMapping, MutableSet, MutableSequence
 
 from .. import tracer
 from ..tools import *
@@ -10,6 +10,7 @@ from ..tools import *
 
 __all__ = [
     "Value", "Const", "C", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
+    "Array", "ArrayProxy",
     "Signal", "ClockSignal", "ResetSignal",
     "Statement", "Assign", "Switch", "Delay", "Tick", "Passive",
     "ValueKey", "ValueDict", "ValueSet",
@@ -39,8 +40,7 @@ class Value(metaclass=ABCMeta):
     def __init__(self, src_loc_at=0):
         super().__init__()
 
-        src_loc_at += 3
-        tb = traceback.extract_stack(limit=src_loc_at)
+        tb = traceback.extract_stack(limit=3 + src_loc_at)
         if len(tb) < src_loc_at:
             self.src_loc = None
         else:
@@ -664,6 +664,127 @@ class ResetSignal(Value):
         return "(rst {})".format(self.domain)
 
 
+class Array(MutableSequence):
+    """Addressable multiplexer.
+
+    An array is similar to a ``list`` that can also be indexed by ``Value``s; indexing by an integer or a slice works the same as for Python lists, but indexing by a ``Value`` results
+    in a proxy.
+
+    The array proxy can be used as an ordinary ``Value``, i.e. participate in calculations and
+    assignments, provided that all elements of the array are values. The array proxy also supports
+    attribute access and further indexing, each returning another array proxy; this means that
+    the results of indexing into arrays, arrays of records, and arrays of arrays can all
+    be used as first-class values.
+
+    It is an error to change an array or any of its elements after an array proxy was created.
+    Changing the array directly will raise an exception. However, it is not possible to detect
+    the elements being modified; if an element's attribute or element is modified after the proxy
+    for it has been created, the proxy will refer to stale data.
+
+    Examples
+    --------
+
+    Simple array::
+
+        gpios = Array(Signal() for _ in range(10))
+        with m.If(bus.we):
+            m.d.sync += gpios[bus.adr].eq(bus.dat_w)
+        with m.Else():
+            m.d.sync += bus.dat_r.eq(gpios[bus.adr])
+
+    Multidimensional array::
+
+        mult = Array(Array(x * y for y in range(10)) for x in range(10))
+        a = Signal(max=10)
+        b = Signal(max=10)
+        r = Signal(8)
+        m.d.comb += r.eq(mult[a][b])
+
+    Array of records::
+
+        layout = [
+            ("re",     1),
+            ("dat_r", 16),
+        ]
+        buses  = Array(Record(layout) for busno in range(4))
+        master = Record(layout)
+        m.d.comb += [
+            buses[sel].re.eq(master.re),
+            master.dat_r.eq(buses[sel].dat_r),
+        ]
+    """
+    def __init__(self, iterable):
+        self._inner    = list(iterable)
+        self._proxy_at = None
+        self._mutable  = True
+
+    def __getitem__(self, index):
+        if isinstance(index, Value):
+            if self._mutable:
+                tb = traceback.extract_stack(limit=2)
+                self._proxy_at = (tb[0].filename, tb[0].lineno)
+                self._mutable  = False
+            return ArrayProxy(self, index)
+        else:
+            return self._inner[index]
+
+    def __len__(self):
+        return len(self._inner)
+
+    def _check_mutability(self):
+        if not self._mutable:
+            raise ValueError("Array can no longer be mutated after it was indexed with a value "
+                             "at {}:{}".format(*self._proxy_at))
+
+    def __setitem__(self, index, value):
+        self._check_mutability()
+        self._inner[index] = value
+
+    def __delitem__(self, index):
+        self._check_mutability()
+        del self._inner[index]
+
+    def insert(self, index, value):
+        self._check_mutability()
+        self._inner.insert(index, value)
+
+    def __repr__(self):
+        return "(array{} [{}])".format(" mutable" if self._mutable else "",
+                                       ", ".join(map(repr, self._inner)))
+
+
+class ArrayProxy(Value):
+    def __init__(self, elems, index):
+        super().__init__(src_loc_at=1)
+        self.elems = elems
+        self.index = Value.wrap(index)
+
+    def __getattr__(self, attr):
+        return ArrayProxy([getattr(elem, attr) for elem in self.elems], self.index)
+
+    def __getitem__(self, index):
+        return ArrayProxy([        elem[index] for elem in self.elems], self.index)
+
+    def _iter_as_values(self):
+        return (Value.wrap(elem) for elem in self.elems)
+
+    def shape(self):
+        bits, sign = 0, False
+        for elem_bits, elem_sign in (elem.shape() for elem in self._iter_as_values()):
+            bits = max(bits, elem_bits + elem_sign)
+            sign = max(sign, elem_sign)
+        return bits, sign
+
+    def _lhs_signals(self):
+        return union((elem._lhs_signals() for elem in self._iter_as_values()), start=ValueSet())
+
+    def _rhs_signals(self):
+        return union((elem._rhs_signals() for elem in self._iter_as_values()), start=ValueSet())
+
+    def __repr__(self):
+        return "(proxy (array [{}]) {!r})".format(", ".join(map(repr, self.elems)), self.index)
+
+
 class _StatementList(list):
     def __repr__(self):
         return "({})".format(" ".join(map(repr, self)))
@@ -713,11 +834,13 @@ class Switch(Statement):
             self.cases[key] = Statement.wrap(stmts)
 
     def _lhs_signals(self):
-        signals = union(s._lhs_signals() for ss in self.cases.values() for s in ss) or ValueSet()
+        signals = union((s._lhs_signals() for ss in self.cases.values() for s in ss),
+                        start=ValueSet())
         return signals
 
     def _rhs_signals(self):
-        signals = union(s._rhs_signals() for ss in self.cases.values() for s in ss) or ValueSet()
+        signals = union((s._rhs_signals() for ss in self.cases.values() for s in ss),
+                        start=ValueSet())
         return self.test._rhs_signals() | signals
 
     def __repr__(self):
index 0fde4c6c70b745fadde9de8307ee758130431d07..85e72f036b575d0b10abd6b361d926fbb9ef6924 100644 (file)
@@ -334,6 +334,66 @@ class ReplTestCase(FHDLTestCase):
         self.assertEqual(repr(s), "(repl (const 4'd10) 3)")
 
 
+class ArrayTestCase(FHDLTestCase):
+    def test_acts_like_array(self):
+        a = Array([1,2,3])
+        self.assertSequenceEqual(a, [1,2,3])
+        self.assertEqual(a[1], 2)
+        a[1] = 4
+        self.assertSequenceEqual(a, [1,4,3])
+        del a[1]
+        self.assertSequenceEqual(a, [1,3])
+        a.insert(1, 2)
+        self.assertSequenceEqual(a, [1,2,3])
+
+    def test_becomes_immutable(self):
+        a = Array([1,2,3])
+        s1 = Signal(max=len(a))
+        s2 = Signal(max=len(a))
+        v1 = a[s1]
+        v2 = a[s2]
+        with self.assertRaisesRegex(ValueError,
+                regex=r"^Array can no longer be mutated after it was indexed with a value at "):
+            a[1] = 2
+        with self.assertRaisesRegex(ValueError,
+                regex=r"^Array can no longer be mutated after it was indexed with a value at "):
+            del a[1]
+        with self.assertRaisesRegex(ValueError,
+                regex=r"^Array can no longer be mutated after it was indexed with a value at "):
+            a.insert(1, 2)
+
+    def test_repr(self):
+        a = Array([1,2,3])
+        self.assertEqual(repr(a), "(array mutable [1, 2, 3])")
+        s = Signal(max=len(a))
+        v = a[s]
+        self.assertEqual(repr(a), "(array [1, 2, 3])")
+
+
+class ArrayProxyTestCase(FHDLTestCase):
+    def test_index_shape(self):
+        m = Array(Array(x * y for y in range(1, 4)) for x in range(1, 4))
+        a = Signal(max=3)
+        b = Signal(max=3)
+        v = m[a][b]
+        self.assertEqual(v.shape(), (4, False))
+
+    def test_attr_shape(self):
+        from collections import namedtuple
+        pair = namedtuple("pair", ("p", "n"))
+        a = Array(pair(i, -i) for i in range(10))
+        s = Signal(max=len(a))
+        v = a[s]
+        self.assertEqual(v.p.shape(), (4, False))
+        self.assertEqual(v.n.shape(), (6, True))
+
+    def test_repr(self):
+        a = Array([1, 2, 3])
+        s = Signal(max=3)
+        v = a[s]
+        self.assertEqual(repr(v), "(proxy (array [1, 2, 3]) (sig s))")
+
+
 class SignalTestCase(FHDLTestCase):
     def test_shape(self):
         s1 = Signal()
index 097925bb263740c976ad428f094ee37ece7550a8..7c89e885bf412ad87fac92a7d36431747e518e2d 100644 (file)
@@ -25,6 +25,14 @@ class FHDLTestCase(unittest.TestCase):
             # WTF? unittest.assertRaises is completely broken.
             self.assertEqual(str(cm.exception), msg)
 
+    @contextmanager
+    def assertRaisesRegex(self, exception, regex=None):
+        with super().assertRaises(exception) as cm:
+            yield
+        if regex is not None:
+            # unittest.assertRaisesRegex also seems broken...
+            self.assertRegex(str(cm.exception), regex)
+
     @contextmanager
     def assertWarns(self, category, msg=None):
         with warnings.catch_warnings(record=True) as warns:
index 0ab4638d0d53f52c19540fa866a8670978941c9b..1f03180f51f9ea86a56af6560a00cdf27a7f09d7 100644 (file)
@@ -14,8 +14,8 @@ def flatten(i):
             yield e
 
 
-def union(i):
-    r = None
+def union(i, start=None):
+    r = start
     for e in i:
         if r is None:
             r = e