From 6957908441b0a9027683cad4e4eb3e0405799149 Mon Sep 17 00:00:00 2001 From: whitequark Date: Tue, 4 Jun 2019 10:19:54 +0000 Subject: [PATCH] hdl.xfrm: handle empty lhs in LHSGroup{Analyzer,Filter}. --- nmigen/hdl/xfrm.py | 12 ++++++++---- nmigen/test/test_hdl_ast.py | 2 ++ nmigen/test/test_hdl_xfrm.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 7b66764..9e00cc6 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -485,7 +485,9 @@ class LHSGroupAnalyzer(StatementVisitor): return groups def on_Assign(self, stmt): - self.unify(*stmt._lhs_signals()) + lhs_signals = stmt._lhs_signals() + if lhs_signals: + self.unify(*stmt._lhs_signals()) on_Assert = on_Assign @@ -511,9 +513,11 @@ class LHSGroupFilter(SwitchCleaner): def on_Assign(self, stmt): # The invariant provided by LHSGroupAnalyzer is that all signals that ever appear together # on LHS are a part of the same group, so it is sufficient to check any of them. - any_lhs_signal = next(iter(stmt.lhs._lhs_signals())) - if any_lhs_signal in self.signals: - return stmt + lhs_signals = stmt.lhs._lhs_signals() + if lhs_signals: + any_lhs_signal = next(iter(lhs_signals)) + if any_lhs_signal in self.signals: + return stmt def on_Assert(self, stmt): any_lhs_signal = next(iter(stmt._lhs_signals())) diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 3661622..a3fadb2 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -314,6 +314,8 @@ class PartTestCase(FHDLTestCase): class CatTestCase(FHDLTestCase): def test_shape(self): + c0 = Cat() + self.assertEqual(c0.shape(), (0, False)) c1 = Cat(Const(10)) self.assertEqual(c1.shape(), (4, False)) c2 = Cat(Const(10), Const(1)) diff --git a/nmigen/test/test_hdl_xfrm.py b/nmigen/test/test_hdl_xfrm.py index 3973b8d..8a281d1 100644 --- a/nmigen/test/test_hdl_xfrm.py +++ b/nmigen/test/test_hdl_xfrm.py @@ -303,6 +303,15 @@ class LHSGroupAnalyzerTestCase(FHDLTestCase): SignalSet((b,)), ]) + def test_lhs_empty(self): + stmts = [ + Cat().eq(0) + ] + + groups = LHSGroupAnalyzer()(stmts) + self.assertEqual(list(groups.values()), [ + ]) + class LHSGroupFilterTestCase(FHDLTestCase): def test_filter(self): @@ -329,6 +338,13 @@ class LHSGroupFilterTestCase(FHDLTestCase): ) """) + def test_lhs_empty(self): + stmts = [ + Cat().eq(0) + ] + + self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()") + class ResetInserterTestCase(FHDLTestCase): def setUp(self): -- 2.30.2