hdl.dsl: add getters to m.submodules.
authorN. Engelhardt <nakengelhardt@gmail.com>
Fri, 19 Jul 2019 12:39:47 +0000 (20:39 +0800)
committerwhitequark <cz@m-labs.hk>
Fri, 19 Jul 2019 12:39:47 +0000 (12:39 +0000)
nmigen/hdl/dsl.py
nmigen/test/test_hdl_dsl.py

index 58db1cf0fb4222bddba30a39aa7a587a020e518e..361fbebfa7f4d1e20ce228e1353a02deab18122e 100644 (file)
@@ -87,6 +87,12 @@ class _ModuleBuilderSubmodules:
     def __setitem__(self, name, value):
         return self.__setattr__(name, value)
 
+    def __getattr__(self, name):
+        return self._builder._get_submodule(name)
+
+    def __getitem__(self, name):
+        return self.__getattr__(name)
+
 
 class _ModuleBuilderDomainSet:
     def __init__(self, builder):
@@ -124,7 +130,8 @@ class Module(_ModuleBuilderRoot, Elaboratable):
         self._ctrl_stack   = []
 
         self._driving      = SignalDict()
-        self._submodules   = []
+        self._named_submodules   = {}
+        self._anon_submodules = []
         self._domains      = []
         self._generated    = {}
 
@@ -418,7 +425,18 @@ class Module(_ModuleBuilderRoot, Elaboratable):
         if not hasattr(submodule, "elaborate"):
             raise TypeError("Trying to add '{!r}', which does not implement .elaborate(), as "
                             "a submodule".format(submodule))
-        self._submodules.append((submodule, name))
+        if name == None:
+            self._anon_submodules.append(submodule)
+        else:
+            if name in self._named_submodules:
+                raise NameError("Submodule named '{}' already exists".format(name))
+            self._named_submodules[name] = submodule
+
+    def _get_submodule(self, name):
+        if name in self._named_submodules:
+            return self._named_submodules[name]
+        else:
+            raise AttributeError("No submodule named '{}' exists".format(name))
 
     def _add_domain(self, cd):
         self._domains.append(cd)
@@ -431,8 +449,10 @@ class Module(_ModuleBuilderRoot, Elaboratable):
         self._flush()
 
         fragment = Fragment()
-        for submodule, name in self._submodules:
-            fragment.add_subfragment(Fragment.get(submodule, platform), name)
+        for name in self._named_submodules:
+            fragment.add_subfragment(Fragment.get(self._named_submodules[name], platform), name)
+        for submodule in self._anon_submodules:
+            fragment.add_subfragment(Fragment.get(submodule, platform), None)
         statements = SampleDomainInjector("sync")(self._statements)
         fragment.add_statements(statements)
         for signal, domain in self._driving.items():
index 9af7038afc2f62ccc10d0621559ff55e8b754854..502e9a88d83796223e7d229dfcd2b5395eb36e38 100644 (file)
@@ -517,26 +517,30 @@ class DSLTestCase(FHDLTestCase):
         m1 = Module()
         m2 = Module()
         m1.submodules += m2
-        self.assertEqual(m1._submodules, [(m2, None)])
+        self.assertEqual(m1._anon_submodules, [m2])
+        self.assertEqual(m1._named_submodules, {})
 
     def test_submodule_anon_multi(self):
         m1 = Module()
         m2 = Module()
         m3 = Module()
         m1.submodules += m2, m3
-        self.assertEqual(m1._submodules, [(m2, None), (m3, None)])
+        self.assertEqual(m1._anon_submodules, [m2, m3])
+        self.assertEqual(m1._named_submodules, {})
 
     def test_submodule_named(self):
         m1 = Module()
         m2 = Module()
         m1.submodules.foo = m2
-        self.assertEqual(m1._submodules, [(m2, "foo")])
+        self.assertEqual(m1._anon_submodules, [])
+        self.assertEqual(m1._named_submodules, {"foo": m2})
 
     def test_submodule_named_index(self):
         m1 = Module()
         m2 = Module()
         m1.submodules["foo"] = m2
-        self.assertEqual(m1._submodules, [(m2, "foo")])
+        self.assertEqual(m1._anon_submodules, [])
+        self.assertEqual(m1._named_submodules, {"foo": m2})
 
     def test_submodule_wrong(self):
         m = Module()
@@ -547,6 +551,34 @@ class DSLTestCase(FHDLTestCase):
                 msg="Trying to add '1', which does not implement .elaborate(), as a submodule"):
             m.submodules += 1
 
+    def test_submodule_named_conflict(self):
+        m1 = Module()
+        m2 = Module()
+        m1.submodules.foo = m2
+        with self.assertRaises(NameError, msg="Submodule named 'foo' already exists"):
+            m1.submodules.foo = m2
+
+    def test_submodule_get(self):
+        m1 = Module()
+        m2 = Module()
+        m1.submodules.foo = m2
+        m3 = m1.submodules.foo
+        self.assertEqual(m2, m3)
+
+    def test_submodule_get_index(self):
+        m1 = Module()
+        m2 = Module()
+        m1.submodules["foo"] = m2
+        m3 = m1.submodules["foo"]
+        self.assertEqual(m2, m3)
+
+    def test_submodule_get_unset(self):
+        m1 = Module()
+        with self.assertRaises(AttributeError, msg="No submodule named 'foo' exists"):
+            m2 = m1.submodules.foo
+        with self.assertRaises(AttributeError, msg="No submodule named 'foo' exists"):
+            m2 = m1.submodules["foo"]
+
     def test_domain_named_implicit(self):
         m = Module()
         m.domains += ClockDomain("sync")