From: whitequark Date: Thu, 13 Dec 2018 07:24:28 +0000 (+0000) Subject: fhdl.dsl: add tests for submodules. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=2bf629245310a836bb08587efa48eaf6a505758b;p=nmigen.git fhdl.dsl: add tests for submodules. --- diff --git a/nmigen/fhdl/dsl.py b/nmigen/fhdl/dsl.py index 9ef6303..4437766 100644 --- a/nmigen/fhdl/dsl.py +++ b/nmigen/fhdl/dsl.py @@ -1,4 +1,4 @@ -from collections import OrderedDict +from collections import OrderedDict, Iterable from contextlib import contextmanager from .ast import * @@ -68,9 +68,13 @@ class _ModuleBuilderSubmodules: def __init__(self, builder): object.__setattr__(self, "_builder", builder) - def __iadd__(self, submodules): - for submodule in submodules: - self._builder._add_submodule(submodule) + def __iadd__(self, modules): + if isinstance(modules, Iterable): + for module in modules: + self._builder._add_submodule(module) + else: + module = modules + self._builder._add_submodule(module) return self def __setattr__(self, name, submodule): @@ -254,7 +258,7 @@ class Module(_ModuleBuilderRoot): def _add_submodule(self, submodule, name=None): if not hasattr(submodule, "get_fragment"): - raise TypeError("Trying to add {!r}, which does not implement .get_fragment(), as " + raise TypeError("Trying to add '{!r}', which does not implement .get_fragment(), as " "a submodule".format(submodule)) self._submodules.append((submodule, name)) diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py index ba102dc..cba7757 100644 --- a/nmigen/test/test_fhdl_dsl.py +++ b/nmigen/test/test_fhdl_dsl.py @@ -317,3 +317,31 @@ class DSLTestCase(unittest.TestCase): (eq (sig c2) (const 1'd1)) ) """) + + def test_submodule_anon(self): + m1 = Module() + m2 = Module() + m1.submodules += m2 + self.assertEqual(m1._submodules, [(m2, None)]) + + def test_submodule_anon_multi(self): + m1 = Module() + m2 = Module() + m3 = Module() + m1.submodules += m2, m3 + self.assertEqual(m1._submodules, [(m2, None), (m3, None)]) + + def test_submodule_named(self): + m1 = Module() + m2 = Module() + m1.submodules.foo = m2 + self.assertEqual(m1._submodules, [(m2, "foo")]) + + def test_submodule_wrong(self): + m = Module() + with self.assertRaises(TypeError, + msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"): + m.submodules.foo = 1 + with self.assertRaises(TypeError, + msg="Trying to add '1', which does not implement .get_fragment(), as a submodule"): + m.submodules += 1