fhdl.ir: implement clock domain propagation.
authorwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 11:01:03 +0000 (11:01 +0000)
committerwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 11:01:03 +0000 (11:01 +0000)
12 files changed:
examples/arst.py
examples/cdc.py
examples/clkdiv.py
examples/ctrl.py
nmigen/back/rtlil.py
nmigen/fhdl/ir.py
nmigen/fhdl/xfrm.py
nmigen/genlib/cdc.py
nmigen/test/test_fhdl_dsl.py
nmigen/test/test_fhdl_ir.py
nmigen/test/test_fhdl_xfrm.py
nmigen/test/tools.py

index d6fcd04da0ed46ee196de2da189f7c960f99d07c..c35640ed10a16029b0fa3f1e857789f4277b2a3b 100644 (file)
@@ -14,8 +14,8 @@ class ClockDivisor:
         return m.lower(platform)
 
 
-sync = ClockDomain(async_reset=True)
 ctr  = ClockDivisor(factor=16)
 frag = ctr.get_fragment(platform=None)
-# print(rtlil.convert(frag, ports=[sync.clk, sync.reset, ctr.o], clock_domains={"sync": sync}))
-print(verilog.convert(frag, ports=[sync.clk, sync.reset, ctr.o], clock_domains={"sync": sync}))
+frag.add_domains(ClockDomain("sync", async_reset=True))
+# print(rtlil.convert(frag, ports=[ctr.o]))
+print(verilog.convert(frag, ports=[ctr.o]))
index caaf679c0fe0609a4ff92208f65f3721e84c99ce..7e38885334a90c8d78da927409631fa9ff367959 100644 (file)
@@ -3,8 +3,7 @@ from nmigen.back import rtlil, verilog
 from nmigen.genlib.cdc import *
 
 
-sys  = ClockDomain()
 i, o = Signal(name="i"), Signal(name="o")
 frag = MultiReg(i, o).get_fragment(platform=None)
-# print(rtlil.convert(frag, ports=[i, o], clock_domains={"sys": sys}))
-print(verilog.convert(frag, ports=[i, o], clock_domains={"sys": sys}))
+# print(rtlil.convert(frag, ports=[i, o]))
+print(verilog.convert(frag, ports=[i, o]))
index 79eec933476d9c16fc6ece90b8094ef9d7c3d7c0..6900e1d1c5a2cd180ac6de70abf8bdc7c2aa8607 100644 (file)
@@ -14,8 +14,7 @@ class ClockDivisor:
         return m.lower(platform)
 
 
-sync = ClockDomain()
 ctr  = ClockDivisor(factor=16)
 frag = ctr.get_fragment(platform=None)
-# print(rtlil.convert(frag, ports=[sync.clk, ctr.o], clock_domains={"sync": sync}))
-print(verilog.convert(frag, ports=[sync.clk, ctr.o], clock_domains={"sync": sync}))
+# print(rtlil.convert(frag, ports=[ctr.o]))
+print(verilog.convert(frag, ports=[ctr.o]))
index 2d8fcf87121cdce35cf2b5acc2897a387b417cb2..17a9e9481a2c6d90ddc2dc4de24028a6a91a34ca 100644 (file)
@@ -15,8 +15,7 @@ class ClockDivisor:
         return CEInserter(self.ce)(m.lower(platform))
 
 
-sync = ClockDomain()
 ctr  = ClockDivisor(factor=16)
 frag = ctr.get_fragment(platform=None)
-# print(rtlil.convert(frag, ports=[sync.clk, ctr.o, ctr.ce], clock_domains={"sync": sync}))
-print(verilog.convert(frag, ports=[sync.clk, ctr.o, ctr.ce], clock_domains={"sync": sync}))
+# print(rtlil.convert(frag, ports=[ctr.o, ctr.ce]))
+print(verilog.convert(frag, ports=[ctr.o, ctr.ce]))
index c670e73cd84f29e37ccb53170553ca7788c72f7b..93e2d40e71c279a211206706d551f5b1582511c5 100644 (file)
@@ -396,7 +396,7 @@ class _ValueTransformer(xfrm.ValueTransformer):
         return "{{ {} }}".format(" ".join(self(node.value) for _ in range(node.count)))
 
 
-def convert_fragment(builder, fragment, name, top, clock_domains):
+def convert_fragment(builder, fragment, name, top):
     with builder.module(name, attrs={"top": 1} if top else {}) as module:
         xformer = _ValueTransformer(module)
 
