hdl, back: add and use SignalSet/SignalDict.
authorwhitequark <cz@m-labs.hk>
Mon, 17 Dec 2018 17:21:12 +0000 (17:21 +0000)
committerwhitequark <cz@m-labs.hk>
Mon, 17 Dec 2018 17:21:29 +0000 (17:21 +0000)
nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/hdl/ast.py
nmigen/hdl/dsl.py
nmigen/hdl/ir.py
nmigen/test/test_hdl_dsl.py
nmigen/test/test_hdl_ir.py
nmigen/test/test_hdl_xfrm.py

index 4d0364dd326d10eb63c6d7247a1e625ce650c62f..1871389ae47988da824adc04f19fc740330f88df 100644 (file)
@@ -20,10 +20,10 @@ class _State:
     __slots__ = ("curr", "curr_dirty", "next", "next_dirty")
 
     def __init__(self):
-        self.curr = ValueDict()
-        self.next = ValueDict()
-        self.curr_dirty = ValueSet()
-        self.next_dirty = ValueSet()
+        self.curr = SignalDict()
+        self.next = SignalDict()
+        self.curr_dirty = SignalSet()
+        self.next_dirty = SignalSet()
 
     def set(self, signal, value):
         assert isinstance(value, int)
@@ -236,7 +236,7 @@ class _LHSValueCompiler(AbstractValueTransformer):
 
 class _StatementCompiler(AbstractStatementTransformer):
     def __init__(self):
-        self.sensitivity   = ValueSet()
+        self.sensitivity   = SignalSet()
         self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
         self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
         self.lhs_compiler  = _LHSValueCompiler(self.lrhs_compiler)
@@ -284,13 +284,13 @@ class Simulator:
         self._fragment        = fragment
 
         self._domains         = dict()        # str/domain -> ClockDomain
-        self._domain_triggers = ValueDict()   # Signal -> str/domain
+        self._domain_triggers = SignalDict()  # Signal -> str/domain
         self._domain_signals  = dict()        # str/domain -> {Signal}
 
-        self._signals         = ValueSet()    # {Signal}
-        self._comb_signals    = ValueSet()    # {Signal}
-        self._sync_signals    = ValueSet()    # {Signal}
-        self._user_signals    = ValueSet()    # {Signal}
+        self._signals         = SignalSet()    # {Signal}
+        self._comb_signals    = SignalSet()    # {Signal}
+        self._sync_signals    = SignalSet()    # {Signal}
+        self._user_signals    = SignalSet()    # {Signal}
 
         self._started         = False
         self._timestamp       = 0.
@@ -306,12 +306,12 @@ class Simulator:
         self._wait_deadline   = dict()        # process -> float/timestamp
         self._wait_tick       = dict()        # process -> str/domain
 
-        self._funclets        = ValueDict()   # Signal -> set(lambda)
+        self._funclets        = SignalDict()  # Signal -> set(lambda)
 
         self._vcd_file        = vcd_file
         self._vcd_writer      = None
-        self._vcd_signals     = ValueDict()   # signal -> set(vcd_signal)
-        self._vcd_names       = ValueDict()   # signal -> str/name
+        self._vcd_signals     = SignalDict()  # signal -> set(vcd_signal)
+        self._vcd_names       = SignalDict()  # signal -> str/name
         self._gtkw_file       = gtkw_file
         self._traces          = traces
 
@@ -381,7 +381,7 @@ class Simulator:
             self._domain_triggers[cd.clk] = domain
             if cd.rst is not None:
                 self._domain_triggers[cd.rst] = domain
-            self._domain_signals[domain] = ValueSet()
+            self._domain_signals[domain] = SignalSet()
 
         hierarchy = {}
         def add_fragment(fragment, scope=()):
index 70b9596a4217925ea2dc4058ab3d49ffc2b7f2e9..4ab2d8c649e5ad9eae4a6f289a5a4032376238e5 100644 (file)
@@ -213,9 +213,9 @@ class LegalizeValue(Exception):
 class _ValueCompilerState:
     def __init__(self, rtlil):
         self.rtlil  = rtlil
-        self.wires  = ast.ValueDict()
-        self.driven = ast.ValueDict()
-        self.ports  = ast.ValueDict()
+        self.wires  = ast.SignalDict()
+        self.driven = ast.SignalDict()
+        self.ports  = ast.SignalDict()
 
         self.expansions = ast.ValueDict()
 
