hdl.xfrm: handle empty lhs in LHSGroup{Analyzer,Filter}.
authorwhitequark <cz@m-labs.hk>
Tue, 4 Jun 2019 10:19:54 +0000 (10:19 +0000)
committerwhitequark <cz@m-labs.hk>
Tue, 4 Jun 2019 10:26:01 +0000 (10:26 +0000)
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_ast.py
nmigen/test/test_hdl_xfrm.py

index 7b667646411f4810ce1ce867b0add6c8d20a75bb..9e00cc6d0e21aff0297d2b64a9b8e205b8e00747 100644 (file)
@@ -485,7 +485,9 @@ class LHSGroupAnalyzer(StatementVisitor):
         return groups
 
     def on_Assign(self, stmt):
-        self.unify(*stmt._lhs_signals())
+        lhs_signals = stmt._lhs_signals()
+        if lhs_signals:
+            self.unify(*stmt._lhs_signals())
 
     on_Assert = on_Assign
 
@@ -511,9 +513,11 @@ class LHSGroupFilter(SwitchCleaner):
     def on_Assign(self, stmt):
         # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together
         # on LHS are a part of the same group, so it is sufficient to check any of them.
-        any_lhs_signal = next(iter(stmt.lhs._lhs_signals()))
-        if any_lhs_signal in self.signals:
-            return stmt
+        lhs_signals = stmt.lhs._lhs_signals()
+        if lhs_signals:
+            any_lhs_signal = next(iter(lhs_signals))
+            if any_lhs_signal in self.signals:
+                return stmt
 
     def on_Assert(self, stmt):
         any_lhs_signal = next(iter(stmt._lhs_signals()))
index 36616224eaff33e8d3375c9d019726deb1c6d6e9..a3fadb27ec3fb4c18142c968fa87bdae9612538c 100644 (file)
@@ -314,6 +314,8 @@ class PartTestCase(FHDLTestCase):
 
 class CatTestCase(FHDLTestCase):
     def test_shape(self):
+        c0 = Cat()
+        self.assertEqual(c0.shape(), (0, False))
         c1 = Cat(Const(10))
         self.assertEqual(c1.shape(), (4, False))
         c2 = Cat(Const(10), Const(1))
index 3973b8df59a417fb307f4f63f8dbe46da11d0c73..8a281d16b7bd7e2f570c2ec960577f2febd4770d 100644 (file)
@@ -303,6 +303,15 @@ class LHSGroupAnalyzerTestCase(FHDLTestCase):
             SignalSet((b,)),
         ])
 
+    def test_lhs_empty(self):
+        stmts = [
+            Cat().eq(0)
+        ]
+
+        groups = LHSGroupAnalyzer()(stmts)
+        self.assertEqual(list(groups.values()), [
+        ])
+
 
 class LHSGroupFilterTestCase(FHDLTestCase):
     def test_filter(self):
@@ -329,6 +338,13 @@ class LHSGroupFilterTestCase(FHDLTestCase):
         )
         """)
 
+    def test_lhs_empty(self):
+        stmts = [
+            Cat().eq(0)
+        ]
+
+        self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
+
 
 class ResetInserterTestCase(FHDLTestCase):
     def setUp(self):