@@ -413,7 +413,7 @@ def convert_fragment(builder, fragment, name, top, clock_domains):
         # Transform all clocks clocks and resets eagerly and outside of any hierarchy, to make
         # sure they get sensible (non-prefixed) names. This does not affect semantics.
         for domain, _ in fragment.iter_sync():
-            cd = clock_domains[domain]
+            cd = fragment.domains[domain]
             xformer(cd.clk)
             xformer(cd.rst)
 
@@ -422,8 +422,7 @@ def convert_fragment(builder, fragment, name, top, clock_domains):
         # name) names.
         for subfragment, sub_name in fragment.subfragments:
             sub_name, sub_port_map = \
-                convert_fragment(builder, subfragment, top=False, name=sub_name,
-                                 clock_domains=clock_domains)
+                convert_fragment(builder, subfragment, top=False, name=sub_name)
             with xformer.hierarchy(sub_name):
                 module.cell(sub_name, name=sub_name, ports={
                     p: xformer(s) for p, s in sub_port_map.items()
@@ -484,13 +483,11 @@ def convert_fragment(builder, fragment, name, top, clock_domains):
                 triggers = []
                 if domain is None:
                     triggers.append(("always",))
-                elif domain in clock_domains:
-                    cd = clock_domains[domain]
+                else:
+                    cd = fragment.domains[domain]
                     triggers.append(("posedge", xformer(cd.clk)))
                     if cd.async_reset:
                         triggers.append(("posedge", xformer(cd.rst)))
-                else:
-                    raise ValueError("Clock domain {} not found in design".format(domain))
 
                 for trigger in triggers:
                     with process.sync(*trigger) as sync:
@@ -509,15 +506,17 @@ def convert_fragment(builder, fragment, name, top, clock_domains):
     return module.name, port_map
 
 
-def convert(fragment, ports=[], clock_domains={}):
+def convert(fragment, ports=[]):
+    fragment._propagate_domains(ensure_sync_exists=True)
+
     # Clock domain reset always takes priority over all other logic. To ensure this, insert
     # decision trees for clock domain reset as the very last step before synthesis.
     fragment = xfrm.ResetInserter({
-        cd.name: cd.rst for cd in clock_domains.values() if cd.rst is not None
+        cd.name: cd.rst for cd in fragment.domains.values() if cd.rst is not None
     })(fragment)
 
-    ins, outs = fragment._propagate_ports(ports, clock_domains)
+    ins, outs = fragment._propagate_ports(ports)
 
     builder = _Builder()
-    convert_fragment(builder, fragment, name="top", top=True, clock_domains=clock_domains)
+    convert_fragment(builder, fragment, name="top", top=True)
     return str(builder)
index 8b838c9d00c88aefd19c3ff86a3afc653310cd0f..0945fcb5b4d85af3ba2dcb6b4af3ed526d663d3a 100644 (file)
@@ -2,9 +2,14 @@ from collections import defaultdict, OrderedDict
 
 from ..tools import *
 from .ast import *
+from .cd import *
 
 
-__all__ = ["Fragment"]
+__all__ = ["Fragment", "DomainError"]
+
+
+class DomainError(Exception):
+    pass
 
 
 class Fragment:
@@ -12,6 +17,7 @@ class Fragment:
         self.ports = ValueSet()
         self.drivers = OrderedDict()
         self.statements = []
+        self.domains = OrderedDict()
         self.subfragments = []
 
     def add_ports(self, *ports):
@@ -40,6 +46,15 @@ class Fragment:
             for signal in signals:
                 yield domain, signal
 
+    def add_domains(self, *domains):
+        for domain in domains:
+            assert isinstance(domain, ClockDomain)
+            assert domain.name not in self.domains
+            self.domains[domain.name] = domain
+
+    def iter_domains(self):
+        yield from self.domains
+
     def add_statements(self, *stmts):
         self.statements += Statement.wrap(stmts)
 
@@ -47,13 +62,79 @@ class Fragment:
         assert isinstance(subfragment, Fragment)
         self.subfragments.append((subfragment, name))
 
-    def _propagate_ports(self, ports, clock_domains={}):
+    def _propagate_domains_up(self, hierarchy=("top",)):
+        from .xfrm import DomainRenamer
+
+        domain_subfrags = defaultdict(lambda: set())
+
+        # For each domain defined by a subfragment, determine which subfragments define it.
+        for i, (subfrag, name) in enumerate(self.subfragments):
+            # First, recurse into subfragments and let them propagate domains up as well.
+            hier_name = name
+            if hier_name is None:
+                hier_name = "<unnamed #{}>".format(i)
+            subfrag._propagate_domains_up(hierarchy + (hier_name,))
+
+            # Second, classify subfragments by domains they define.
+            for domain in subfrag.iter_domains():
+                domain_subfrags[domain].add((subfrag, name, i))
+
+        # For each domain defined by more than one subfragment, rename the domain in each
+        # of the subfragments such that they no longer conflict.
+        for domain, subfrags in domain_subfrags.items():
+            if len(subfrags) == 1:
+                continue
+
+            names = [n for f, n, i in subfrags]
+            if not all(names):
+                names = sorted("<unnamed #{}>".format(i) if n is None else "'{}'".format(n)
+                               for f, n, i in subfrags)
+                raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}'; "
+                                  "it is necessary to either rename subfragment domains "
+                                  "explicitly, or give names to subfragments"
+                                  .format(domain, ", ".join(names), ".".join(hierarchy)))
+
+            if len(names) != len(set(names)):
+                names = sorted("#{}".format(i) for f, n, i in subfrags)
+                raise DomainError("Domain '{}' is defined by subfragments {} of fragment '{}', "
+                                  "some of which have identical names; it is necessary to either "
+                                  "rename subfragment domains explicitly, or give distinct names "
+                                  "to subfragments"
+                                  .format(domain, ", ".join(names), ".".join(hierarchy)))
+
+            for subfrag, name, i in subfrags:
+                self.subfragments[i] = \
+                    (DomainRenamer({domain: "{}_{}".format(name, domain)})(subfrag), name)
+
+        # Finally, collect the (now unique) subfragment domains, and merge them into our domains.
+        for subfrag, name in self.subfragments:
+            for domain in subfrag.iter_domains():
+                self.add_domains(subfrag.domains[domain])
+
+    def _propagate_domains_down(self):
+        # For each domain defined in this fragment, ensure it also exists in all subfragments.
+        for subfrag, name in self.subfragments:
+            for domain in self.iter_domains():
+                if domain in subfrag.domains:
+                    assert self.domains[domain] is subfrag.domains[domain]
+                else:
+                    subfrag.add_domains(self.domains[domain])
+
+            subfrag._propagate_domains_down()
+
+    def _propagate_domains(self, ensure_sync_exists=False):
+        self._propagate_domains_up()
+        if ensure_sync_exists and not self.domains:
+            self.add_domains(ClockDomain("sync"))
+        self._propagate_domains_down()
+
+    def _propagate_ports(self, ports):
         # Collect all signals we're driving (on LHS of statements), and signals we're using
         # (on RHS of statements, or in clock domains).
         self_driven = union(s._lhs_signals() for s in self.statements)
         self_used   = union(s._rhs_signals() for s in self.statements)
         for domain, _ in self.iter_sync():
