From: whitequark Date: Fri, 14 Dec 2018 22:47:58 +0000 (+0000) Subject: fhdl.ir: automatically flatten hierarchy to resolve driver conflicts. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=c692d27fe7ae0f641b0c0f9c75ac1d753399a16c;p=nmigen.git fhdl.ir: automatically flatten hierarchy to resolve driver conflicts. Fixes #5. --- diff --git a/.coveragerc b/.coveragerc index e70df50..6435aa3 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,3 +9,5 @@ omit = [report] exclude_lines = :nocov: +partial_branches = + :nobr: diff --git a/nmigen/fhdl/ir.py b/nmigen/fhdl/ir.py index dbac817..844cc7b 100644 --- a/nmigen/fhdl/ir.py +++ b/nmigen/fhdl/ir.py @@ -1,3 +1,4 @@ +import warnings from collections import defaultdict, OrderedDict from ..tools import * @@ -5,7 +6,11 @@ from .ast import * from .cd import * -__all__ = ["Fragment"] +__all__ = ["Fragment", "DriverConflict"] + + +class DriverConflict(UserWarning): + pass class Fragment: @@ -73,6 +78,77 @@ class Fragment: assert isinstance(subfragment, Fragment) self.subfragments.append((subfragment, name)) + def _resolve_driver_conflicts(self, hierarchy=("top",), mode="warn"): + assert mode in ("silent", "warn", "error") + + driver_subfrags = ValueDict() + + # For each signal driven by this fragment and/or its subfragments, determine which + # subfragments also drive it. + for domain, signal in self.iter_drivers(): + if signal not in driver_subfrags: + driver_subfrags[signal] = set() + driver_subfrags[signal].add((None, hierarchy)) + + for i, (subfrag, name) in enumerate(self.subfragments): + # First, recurse into subfragments and let them detect driver conflicts as well. + if name is None: + name = "".format(i) + subfrag_hierarchy = hierarchy + (name,) + subfrag_drivers = subfrag._resolve_driver_conflicts(subfrag_hierarchy, mode) + + # Second, classify subfragments by domains they define. + for signal in subfrag_drivers: + if signal not in driver_subfrags: + driver_subfrags[signal] = set() + driver_subfrags[signal].add((subfrag, subfrag_hierarchy)) + + # Find out the set of subfragments that needs to be flattened into this fragment + # to resolve driver-driver conflicts. + flatten_subfrags = set() + for signal, subfrags in driver_subfrags.items(): + if len(subfrags) > 1: + flatten_subfrags.update((f, h) for f, h in subfrags if f is not None) + + # While we're at it, show a message. + subfrag_names = ", ".join(sorted(".".join(h) for f, h in subfrags)) + message = ("Signal '{}' is driven from multiple fragments: {}" + .format(signal, subfrag_names)) + if mode == "error": + raise DriverConflict(message) + elif mode == "warn": + message += "; hierarchy will be flattened" + warnings.warn_explicit(message, DriverConflict, *signal.src_loc) + + for subfrag, subfrag_hierarchy in sorted(flatten_subfrags, key=lambda x: x[1]): + # Merge subfragment's everything except clock domains into this fragment. + # Flattening is done after clock domain propagation, so we can assume the domains + # are already the same in every involved fragment in the first place. + self.ports.update(subfrag.ports) + for domain, signal in subfrag.iter_drivers(): + self.add_driver(signal, domain) + self.statements += subfrag.statements + self.subfragments += subfrag.subfragments + + # Remove the merged subfragment. + for i, (check_subfrag, check_name) in enumerate(self.subfragments): # :nobr: + if subfrag == check_subfrag: + del self.subfragments[i] + break + + # If we flattened anything, we might be in a situation where we have a driver conflict + # again, e.g. if we had a tree of fragments like A --- B --- C where only fragments + # A and C were driving a signal S. In that case, since B is not driving S itself, + # processing B will not result in any flattening, but since B is transitively driving S, + # processing A will flatten B into it. Afterwards, we have a tree like AB --- C, which + # has another conflict. + if any(flatten_subfrags): + # Try flattening again. + return self._resolve_driver_conflicts(hierarchy, mode) + + # Nothing was flattened, we're done! + return ValueSet(driver_subfrags.keys()) + def _propagate_domains_up(self, hierarchy=("top",)): from .xfrm import DomainRenamer @@ -193,6 +269,7 @@ class Fragment: fragment = FragmentTransformer()(self) fragment._propagate_domains(ensure_sync_exists) + fragment._resolve_driver_conflicts() fragment = fragment._insert_domain_resets() fragment = fragment._lower_domain_signals() fragment._propagate_ports(ports) diff --git a/nmigen/test/test_fhdl_ir.py b/nmigen/test/test_fhdl_ir.py index 3068af4..ace2e12 100644 --- a/nmigen/test/test_fhdl_ir.py +++ b/nmigen/test/test_fhdl_ir.py @@ -245,3 +245,123 @@ class FragmentDomainsTestCase(FHDLTestCase): self.assertEqual(f1.domains.keys(), {"sync"}) self.assertEqual(f2.domains.keys(), {"sync"}) self.assertEqual(f1.domains["sync"], f2.domains["sync"]) + + +class FragmentDriverConflictTestCase(FHDLTestCase): + def setUp_self_sub(self): + self.s1 = Signal() + self.c1 = Signal() + self.c2 = Signal() + + self.f1 = Fragment() + self.f1.add_statements(self.c1.eq(0)) + self.f1.add_driver(self.s1) + self.f1.add_driver(self.c1, "sync") + + self.f1a = Fragment() + self.f1.add_subfragment(self.f1a, "f1a") + + self.f2 = Fragment() + self.f2.add_statements(self.c2.eq(1)) + self.f2.add_driver(self.s1) + self.f2.add_driver(self.c2, "sync") + self.f1.add_subfragment(self.f2) + + self.f1b = Fragment() + self.f1.add_subfragment(self.f1b, "f1b") + + self.f2a = Fragment() + self.f2.add_subfragment(self.f2a, "f2a") + + def test_conflict_self_sub(self): + self.setUp_self_sub() + + self.f1._resolve_driver_conflicts(mode="silent") + self.assertEqual(self.f1.subfragments, [ + (self.f1a, "f1a"), + (self.f1b, "f1b"), + (self.f2a, "f2a"), + ]) + self.assertRepr(self.f1.statements, """ + ( + (eq (sig c1) (const 1'd0)) + (eq (sig c2) (const 1'd1)) + ) + """) + self.assertEqual(self.f1.drivers, { + None: ValueSet((self.s1,)), + "sync": ValueSet((self.c1, self.c2)), + }) + + def test_conflict_self_sub_error(self): + self.setUp_self_sub() + + with self.assertRaises(DriverConflict, + msg="Signal '(sig s1)' is driven from multiple fragments: top, top."): + self.f1._resolve_driver_conflicts(mode="error") + + def test_conflict_self_sub_warning(self): + self.setUp_self_sub() + + with self.assertWarns(DriverConflict, + msg="Signal '(sig s1)' is driven from multiple fragments: top, top.; " + "hierarchy will be flattened"): + self.f1._resolve_driver_conflicts(mode="warn") + + def setUp_sub_sub(self): + self.s1 = Signal() + self.c1 = Signal() + self.c2 = Signal() + + self.f1 = Fragment() + + self.f2 = Fragment() + self.f2.add_driver(self.s1) + self.f2.add_statements(self.c1.eq(0)) + self.f1.add_subfragment(self.f2) + + self.f3 = Fragment() + self.f3.add_driver(self.s1) + self.f3.add_statements(self.c2.eq(1)) + self.f1.add_subfragment(self.f3) + + def test_conflict_sub_sub(self): + self.setUp_sub_sub() + + self.f1._resolve_driver_conflicts(mode="silent") + self.assertEqual(self.f1.subfragments, []) + self.assertRepr(self.f1.statements, """ + ( + (eq (sig c1) (const 1'd0)) + (eq (sig c2) (const 1'd1)) + ) + """) + + def setUp_self_subsub(self): + self.s1 = Signal() + self.c1 = Signal() + self.c2 = Signal() + + self.f1 = Fragment() + self.f1.add_driver(self.s1) + + self.f2 = Fragment() + self.f2.add_statements(self.c1.eq(0)) + self.f1.add_subfragment(self.f2) + + self.f3 = Fragment() + self.f3.add_driver(self.s1) + self.f3.add_statements(self.c2.eq(1)) + self.f2.add_subfragment(self.f3) + + def test_conflict_self_subsub(self): + self.setUp_self_subsub() + + self.f1._resolve_driver_conflicts(mode="silent") + self.assertEqual(self.f1.subfragments, []) + self.assertRepr(self.f1.statements, """ + ( + (eq (sig c1) (const 1'd0)) + (eq (sig c2) (const 1'd1)) + ) + """) diff --git a/nmigen/test/tools.py b/nmigen/test/tools.py index 65cf0ff..297e7f9 100644 --- a/nmigen/test/tools.py +++ b/nmigen/test/tools.py @@ -1,5 +1,6 @@ import re import unittest +import warnings from contextlib import contextmanager from ..fhdl.ast import * @@ -23,3 +24,12 @@ class FHDLTestCase(unittest.TestCase): if msg is not None: # WTF? unittest.assertRaises is completely broken. self.assertEqual(str(cm.exception), msg) + + @contextmanager + def assertWarns(self, category, msg=None): + with warnings.catch_warnings(record=True) as warns: + yield + self.assertEqual(len(warns), 1) + self.assertEqual(warns[0].category, category) + if msg is not None: + self.assertEqual(str(warns[0].message), msg)