fhdl.dsl: add tests for submodules.
authorwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 07:24:28 +0000 (07:24 +0000)
committerwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 07:24:28 +0000 (07:24 +0000)
nmigen/fhdl/dsl.py
nmigen/test/test_fhdl_dsl.py

index 9ef6303fe546d2439cbd995f53b395dccfbb2115..4437766ce01c14ab6fe6ea0c74c3ed07940f4264 100644 (file)
@@ -1,4 +1,4 @@
-from collections import OrderedDict
+from collections import OrderedDict, Iterable
 from contextlib import contextmanager
 
 from .ast import *
@@ -68,9 +68,13 @@ class _ModuleBuilderSubmodules:
     def __init__(self, builder):
         object.__setattr__(self, "_builder", builder)
 
-    def __iadd__(self, submodules):
-        for submodule in submodules:
-            self._builder._add_submodule(submodule)
+    def __iadd__(self, modules):
+        if isinstance(modules, Iterable):
+            for module in modules:
+                self._builder._add_submodule(module)
+        else:
+            module = modules
+            self._builder._add_submodule(module)
         return self
 
     def __setattr__(self, name, submodule):
@@ -254,7 +258,7 @@ class Module(_ModuleBuilderRoot):
 
     def _add_submodule(self, submodule, name=None):
         if not hasattr(submodule, "get_fragment"):
-            raise TypeError("Trying to add {!r}, which does not implement .get_fragment(), as "
+            raise TypeError("Trying to add '{!r}', which does not implement .get_fragment(), as "
                             "a submodule".format(submodule))
         self._submodules.append((submodule, name))
 
index ba102dce43d67d624593358c2e5419996cf89aca..cba77578ba8e1caeea350e4719f53bd2b19b2ab9 100644 (file)
@@ -317,3 +317,31 @@ class DSLTestCase(unittest.TestCase):
             (eq (sig c2) (const 1'd1))
         )
         """)
+
+    def test_submodule_anon(self):
+        m1 = Module()
+        m2 = Module()
+        m1.submodules += m2
+        self.assertEqual(m1._submodules, [(m2, None)])
+
+    def test_submodule_anon_multi(self):
+        m1 = Module()
+        m2 = Module()
+        m3 = Module()
+        m1.submodules += m2, m3
+        self.assertEqual(m1._submodules, [(m2, None), (m3, None)])
+
+    def test_submodule_named(self):
+        m1 = Module()
+        m2 = Module()
+        m1.submodules.foo = m2
+        self.assertEqual(m1._submodules, [(m2, "foo")])
+
+    def test_submodule_wrong(self):
+        m = Module()
+        with self.assertRaises(TypeError,
+                msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"):
+            m.submodules.foo = 1
+        with self.assertRaises(TypeError,
+                msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"):
+            m.submodules += 1