-            cd = clock_domains[domain]
+            cd = self.domains[domain]
             self_used.add(cd.clk)
             if cd.rst is not None:
                 self_used.add(cd.rst)
@@ -69,8 +150,7 @@ class Fragment:
         for subfrag, name in self.subfragments:
             # Always ask subfragments to provide all signals we're using and signals we're asked
             # to provide. If the subfragment is not driving it, it will silently ignore it.
-            sub_ins, sub_outs = subfrag._propagate_ports(ports=self_used | ports,
-                                                         clock_domains=clock_domains)
+            sub_ins, sub_outs = subfrag._propagate_ports(ports=self_used | ports)
             # Refine the input port approximation: if a subfragment is driving a signal,
             # it is definitely not our input.
             ins  -= sub_outs
index 07c454d435a2d9719da9a9eda11ad5d93b75e4c7..86268f62cda50813a5ec8b3f55befea5aa9c7f7e 100644 (file)
@@ -56,7 +56,7 @@ class ValueTransformer:
         elif isinstance(value, Repl):
             return self.on_Repl(value)
         else:
-            raise TypeError("Cannot transform value {!r}".format(value))
+            raise TypeError("Cannot transform value {!r}".format(value)) # :nocov:
 
     def __call__(self, value):
         return self.on_value(value)
@@ -84,7 +84,7 @@ class StatementTransformer:
         elif isinstance(stmt, (list, tuple)):
             return self.on_statements(stmt)
         else:
