hdl.dsl: add clock domain support.
authorwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 23:51:24 +0000 (23:51 +0000)
committerwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 23:51:24 +0000 (23:51 +0000)
nmigen/compat/fhdl/module.py
nmigen/hdl/dsl.py
nmigen/hdl/ir.py
nmigen/test/test_hdl_dsl.py

index a9730ae3878d8fc04d78101952e4b39614fad2e1..84b51a1d0baee64081fdd0be84280910c4d63dc9 100644 (file)
@@ -90,7 +90,7 @@ class _CompatModuleClockDomains(_CompatModuleProxy):
 
     @deprecated("TODO")
     def __iadd__(self, other):
-        self._cm._fragment.clock_domains += _flat_list(other)
+        self._cm._module.domains += _flat_list(other)
         return self
 
 
index e240513d2b908122205d35a03fc4389ff1d592f3..7a386dbcc3382559e4bec900b026b29815b71018 100644 (file)
@@ -2,6 +2,7 @@ from collections import OrderedDict
 from collections.abc import Iterable
 from contextlib import contextmanager
 
+from ..tools import flatten
 from .ast import *
 from .ir import *
 from .xfrm import *
@@ -20,7 +21,7 @@ class _ModuleBuilderProxy:
         object.__setattr__(self, "_depth", depth)
 
 
-class _ModuleBuilderDomain(_ModuleBuilderProxy):
+class _ModuleBuilderDomainExplicit(_ModuleBuilderProxy):
     def __init__(self, builder, depth, domain):
         super().__init__(builder, depth)
         self._domain = domain
@@ -30,13 +31,13 @@ class _ModuleBuilderDomain(_ModuleBuilderProxy):
         return self
 
 
-class _ModuleBuilderDomains(_ModuleBuilderProxy):
+class _ModuleBuilderDomainImplicit(_ModuleBuilderProxy):
     def __getattr__(self, name):
         if name == "comb":
             domain = None
         else:
             domain = name
-        return _ModuleBuilderDomain(self._builder, self._depth, domain)
+        return _ModuleBuilderDomainExplicit(self._builder, self._depth, domain)
 
     def __getitem__(self, name):
         return self.__getattr__(name)
@@ -44,7 +45,7 @@ class _ModuleBuilderDomains(_ModuleBuilderProxy):
     def __setattr__(self, name, value):
         if name == "_depth":
             object.__setattr__(self, name, value)
-        elif not isinstance(value, _ModuleBuilderDomain):
+        elif not isinstance(value, _ModuleBuilderDomainExplicit):
             raise AttributeError("Cannot assign 'd.{}' attribute; did you mean 'd.{} +='?"
                                  .format(name, name))
 
@@ -55,7 +56,7 @@ class _ModuleBuilderDomains(_ModuleBuilderProxy):
 class _ModuleBuilderRoot:
     def __init__(self, builder, depth):
         self._builder = builder
-        self.domain = self.d = _ModuleBuilderDomains(builder, depth)
+        self.domain = self.d = _ModuleBuilderDomainImplicit(builder, depth)
 
     def __getattr__(self, name):
         if name in ("comb", "sync"):
@@ -70,11 +71,7 @@ class _ModuleBuilderSubmodules:
         object.__setattr__(self, "_builder", builder)
 
     def __iadd__(self, modules):
-        if isinstance(modules, Iterable):
-            for module in modules:
-                self._builder._add_submodule(module)
-        else:
-            module = modules
+        for module in flatten([modules]):
             self._builder._add_submodule(module)
         return self
 
@@ -82,17 +79,33 @@ class _ModuleBuilderSubmodules:
         self._builder._add_submodule(submodule, name)
 
 
+class _ModuleBuilderDomainSet:
+    def __init__(self, builder):
+        object.__setattr__(self, "_builder", builder)
+
+    def __iadd__(self, domains):
+        for domain in flatten([domains]):
+            self._builder._add_domain(domain)
+        return self
+
+    def __setattr__(self, name, domain):
+        self._builder._add_domain(domain)
+
+
 class Module(_ModuleBuilderRoot):
     def __init__(self):
         _ModuleBuilderRoot.__init__(self, self, depth=0)
-        self.submodules = _ModuleBuilderSubmodules(self)
+        self.submodules    = _ModuleBuilderSubmodules(self)
+        self.domains       = _ModuleBuilderDomainSet(self)
 
-        self._submodules   = []
-        self._driving      = ValueDict()
         self._statements   = Statement.wrap([])
         self._ctrl_context = None
         self._ctrl_stack   = []
 
+        self._driving      = ValueDict()
+        self._submodules   = []
+        self._domains      = []
+
     def _check_context(self, construct, context):
         if self._ctrl_context != context:
             if self._ctrl_context is None:
@@ -259,6 +272,9 @@ class Module(_ModuleBuilderRoot):
                             "a submodule".format(submodule))
         self._submodules.append((submodule, name))
 
+    def _add_domain(self, cd):
+        self._domains.append(cd)
+
     def _flush(self):
         while self._ctrl_stack:
             self._pop_ctrl()
@@ -272,6 +288,7 @@ class Module(_ModuleBuilderRoot):
         fragment.add_statements(self._statements)
         for signal, domain in self._driving.items():
             fragment.add_driver(signal, domain)
+        fragment.add_domains(self._domains)
         return fragment
 
     get_fragment = lower
index 844cc7b8abee82154bfcca87ec621860bb269196..1ff1961749559aac6904a57062043c59831e83fd 100644 (file)
@@ -63,7 +63,7 @@ class Fragment:
         return signals
 
     def add_domains(self, *domains):
-        for domain in domains:
+        for domain in flatten(domains):
             assert isinstance(domain, ClockDomain)
             assert domain.name not in self.domains
             self.domains[domain.name] = domain
index f156880e9b31a3d897eae87f6bce9fcf2ea2e304..f5fec900b02dcf1b326419ac338abb5543d70dd7 100644 (file)
@@ -1,4 +1,5 @@
 from ..hdl.ast import *
+from ..hdl.cd import *
 from ..hdl.dsl import *
 from .tools import *
 
@@ -342,6 +343,17 @@ class DSLTestCase(FHDLTestCase):
                 msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"):
             m.submodules += 1
 
+    def test_domain_named_implicit(self):
+        m = Module()
+        m.domains += ClockDomain("sync")
+        self.assertEqual(len(m._domains), 1)
+
+    def test_domain_named_explicit(self):
+        m = Module()
+        m.domains.foo = ClockDomain()
+        self.assertEqual(len(m._domains), 1)
+        self.assertEqual(m._domains[0].name, "foo")
+
     def test_lower(self):
         m1 = Module()
         m1.d.comb += self.c1.eq(self.s1)