From 958cb18b886ae40d0a28140b40c187d90a692475 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 12 May 2019 05:21:23 +0000 Subject: [PATCH] hdl.ir: only pull explicitly specified ports to toplevel, if any. Fixes #30. --- nmigen/hdl/ir.py | 175 ++++++++++++++++++++++++++----------- nmigen/test/test_hdl_ir.py | 37 ++++---- nmigen/tools.py | 13 ++- 3 files changed, 158 insertions(+), 67 deletions(-) diff --git a/nmigen/hdl/ir.py b/nmigen/hdl/ir.py index c587969..6c05018 100644 --- a/nmigen/hdl/ir.py +++ b/nmigen/hdl/ir.py @@ -1,5 +1,6 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict, OrderedDict +from functools import reduce import warnings import traceback import sys @@ -340,69 +341,140 @@ class Fragment: return DomainLowerer(self.domains)(self) - def _propagate_ports(self, ports): + def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top): + def add_uses(*sigs, self=self): + for sig in flatten(sigs): + if sig not in uses: + uses[sig] = set() + uses[sig].add(self) + + def add_defs(*sigs): + for sig in flatten(sigs): + if sig not in defs: + defs[sig] = self + else: + assert defs[sig] is self + + def add_io(sig): + assert sig not in ios + ios[sig] = self + # Collect all signals we're driving (on LHS of statements), and signals we're using # (on RHS of statements, or in clock domains). if isinstance(self, Instance): - self_driven = SignalSet() - self_used = SignalSet() for port_name, (value, dir) in self.named_ports.items(): if dir == "i": - for signal in value._rhs_signals(): - self_used.add(signal) - self.add_ports(signal, dir="i") + add_uses(value._rhs_signals()) if dir == "o": - for signal in value._lhs_signals(): - self_driven.add(signal) - self.add_ports(signal, dir="o") + add_defs(value._lhs_signals()) if dir == "io": - self.add_ports(value, dir="io") + add_io(value) else: - self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet()) - self_used = union((s._rhs_signals() for s in self.statements), start=SignalSet()) + for stmt in self.statements: + add_uses(stmt._rhs_signals()) + add_defs(stmt._lhs_signals()) + for domain, _ in self.iter_sync(): cd = self.domains[domain] - self_used.add(cd.clk) + add_uses(cd.clk) if cd.rst is not None: - self_used.add(cd.rst) - - # Our input ports are all the signals we're using but not driving. This is an over- - # approximation: some of these signals may be driven by our subfragments. - ins = self_used - self_driven - # Our output ports are all the signals we're asked to provide that we're driving. This is - # an underapproximation: some of these signals may be driven by subfragments. - outs = ports & self_driven + add_uses(cd.rst) - # Go through subfragments and refine our approximation for inputs. + # Repeat for subfragments. for subfrag, name in self.subfragments: - # Refine the input port approximation: if a subfragment requires a signal as an input, - # and we aren't driving it, it has to be our input as well. - sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=()) - ins |= sub_ins - self_driven + parent[subfrag] = self + level [subfrag] = level[self] + 1 + + subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top) + + def _propagate_ports(self, ports, all_undef_as_ports): + # Take this fragment graph: + # + # __ B (def: q, use: p r) + # / + # A (def: p, use: q r) + # \ + # \_ C (def: r, use: p q) + # + # We need to consider three cases. + # 1. Signal p requires an input port in B; + # 2. Signal r requires an output port in C; + # 3. Signal r requires an output port in C and an input port in B. + # + # Adding these ports can be in general done in three steps for each signal: + # 1. Find the least common ancestor of all uses and defs. + # 2. Going upwards from the single def, add output ports. + # 3. Going upwards from all uses, add input ports. + + parent = {self: None} + level = {self: 0} + uses = SignalDict() + defs = SignalDict() + ios = SignalDict() + self._prepare_use_def_graph(parent, level, uses, defs, ios, self) + + ports = SignalSet(ports) + if all_undef_as_ports: + for sig in uses: + if sig in defs: + continue + ports.add(sig) + for sig in ports: + if sig not in uses: + uses[sig] = set() + uses[sig].add(self) + + @memoize + def lca_of(fragu, fragv): + # Normalize fragu to be deeper than fragv. + if level[fragu] < level[fragv]: + fragu, fragv = fragv, fragu + # Find ancestor of fragu on the same level as fragv. + for _ in range(level[fragu] - level[fragv]): + fragu = parent[fragu] + # If fragv was the ancestor of fragv, we're done. + if fragu == fragv: + return fragu + # Otherwise, they are at the same level but in different branches. Step both fragu + # and fragv until we find the common ancestor. + while parent[fragu] != parent[fragv]: + fragu = parent[fragu] + fragv = parent[fragv] + return parent[fragu] + + for sig in uses: + if sig in defs: + lca = reduce(lca_of, uses[sig], defs[sig]) + + frag = defs[sig] + while frag != lca: + frag.add_ports(sig, dir="o") + frag = parent[frag] + else: + lca = reduce(lca_of, uses[sig]) + + for frag in uses[sig]: + if sig in defs and frag is defs[sig]: + continue + while frag != lca: + frag.add_ports(sig, dir="i") + frag = parent[frag] + + for sig in ios: + frag = ios[sig] + while frag is not None: + frag.add_ports(sig, dir="io") + frag = parent[frag] + + for sig in ports: + if sig in ios: + continue + if sig in defs: + self.add_ports(sig, dir="o") + else: + self.add_ports(sig, dir="i") - for subfrag, name in self.subfragments: - # Always ask subfragments to provide all signals that are our inputs. - # If the subfragment is not driving it, it will silently ignore it. - sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports) - # Refine the input port appropximation further: if any subfragment is driving a signal - # that we currently think should be our input, it shouldn't actually be our input. - ins -= sub_outs - # Refine the output port approximation: if a subfragment is driving a signal, - # and we're asked to provide it, we can provide it now. - outs |= ports & sub_outs - # All of our subfragments' bidirectional ports are also our bidirectional ports, - # since these are only used for pins. - self.add_ports(sub_inouts, dir="io") - - # We've computed the precise set of input and output ports. - self.add_ports(ins, dir="i") - self.add_ports(outs, dir="o") - - return (SignalSet(self.iter_ports("i")), - SignalSet(self.iter_ports("o")), - SignalSet(self.iter_ports("io"))) - - def prepare(self, ports=(), ensure_sync_exists=True): + def prepare(self, ports=None, ensure_sync_exists=True): from .xfrm import SampleLowerer fragment = SampleLowerer()(self) @@ -410,7 +482,10 @@ class Fragment: fragment._resolve_hierarchy_conflicts() fragment = fragment._insert_domain_resets() fragment = fragment._lower_domain_signals() - fragment._propagate_ports(ports) + if ports is None: + fragment._propagate_ports(ports=(), all_undef_as_ports=True) + else: + fragment._propagate_ports(ports=ports, all_undef_as_ports=False) return fragment diff --git a/nmigen/test/test_hdl_ir.py b/nmigen/test/test_hdl_ir.py index f9f6ab8..34a31e6 100644 --- a/nmigen/test/test_hdl_ir.py +++ b/nmigen/test/test_hdl_ir.py @@ -65,7 +65,7 @@ class FragmentPortsTestCase(FHDLTestCase): f = Fragment() self.assertEqual(list(f.iter_ports()), []) - f._propagate_ports(ports=()) + f._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([])) def test_iter_signals(self): @@ -80,7 +80,7 @@ class FragmentPortsTestCase(FHDLTestCase): self.s1.eq(self.c1) ) - f._propagate_ports(ports=()) + f._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([])) def test_infer_input(self): @@ -89,7 +89,7 @@ class FragmentPortsTestCase(FHDLTestCase): self.c1.eq(self.s1) ) - f._propagate_ports(ports=()) + f._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([ (self.s1, "i") ])) @@ -100,7 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase): self.c1.eq(self.s1) ) - f._propagate_ports(ports=(self.c1,)) + f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([ (self.s1, "i"), (self.c1, "o") @@ -116,7 +116,7 @@ class FragmentPortsTestCase(FHDLTestCase): self.s1.eq(0) ) f1.add_subfragment(f2) - f1._propagate_ports(ports=()) + f1._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict()) self.assertEqual(f2.ports, SignalDict([ (self.s1, "o"), @@ -129,7 +129,7 @@ class FragmentPortsTestCase(FHDLTestCase): self.c1.eq(self.s1) ) f1.add_subfragment(f2) - f1._propagate_ports(ports=()) + f1._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict([ (self.s1, "i"), ])) @@ -148,7 +148,7 @@ class FragmentPortsTestCase(FHDLTestCase): ) f1.add_subfragment(f2) - f1._propagate_ports(ports=(self.c2,)) + f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict([ (self.c2, "o"), ])) @@ -170,7 +170,7 @@ class FragmentPortsTestCase(FHDLTestCase): f3.add_driver(self.c2) f1.add_subfragment(f3) - f1._propagate_ports(ports=()) + f1._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict()) def test_output_input_sibling(self): @@ -187,7 +187,7 @@ class FragmentPortsTestCase(FHDLTestCase): ) f1.add_subfragment(f3) - f1._propagate_ports(ports=()) + f1._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict()) def test_input_cd(self): @@ -199,7 +199,7 @@ class FragmentPortsTestCase(FHDLTestCase): f.add_domains(sync) f.add_driver(self.c1, "sync") - f._propagate_ports(ports=()) + f._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([ (self.s1, "i"), (sync.clk, "i"), @@ -215,7 +215,7 @@ class FragmentPortsTestCase(FHDLTestCase): f.add_domains(sync) f.add_driver(self.c1, "sync") - f._propagate_ports(ports=()) + f._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f.ports, SignalDict([ (self.s1, "i"), (sync.clk, "i"), @@ -224,11 +224,10 @@ class FragmentPortsTestCase(FHDLTestCase): def test_inout(self): s = Signal() f1 = Fragment() - f2 = Fragment() - f2.add_ports(s, dir="io") + f2 = Instance("foo", io_x=s) f1.add_subfragment(f2) - f1._propagate_ports(ports=()) + f1._propagate_ports(ports=(), all_undef_as_ports=True) self.assertEqual(f1.ports, SignalDict([ (s, "io") ])) @@ -556,9 +555,15 @@ class InstanceTestCase(FHDLTestCase): clk = f.domains["sync"].clk self.assertEqual(f.ports, SignalDict([ (clk, "i"), + (self.rst, "i"), + (self.pins, "io"), + ])) + + def test_prepare_explicit_ports(self): + self.setUp_cpu() + f = self.inst.prepare(ports=[self.rst, self.stb]) + self.assertEqual(f.ports, SignalDict([ (self.rst, "i"), (self.stb, "o"), - (self.datal, "o"), - (self.datah, "o"), (self.pins, "io"), ])) diff --git a/nmigen/tools.py b/nmigen/tools.py index 25b6893..edf2195 100644 --- a/nmigen/tools.py +++ b/nmigen/tools.py @@ -1,11 +1,12 @@ import contextlib import functools import warnings +from collections import OrderedDict from collections.abc import Iterable from contextlib import contextmanager -__all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"] +__all__ = ["flatten", "union", "log2_int", "bits_for", "memoize", "deprecated"] def flatten(i): @@ -46,6 +47,16 @@ def bits_for(n, require_sign_bit=False): return r +def memoize(f): + memo = OrderedDict() + @functools.wraps(f) + def g(*args): + if args not in memo: + memo[args] = f(*args) + return memo[args] + return g + + def deprecated(message, stacklevel=2): def decorator(f): @functools.wraps(f) -- 2.30.2