-            raise TypeError("Cannot transform statement {!r}".format(stmt))
+            raise TypeError("Cannot transform statement {!r}".format(stmt)) # :nocov:
 
     def __call__(self, value):
         return self.on_statement(value)
@@ -95,6 +95,10 @@ class FragmentTransformer:
         for subfragment, name in fragment.subfragments:
             new_fragment.add_subfragment(self(subfragment), name)
 
+    def map_domains(self, fragment, new_fragment):
+        for domain in fragment.iter_domains():
+            new_fragment.add_domains(fragment.domains[domain])
+
     def map_statements(self, fragment, new_fragment):
         if hasattr(self, "on_statement"):
             new_fragment.add_statements(map(self.on_statement, fragment.statements))
@@ -108,6 +112,7 @@ class FragmentTransformer:
     def on_fragment(self, fragment):
         new_fragment = Fragment()
         self.map_subfragments(fragment, new_fragment)
+        self.map_domains(fragment, new_fragment)
         self.map_statements(fragment, new_fragment)
         self.map_drivers(fragment, new_fragment)
         return new_fragment
@@ -117,25 +122,36 @@ class FragmentTransformer:
 
 
 class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer):
-    def __init__(self, domains):
-        if isinstance(domains, str):
-            domains = {"sync": domains}
-        self.domains = OrderedDict(domains)
+    def __init__(self, domain_map):
+        if isinstance(domain_map, str):
+            domain_map = {"sync": domain_map}
+        self.domain_map = OrderedDict(domain_map)
 
     def on_ClockSignal(self, value):
-        if value.domain in self.domains:
-            return ClockSignal(self.domains[value.domain])
+        if value.domain in self.domain_map:
+            return ClockSignal(self.domain_map[value.domain])
         return value
 
     def on_ResetSignal(self, value):
-        if value.domain in self.domains:
-            return ResetSignal(self.domains[value.domain])
+        if value.domain in self.domain_map:
+            return ResetSignal(self.domain_map[value.domain])
         return value
 
+    def map_domains(self, fragment, new_fragment):
+        for domain in fragment.iter_domains():
+            cd = fragment.domains[domain]
+            if domain in self.domain_map:
+                if cd.name == domain:
+                    # Rename the actual ClockDomain object.
+                    cd.name = self.domain_map[domain]
+                else:
+                    assert cd.name == self.domain_map[domain]
+            new_fragment.add_domains(cd)
+
     def map_drivers(self, fragment, new_fragment):
         for domain, signals in fragment.drivers.items():
-            if domain in self.domains:
-                domain = self.domains[domain]
+            if domain in self.domain_map:
+                domain = self.domain_map[domain]
             for signal in signals:
                 new_fragment.drive(signal, domain)
 
index 411296ea76a8529b84f79a00b3fb574db4a5d211..396e45a23b17a29faeb874024e4cc7d8f387044b 100644 (file)
@@ -5,13 +5,13 @@ __all__ = ["MultiReg"]
 
 
 class MultiReg:
-    def __init__(self, i, o, odomain="sys", n=2, reset=0):
+    def __init__(self, i, o, odomain="sync", n=2, reset=0):
         self.i = i
         self.o = o
         self.odomain = odomain
 
-        self._regs = [Signal(self.i.shape(), name="cdc{}".format(i), reset=reset, reset_less=True,
-                             attrs={"no_retiming": True})
+        self._regs = [Signal(self.i.shape(), name="cdc{}".format(i),
+                             reset=reset, reset_less=True, attrs={"no_retiming": True})
                       for i in range(n)]
 
     def get_fragment(self, platform):
index b4ad259a486aefa5a37fec85288fbb937cce34bc..c0acaed2eb40e907ed9b39baa10f9ba328a0db60 100644 (file)
@@ -1,5 +1,3 @@
-from contextlib import contextmanager
-
 from ..fhdl.ast import *
 from ..fhdl.dsl import *
 from .tools import *
index 708acec9af12e7932e2af1ea1a18e57346b255cf..86cd55ad38b221b88f83f73f81aed3a76334a7bc 100644 (file)
@@ -1,4 +1,5 @@
 from ..fhdl.ast import *
+from ..fhdl.cd import *
 from ..fhdl.ir import *
 from .tools import *
 
@@ -18,6 +19,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.c1.eq(self.s1),
             self.s1.eq(self.c1)
         )
+
         ins, outs = f._propagate_ports(ports=())
         self.assertEqual(ins, ValueSet())
         self.assertEqual(outs, ValueSet())