index f9c111b936c4e4cf6544fd2ecf74b6db9bb6771f..33a3625fa61e8a76279e9213332cef801ceac48d 100644 (file)
@@ -13,7 +13,7 @@ __all__ = [
     "Array", "ArrayProxy",
     "Signal", "ClockSignal", "ResetSignal",
     "Statement", "Assign", "Switch", "Delay", "Tick", "Passive",
-    "ValueKey", "ValueDict", "ValueSet",
+    "ValueKey", "ValueDict", "ValueSet", "SignalKey", "SignalDict", "SignalSet",
 ]
 
 
@@ -28,14 +28,14 @@ class DUID:
 class Value(metaclass=ABCMeta):
     @staticmethod
     def wrap(obj):
-        """Ensures that the passed object is a Migen value. Booleans and integers
+        """Ensures that the passed object is an nMigen value. Booleans and integers
         are automatically wrapped into ``Const``."""
         if isinstance(obj, Value):
             return obj
         elif isinstance(obj, (bool, int)):
             return Const(obj)
         else:
-            raise TypeError("Object '{!r}' is not a Migen value".format(obj))
+            raise TypeError("Object '{!r}' is not an nMigen value".format(obj))
 
     def __init__(self, src_loc_at=0):
         super().__init__()
@@ -47,7 +47,7 @@ class Value(metaclass=ABCMeta):
             self.src_loc = (tb[0].filename, tb[0].lineno)
 
     def __bool__(self):
-        raise TypeError("Attempted to convert Migen value to boolean")
+        raise TypeError("Attempted to convert nMigen value to boolean")
 
     def __invert__(self):
         return Operator("~", [self])
@@ -801,7 +801,7 @@ class Statement:
             if isinstance(obj, Statement):
                 return _StatementList([obj])
             else:
-                raise TypeError("Object '{!r}' is not a Migen statement".format(obj))
+                raise TypeError("Object '{!r}' is not an nMigen statement".format(obj))
 
 
 class Assign(Statement):
@@ -936,7 +936,8 @@ class _MappedKeyDict(MutableMapping, _MappedKeyCollection):
 
     def __repr__(self):
         pairs = ["({!r}, {!r})".format(k, v) for k, v in self.items()]
-        return "{}([{}])".format(type(self).__name__, ", ".join(pairs))
+        return "{}.{}([{}])".format(type(self).__module__, type(self).__name__,
+                                    ", ".join(pairs))
 
 
 class _MappedKeySet(MutableSet, _MappedKeyCollection):
@@ -967,7 +968,8 @@ class _MappedKeySet(MutableSet, _MappedKeyCollection):
         return len(self._storage)
 
     def __repr__(self):
-        return "{}({})".format(type(self).__name__, ", ".join(repr(x) for x in self))
+        return "{}.{}({})".format(type(self).__module__, type(self).__name__,
+                                  ", ".join(repr(x) for x in self))
 
 
 class ValueKey:
@@ -1060,3 +1062,34 @@ class ValueDict(_MappedKeyDict):
 class ValueSet(_MappedKeySet):
     _map_key   = ValueKey
     _unmap_key = lambda self, key: key.value
+
+
+class SignalKey:
+    def __init__(self, signal):
+        if not isinstance(signal, Signal):
+            raise TypeError("Object '{!r}' is not an nMigen signal")
+        self.signal = signal
+
+    def __hash__(self):
+        return hash(self.signal.duid)
+
+    def __eq__(self, other):
+        return isinstance(other, SignalKey) and self.signal.duid == other.signal.duid
+
+    def __lt__(self, other):
+        if not isinstance(other, SignalKey):
+            raise TypeError("Object '{!r}' cannot be compared to a SignalKey")
+        return self.signal.duid < other.signal.duid
+
+    def __repr__(self):
+        return "<{}.SignalKey {!r}>".format(__name__, self.signal)
+
+
+class SignalDict(_MappedKeyDict):
+    _map_key   = SignalKey
+    _unmap_key = lambda self, key: key.signal
+
+
+class SignalSet(_MappedKeySet):
+    _map_key   = SignalKey
+    _unmap_key = lambda self, key: key.signal
index 7a386dbcc3382559e4bec900b026b29815b71018..86274c59b244a761aa827940c914f3683bf81d53 100644 (file)
@@ -102,7 +102,7 @@ class Module(_ModuleBuilderRoot):
         self._ctrl_context = None
         self._ctrl_stack   = []
 
-        self._driving      = ValueDict()
+        self._driving      = SignalDict()
         self._submodules   = []
         self._domains      = []
 
index 1ff1961749559aac6904a57062043c59831e83fd..b9dd84defb111af8cbe82a65c0602a39ec009ed0 100644 (file)
@@ -15,7 +15,7 @@ class DriverConflict(UserWarning):
 
 class Fragment:
     def __init__(self):
-        self.ports = ValueDict()
+        self.ports = SignalDict()
         self.drivers = OrderedDict()
         self.statements = []
         self.domains = OrderedDict()
@@ -31,7 +31,7 @@ class Fragment:
 
     def add_driver(self, signal, domain=None):
         if domain not in self.drivers:
-            self.drivers[domain] = ValueSet()
+            self.drivers[domain] = SignalSet()
         self.drivers[domain].add(signal)
 
     def iter_drivers(self):
@@ -51,7 +51,7 @@ class Fragment:
                 yield domain, signal
 
     def iter_signals(self):
-        signals = ValueSet()
+        signals = SignalSet()
         signals |= self.ports.keys()
         for domain, domain_signals in self.drivers.items():
             if domain is not None:
@@ -81,7 +81,7 @@ class Fragment:
     def _resolve_driver_conflicts(self, hierarchy=("top",), mode="warn"):
         assert mode in ("silent", "warn", "error")
 
-        driver_subfrags = ValueDict()
+        driver_subfrags = SignalDict()
 
         # For each signal driven by this fragment and/or its subfragments, determine which
         # subfragments also drive it.
@@ -147,7 +147,7 @@ class Fragment:
             return self._resolve_driver_conflicts(hierarchy, mode)
 
         # Nothing was flattened, we're done!
-        return ValueSet(driver_subfrags.keys())
+        return SignalSet(driver_subfrags.keys())
 
     def _propagate_domains_up(self, hierarchy=("top",)):
         from .xfrm import DomainRenamer
@@ -229,8 +229,8 @@ class Fragment:
     def _propagate_ports(self, ports):
         # Collect all signals we're driving (on LHS of statements), and signals we're using
         # (on RHS of statements, or in clock domains).
-        self_driven = union(s._lhs_signals() for s in self.statements) or ValueSet()
-        self_used   = union(s._rhs_signals() for s in self.statements) or ValueSet()
+        self_driven = union(s._lhs_signals() for s in self.statements) or SignalSet()
+        self_used   = union(s._rhs_signals() for s in self.statements) or SignalSet()
         for domain, _ in self.iter_sync():
             cd = self.domains[domain]
             self_used.add(cd.clk)
index f5fec900b02dcf1b326419ac338abb5543d70dd7..9a87c5e9042425f5beccec640f79bbfd03a3ce81 100644 (file)
@@ -369,7 +369,7 @@ class DSLTestCase(FHDLTestCase):
         )
         """)
         self.assertEqual(f1.drivers, {
-            None: ValueSet((self.c1,))
+            None: SignalSet((self.c1,))
         })
         self.assertEqual(len(f1.subfragments), 1)
         (f2, f2_name), = f1.subfragments
