hdl.xfrm: consider fragment's own domains in DomainLowerer.
authorwhitequark <cz@m-labs.hk>
Mon, 19 Aug 2019 21:06:54 +0000 (21:06 +0000)
committerwhitequark <cz@m-labs.hk>
Mon, 19 Aug 2019 21:07:02 +0000 (21:07 +0000)
Changed in preparation for introducing local clock domains.

nmigen/hdl/ir.py
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_xfrm.py

index 6dd5d08ffa95676ce607eef3076b46cc5ff0f80d..79e308affd52c7c2c17974253e1d873ee9320d03 100644 (file)
@@ -395,7 +395,7 @@ class Fragment:
     def _lower_domain_signals(self):
         from .xfrm import DomainLowerer
 
-        return DomainLowerer(self.domains)(self)
+        return DomainLowerer()(self)
 
     def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
         def add_uses(*sigs, self=self):
index 4c96b7c31ebd263d6c49078ae01888ab3af3a22c..2fdf707e2dc9ea50c932396c1e326d45e6722929 100644 (file)
@@ -337,6 +337,11 @@ class TransformedElaboratable(Elaboratable):
 class DomainCollector(ValueVisitor, StatementVisitor):
     def __init__(self):
         self.domains = set()
+        self._local_domains = set()
+
+    def _add_domain(self, domain_name):
+        if domain_name not in self._local_domains:
+            self.domains.add(domain_name)
 
     def on_ignore(self, value):
         pass
@@ -347,10 +352,10 @@ class DomainCollector(ValueVisitor, StatementVisitor):
     on_Signal = on_ignore
 
     def on_ClockSignal(self, value):
-        self.domains.add(value.domain)
+        self._add_domain(value.domain)
 
     def on_ResetSignal(self, value):
-        self.domains.add(value.domain)
+        self._add_domain(value.domain)
 
     on_Record = on_ignore
 
@@ -406,11 +411,20 @@ class DomainCollector(ValueVisitor, StatementVisitor):
         if isinstance(fragment, Instance):
             for name, (value, dir) in fragment.named_ports.items():
                 self.on_value(value)
+
+        old_local_domains, self._local_domains = self._local_domains, set(self._local_domains)
+        for domain_name, domain in fragment.domains.items():
+            if domain.local:
+                self._local_domains.add(domain_name)
+
         self.on_statements(fragment.statements)
-        self.domains.update(fragment.drivers.keys())
+        for domain_name in fragment.drivers:
+            self._add_domain(domain_name)
         for subfragment, name in fragment.subfragments:
             self.on_fragment(subfragment)
 
+        self._local_domains = old_local_domains
+
     def __call__(self, fragment):
         self.on_fragment(fragment)
         return self.domains
@@ -457,8 +471,8 @@ class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer)
 
 
 class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
-    def __init__(self, domains):
-        self.domains = domains
+    def __init__(self):
+        self.domains = None
 
     def _resolve(self, domain, context):
         if domain not in self.domains:
@@ -487,6 +501,11 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
                                   .format(value, value.domain))
         return cd.rst
 
+    def on_fragment(self, fragment):
+        self.domains = fragment.domains
+        new_fragment = super().on_fragment(fragment)
+        return new_fragment
+
 
 class SampleDomainInjector(ValueTransformer, StatementTransformer):
     def __init__(self, domain):
index 4081dd4d6e30804f4604cdf44c02ecc6c12be094..1a2a6c83b7f8de409d5b2413d5717a27c9c8b1bc 100644 (file)
@@ -107,11 +107,12 @@ class DomainLowererTestCase(FHDLTestCase):
     def test_lower_clk(self):
         sync = ClockDomain()
         f = Fragment()
+        f.add_domains(sync)
         f.add_statements(
             self.s.eq(ClockSignal("sync"))
         )
 
-        f = DomainLowerer({"sync": sync})(f)
+        f = DomainLowerer()(f)
         self.assertRepr(f.statements, """
         (
             (eq (sig s) (sig clk))
@@ -121,11 +122,12 @@ class DomainLowererTestCase(FHDLTestCase):
     def test_lower_rst(self):
         sync = ClockDomain()
         f = Fragment()
+        f.add_domains(sync)
         f.add_statements(
             self.s.eq(ResetSignal("sync"))
         )
 
-        f = DomainLowerer({"sync": sync})(f)
+        f = DomainLowerer()(f)
         self.assertRepr(f.statements, """
         (
             (eq (sig s) (sig rst))
@@ -135,11 +137,12 @@ class DomainLowererTestCase(FHDLTestCase):
     def test_lower_rst_reset_less(self):
         sync = ClockDomain(reset_less=True)
         f = Fragment()
+        f.add_domains(sync)
         f.add_statements(
             self.s.eq(ResetSignal("sync", allow_reset_less=True))
         )
 
-        f = DomainLowerer({"sync": sync})(f)
+        f = DomainLowerer()(f)
         self.assertRepr(f.statements, """
         (
             (eq (sig s) (const 1'd0))
@@ -149,17 +152,17 @@ class DomainLowererTestCase(FHDLTestCase):
     def test_lower_drivers(self):
         pix = ClockDomain()
         f = Fragment()
+        f.add_domains(pix)
         f.add_driver(ClockSignal("pix"), None)
         f.add_driver(ResetSignal("pix"), "sync")
 
-        f = DomainLowerer({"pix": pix})(f)
+        f = DomainLowerer()(f)
         self.assertEqual(f.drivers, {
             None: SignalSet((pix.clk,)),
             "sync": SignalSet((pix.rst,))
         })
 
     def test_lower_wrong_domain(self):
-        sync = ClockDomain()
         f = Fragment()
         f.add_statements(
             self.s.eq(ClockSignal("xxx"))
@@ -167,18 +170,19 @@ class DomainLowererTestCase(FHDLTestCase):
 
         with self.assertRaises(DomainError,
                 msg="Signal (clk xxx) refers to nonexistent domain 'xxx'"):
-            DomainLowerer({"sync": sync})(f)
+            DomainLowerer()(f)
 
     def test_lower_wrong_reset_less_domain(self):
         sync = ClockDomain(reset_less=True)
         f = Fragment()
+        f.add_domains(sync)
         f.add_statements(
             self.s.eq(ResetSignal("sync"))
         )
 
         with self.assertRaises(DomainError,
                 msg="Signal (rst sync) refers to reset of reset-less domain 'sync'"):
-            DomainLowerer({"sync": sync})(f)
+            DomainLowerer()(f)
 
 
 class SampleLowererTestCase(FHDLTestCase):
@@ -600,7 +604,7 @@ class UserValueTestCase(FHDLTestCase):
             f.add_driver(signal, "sync")
 
         f = ResetInserter(self.c)(f)
-        f = DomainLowerer({"sync": sync})(f)
+        f = DomainLowerer()(f)
         self.assertRepr(f.statements, """
         (
             (eq (sig s) (const 1'd1))