@@ -28,6 +30,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_statements(
             self.c1.eq(self.s1)
         )
+
         ins, outs = f._propagate_ports(ports=())
         self.assertEqual(ins, ValueSet((self.s1,)))
         self.assertEqual(outs, ValueSet())
@@ -38,6 +41,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_statements(
             self.c1.eq(self.s1)
         )
+
         ins, outs = f._propagate_ports(ports=(self.c1,))
         self.assertEqual(ins, ValueSet((self.s1,)))
         self.assertEqual(outs, ValueSet((self.c1,)))
@@ -69,8 +73,150 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.c2.eq(1)
         )
         f1.add_subfragment(f2)
+
         ins, outs = f1._propagate_ports(ports=(self.c2,))
         self.assertEqual(ins, ValueSet())
         self.assertEqual(outs, ValueSet((self.c2,)))
         self.assertEqual(f1.ports, ValueSet((self.c2,)))
         self.assertEqual(f2.ports, ValueSet((self.c2,)))
+
+    def test_input_cd(self):
+        sync = ClockDomain()
+        f = Fragment()
+        f.add_statements(
+            self.c1.eq(self.s1)
+        )
+        f.add_domains(sync)
+        f.drive(self.c1, "sync")
+
+        ins, outs = f._propagate_ports(ports=())
+        self.assertEqual(ins, ValueSet((self.s1, sync.clk, sync.rst)))
+        self.assertEqual(outs, ValueSet(()))
+        self.assertEqual(f.ports, ValueSet((self.s1, sync.clk, sync.rst)))
+
+    def test_input_cd_reset_less(self):
+        sync = ClockDomain(reset_less=True)
+        f = Fragment()
+        f.add_statements(
+            self.c1.eq(self.s1)
+        )
+        f.add_domains(sync)
+        f.drive(self.c1, "sync")
+
+        ins, outs = f._propagate_ports(ports=())
+        self.assertEqual(ins, ValueSet((self.s1, sync.clk)))
+        self.assertEqual(outs, ValueSet(()))
+        self.assertEqual(f.ports, ValueSet((self.s1, sync.clk)))
+
+
+class FragmentDomainsTestCase(FHDLTestCase):
+    def test_propagate_up(self):
+        cd = ClockDomain()
+
+        f1 = Fragment()
+        f2 = Fragment()
+        f1.add_subfragment(f2)
+        f2.add_domains(cd)
+
+        f1._propagate_domains_up()
+        self.assertEqual(f1.domains, {"cd": cd})
+
+    def test_domain_conflict(self):
+        cda = ClockDomain("sync")
+        cdb = ClockDomain("sync")
+
+        fa = Fragment()
+        fa.add_domains(cda)
+        fb = Fragment()
+        fb.add_domains(cdb)
+        f = Fragment()
+        f.add_subfragment(fa, "a")
+        f.add_subfragment(fb, "b")
+
+        f._propagate_domains_up()
+        self.assertEqual(f.domains, {"a_sync": cda, "b_sync": cdb})
+        (fa, _), (fb, _) = f.subfragments
+        self.assertEqual(fa.domains, {"a_sync": cda})
+        self.assertEqual(fb.domains, {"b_sync": cdb})
+
+    def test_domain_conflict_anon(self):
+        cda = ClockDomain("sync")
+        cdb = ClockDomain("sync")
+
+        fa = Fragment()
+        fa.add_domains(cda)
+        fb = Fragment()
+        fb.add_domains(cdb)
+        f = Fragment()
+        f.add_subfragment(fa, "a")
+        f.add_subfragment(fb)
+
+        with self.assertRaises(DomainError,
+                msg="Domain 'sync' is defined by subfragments 'a', <unnamed #1> of fragment "
+                    "'top'; it is necessary to either rename subfragment domains explicitly, "
+                    "or give names to subfragments"):
+            f._propagate_domains_up()
+
+    def test_domain_conflict_name(self):
+        cda = ClockDomain("sync")
+        cdb = ClockDomain("sync")
+
+        fa = Fragment()
+        fa.add_domains(cda)
+        fb = Fragment()
+        fb.add_domains(cdb)
+        f = Fragment()
+        f.add_subfragment(fa, "x")
+        f.add_subfragment(fb, "x")
+
+        with self.assertRaises(DomainError,
+                msg="Domain 'sync' is defined by subfragments #0, #1 of fragment 'top', some "
+                    "of which have identical names; it is necessary to either rename subfragment "
+                    "domains explicitly, or give distinct names to subfragments"):
+            f._propagate_domains_up()
+
+    def test_propagate_down(self):
+        cd = ClockDomain()
+
+        f1 = Fragment()
+        f2 = Fragment()
+        f1.add_domains(cd)
+        f1.add_subfragment(f2)
+
+        f1._propagate_domains_down()
+        self.assertEqual(f2.domains, {"cd": cd})
+
+    def test_propagate_down_idempotent(self):
+        cd = ClockDomain()
+
+        f1 = Fragment()
+        f1.add_domains(cd)
+        f2 = Fragment()
+        f2.add_domains(cd)
+        f1.add_subfragment(f2)
+
+        f1._propagate_domains_down()
+        self.assertEqual(f1.domains, {"cd": cd})
+        self.assertEqual(f2.domains, {"cd": cd})
+
+    def test_propagate(self):
+        cd = ClockDomain()
+
+        f1 = Fragment()
+        f2 = Fragment()
+        f1.add_domains(cd)
+        f1.add_subfragment(f2)
+
+        f1._propagate_domains()
+        self.assertEqual(f1.domains, {"cd": cd})
+        self.assertEqual(f2.domains, {"cd": cd})
+
+    def test_propagate_default(self):
+        f1 = Fragment()
+        f2 = Fragment()
+        f1.add_subfragment(f2)
+
+        f1._propagate_domains(ensure_sync_exists=True)
+        self.assertEqual(f1.domains.keys(), {"sync"})
+        self.assertEqual(f2.domains.keys(), {"sync"})
+        self.assertEqual(f1.domains["sync"], f2.domains["sync"])
index 0e1d3fae96adafdaac481153099193c0b16f6fa0..78346d8aacbd108e94d70012c4e3bdaf8b4d0429 100644 (file)
@@ -1,4 +1,5 @@
 from ..fhdl.ast import *