@@ -381,7 +381,7 @@ class DSLTestCase(FHDLTestCase):
         )
         """)
         self.assertEqual(f2.drivers, {
-            None: ValueSet((self.c2,)),
-            "sync": ValueSet((self.c3,))
+            None: SignalSet((self.c2,)),
+            "sync": SignalSet((self.c3,))
         })
         self.assertEqual(len(f2.subfragments), 0)
index fcc1b2e945ac10a2660f505fb96b20e4e1aad248..dd39ac8be403b7c55ee86540414b962b8320213a 100644 (file)
@@ -25,12 +25,12 @@ class FragmentPortsTestCase(FHDLTestCase):
         self.assertEqual(list(f.iter_ports()), [])
 
         f._propagate_ports(ports=())
-        self.assertEqual(f.ports, ValueDict([]))
+        self.assertEqual(f.ports, SignalDict([]))
 
     def test_iter_signals(self):
         f = Fragment()
         f.add_ports(self.s1, self.s2, kind="io")
-        self.assertEqual(ValueSet((self.s1, self.s2)), f.iter_signals())
+        self.assertEqual(SignalSet((self.s1, self.s2)), f.iter_signals())
 
     def test_self_contained(self):
         f = Fragment()
@@ -40,7 +40,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
 
         f._propagate_ports(ports=())
-        self.assertEqual(f.ports, ValueDict([]))
+        self.assertEqual(f.ports, SignalDict([]))
 
     def test_infer_input(self):
         f = Fragment()
@@ -49,7 +49,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
 
         f._propagate_ports(ports=())
-        self.assertEqual(f.ports, ValueDict([
+        self.assertEqual(f.ports, SignalDict([
             (self.s1, "i")
         ]))
 
@@ -60,7 +60,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
 
         f._propagate_ports(ports=(self.c1,))
-        self.assertEqual(f.ports, ValueDict([
+        self.assertEqual(f.ports, SignalDict([
             (self.s1, "i"),
             (self.c1, "o")
         ]))
@@ -76,8 +76,8 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
         f1.add_subfragment(f2)
         f1._propagate_ports(ports=())
-        self.assertEqual(f1.ports, ValueDict())
-        self.assertEqual(f2.ports, ValueDict([
+        self.assertEqual(f1.ports, SignalDict())
+        self.assertEqual(f2.ports, SignalDict([
             (self.s1, "o"),
         ]))
 
@@ -89,10 +89,10 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
         f1.add_subfragment(f2)
         f1._propagate_ports(ports=())
-        self.assertEqual(f1.ports, ValueDict([
+        self.assertEqual(f1.ports, SignalDict([
             (self.s1, "i"),
         ]))
-        self.assertEqual(f2.ports, ValueDict([
+        self.assertEqual(f2.ports, SignalDict([
             (self.s1, "i"),
         ]))
 
@@ -108,10 +108,10 @@ class FragmentPortsTestCase(FHDLTestCase):
         f1.add_subfragment(f2)
 
         f1._propagate_ports(ports=(self.c2,))
-        self.assertEqual(f1.ports, ValueDict([
+        self.assertEqual(f1.ports, SignalDict([
             (self.c2, "o"),
         ]))
-        self.assertEqual(f2.ports, ValueDict([
+        self.assertEqual(f2.ports, SignalDict([
             (self.c2, "o"),
         ]))
 
@@ -125,7 +125,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_driver(self.c1, "sync")
 
         f._propagate_ports(ports=())
-        self.assertEqual(f.ports, ValueDict([
+        self.assertEqual(f.ports, SignalDict([
             (self.s1,  "i"),
             (sync.clk, "i"),
             (sync.rst, "i"),
@@ -141,7 +141,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_driver(self.c1, "sync")
 
         f._propagate_ports(ports=())
-        self.assertEqual(f.ports, ValueDict([
+        self.assertEqual(f.ports, SignalDict([
             (self.s1,  "i"),
             (sync.clk, "i"),
         ]))
@@ -157,9 +157,9 @@ class FragmentDomainsTestCase(FHDLTestCase):
         f = Fragment()
         f.add_domains(cd1, cd2)
         f.add_driver(s1, "cd1")
-        self.assertEqual(ValueSet((cd1.clk, cd1.rst, s1)), f.iter_signals())
+        self.assertEqual(SignalSet((cd1.clk, cd1.rst, s1)), f.iter_signals())
         f.add_driver(s2, "cd2")
-        self.assertEqual(ValueSet((cd1.clk, cd1.rst, cd2.clk, s1, s2)), f.iter_signals())
+        self.assertEqual(SignalSet((cd1.clk, cd1.rst, cd2.clk, s1, s2)), f.iter_signals())
 
     def test_propagate_up(self):
         cd = ClockDomain()
@@ -315,8 +315,8 @@ class FragmentDriverConflictTestCase(FHDLTestCase):
         )
         """)
         self.assertEqual(self.f1.drivers, {
-            None:   ValueSet((self.s1,)),
-            "sync": ValueSet((self.c1, self.c2)),
+            None:   SignalSet((self.s1,)),
+            "sync": SignalSet((self.c1, self.c2)),
         })
 
     def test_conflict_self_sub_error(self):
index 428bad78c6baae12844cbd9ff8b19f312b33973f..802761dd71ed81abbaab3581b9e5a328502ed1b8 100644 (file)
@@ -38,8 +38,8 @@ class DomainRenamerTestCase(FHDLTestCase):
         )
         """)
         self.assertEqual(f.drivers, {
-            None: ValueSet((self.s1, self.s2)),
-            "pix": ValueSet((self.s3,)),
+            None: SignalSet((self.s1, self.s2)),
+            "pix": SignalSet((self.s3,)),
         })
 
     def test_rename_multi(self):