+from ..fhdl.cd import *
 from ..fhdl.ir import *
 from ..fhdl.xfrm import *
 from .tools import *
@@ -56,6 +57,37 @@ class DomainRenamerTestCase(FHDLTestCase):
         )
         """)
 
+    def test_rename_cd(self):
+        cd_sync = ClockDomain()
+        cd_pix  = ClockDomain()
+
+        f = Fragment()
+        f.add_domains(cd_sync, cd_pix)
+
+        f = DomainRenamer("ext")(f)
+        self.assertEqual(cd_sync.name, "ext")
+        self.assertEqual(f.domains, {
+            "ext": cd_sync,
+            "pix": cd_pix,
+        })
+
+    def test_rename_cd_subfragment(self):
+        cd_sync = ClockDomain()
+        cd_pix  = ClockDomain()
+
+        f1 = Fragment()
+        f1.add_domains(cd_sync, cd_pix)
+        f2 = Fragment()
+        f2.add_domains(cd_sync)
+        f1.add_subfragment(f2)
+
+        f1 = DomainRenamer("ext")(f1)
+        self.assertEqual(cd_sync.name, "ext")
+        self.assertEqual(f1.domains, {
+            "ext": cd_sync,
+            "pix": cd_pix,
+        })
+
 
 class ResetInserterTestCase(FHDLTestCase):
     def setUp(self):
@@ -87,6 +119,7 @@ class ResetInserterTestCase(FHDLTestCase):
             self.s1.eq(1),
             self.s2.eq(0),
         )
+        f.add_domains(ClockDomain("sync"))
         f.drive(self.s1, "sync")
         f.drive(self.s2, "pix")
 
index 9e3a8f0cbf047e0902cf9279cae6730bf7bbfa25..65cf0ff797ba235cb0f5a8066e0370f8fc32dcb3 100644 (file)
@@ -1,5 +1,6 @@
 import re
 import unittest
+from contextlib import contextmanager
 
 from ..fhdl.ast import *
 
@@ -14,3 +15,11 @@ class FHDLTestCase(unittest.TestCase):
         repr_str = re.sub(r"\( (?=\()", "(", repr_str)
         repr_str = re.sub(r"\) (?=\))", ")", repr_str)
         self.assertEqual(repr(obj), repr_str.strip())
+
+    @contextmanager
+    def assertRaises(self, exception, msg=None):
+        with super().assertRaises(exception) as cm:
+            yield
+        if msg is not None:
+            # WTF? unittest.assertRaises is completely broken.
+            self.assertEqual(str(cm.exception), msg)