hdl.ir: only pull explicitly specified ports to toplevel, if any.
authorwhitequark <whitequark@whitequark.org>
Sun, 12 May 2019 05:21:23 +0000 (05:21 +0000)
committerwhitequark <whitequark@whitequark.org>
Sun, 12 May 2019 05:21:23 +0000 (05:21 +0000)
Fixes #30.

nmigen/hdl/ir.py
nmigen/test/test_hdl_ir.py
nmigen/tools.py

index c5879690ac3fd7ae217899d7ca76f90d5d181021..6c050182a882d92a77a6c5f87fe02d30549cf783 100644 (file)
@@ -1,5 +1,6 @@
 from abc import ABCMeta, abstractmethod
 from collections import defaultdict, OrderedDict
+from functools import reduce
 import warnings
 import traceback
 import sys
@@ -340,69 +341,140 @@ class Fragment:
 
         return DomainLowerer(self.domains)(self)
 
-    def _propagate_ports(self, ports):
+    def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
+        def add_uses(*sigs, self=self):
+            for sig in flatten(sigs):
+                if sig not in uses:
+                    uses[sig] = set()
+                uses[sig].add(self)
+
+        def add_defs(*sigs):
+            for sig in flatten(sigs):
+                if sig not in defs:
+                    defs[sig] = self
+                else:
+                    assert defs[sig] is self
+
+        def add_io(sig):
+            assert sig not in ios
+            ios[sig] = self
+
         # Collect all signals we're driving (on LHS of statements), and signals we're using
         # (on RHS of statements, or in clock domains).
         if isinstance(self, Instance):
-            self_driven = SignalSet()
-            self_used   = SignalSet()
             for port_name, (value, dir) in self.named_ports.items():
                 if dir == "i":
-                    for signal in value._rhs_signals():
-                        self_used.add(signal)
-                        self.add_ports(signal, dir="i")
+                    add_uses(value._rhs_signals())
                 if dir == "o":
-                    for signal in value._lhs_signals():
-                        self_driven.add(signal)
-                        self.add_ports(signal, dir="o")
+                    add_defs(value._lhs_signals())
                 if dir == "io":
-                    self.add_ports(value, dir="io")
+                    add_io(value)
         else:
-            self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet())
-            self_used   = union((s._rhs_signals() for s in self.statements), start=SignalSet())
+            for stmt in self.statements:
+                add_uses(stmt._rhs_signals())
+                add_defs(stmt._lhs_signals())
+
             for domain, _ in self.iter_sync():
                 cd = self.domains[domain]
-                self_used.add(cd.clk)
+                add_uses(cd.clk)
                 if cd.rst is not None:
-                    self_used.add(cd.rst)
-
-        # Our input ports are all the signals we're using but not driving. This is an over-
-        # approximation: some of these signals may be driven by our subfragments.
-        ins  = self_used - self_driven
-        # Our output ports are all the signals we're asked to provide that we're driving. This is
-        # an underapproximation: some of these signals may be driven by subfragments.
-        outs = ports & self_driven
+                    add_uses(cd.rst)
 
-        # Go through subfragments and refine our approximation for inputs.
+        # Repeat for subfragments.
         for subfrag, name in self.subfragments:
-            # Refine the input port approximation: if a subfragment requires a signal as an input,
-            # and we aren't driving it, it has to be our input as well.
-            sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=())
-            ins  |= sub_ins - self_driven
+            parent[subfrag] = self
+            level [subfrag] = level[self] + 1
+
+            subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
+
+    def _propagate_ports(self, ports, all_undef_as_ports):
+        # Take this fragment graph:
+        #
+        #    __ B (def: q, use: p r)
+        #   /
+        #  A (def: p, use: q r)
+        #   \
+        #    \_ C (def: r, use: p q)
+        #
+        # We need to consider three cases.
+        #   1. Signal p requires an input port in B;
+        #   2. Signal r requires an output port in C;
+        #   3. Signal r requires an output port in C and an input port in B.
+        #
+        # Adding these ports can be in general done in three steps for each signal:
+        #   1. Find the least common ancestor of all uses and defs.
+        #   2. Going upwards from the single def, add output ports.
+        #   3. Going upwards from all uses, add input ports.
+
+        parent = {self: None}
+        level  = {self: 0}
+        uses   = SignalDict()
+        defs   = SignalDict()
+        ios    = SignalDict()
+        self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
+
+        ports = SignalSet(ports)
+        if all_undef_as_ports:
+            for sig in uses:
+                if sig in defs:
+                    continue
+                ports.add(sig)
+        for sig in ports:
+            if sig not in uses:
+                uses[sig] = set()
+            uses[sig].add(self)
+
+        @memoize
+        def lca_of(fragu, fragv):
+            # Normalize fragu to be deeper than fragv.
+            if level[fragu] < level[fragv]:
+                fragu, fragv = fragv, fragu
+            # Find ancestor of fragu on the same level as fragv.
+            for _ in range(level[fragu] - level[fragv]):
+                fragu = parent[fragu]
+            # If fragv was the ancestor of fragv, we're done.
+            if fragu == fragv:
+                return fragu
+            # Otherwise, they are at the same level but in different branches. Step both fragu
+            # and fragv until we find the common ancestor.
+            while parent[fragu] != parent[fragv]:
+                fragu = parent[fragu]
+                fragv = parent[fragv]
+            return parent[fragu]
+
+        for sig in uses:
+            if sig in defs:
+                lca  = reduce(lca_of, uses[sig], defs[sig])
+
+                frag = defs[sig]
+                while frag != lca:
+                    frag.add_ports(sig, dir="o")
+                    frag = parent[frag]
+            else:
+                lca  = reduce(lca_of, uses[sig])
+
+            for frag in uses[sig]:
+                if sig in defs and frag is defs[sig]:
+                    continue
+                while frag != lca:
+                    frag.add_ports(sig, dir="i")
+                    frag = parent[frag]
+
+        for sig in ios:
+            frag = ios[sig]
+            while frag is not None:
+                frag.add_ports(sig, dir="io")
+                frag = parent[frag]
+
+        for sig in ports:
+            if sig in ios:
+                continue
+            if sig in defs:
+                self.add_ports(sig, dir="o")
+            else:
+                self.add_ports(sig, dir="i")
 
-        for subfrag, name in self.subfragments:
-            # Always ask subfragments to provide all signals that are our inputs.
-            # If the subfragment is not driving it, it will silently ignore it.
-            sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports)
-            # Refine the input port appropximation further: if any subfragment is driving a signal
-            # that we currently think should be our input, it shouldn't actually be our input.
-            ins  -= sub_outs
-            # Refine the output port approximation: if a subfragment is driving a signal,
-            # and we're asked to provide it, we can provide it now.
-            outs |= ports & sub_outs
-            # All of our subfragments' bidirectional ports are also our bidirectional ports,
-            # since these are only used for pins.
-            self.add_ports(sub_inouts, dir="io")
-
-        # We've computed the precise set of input and output ports.
-        self.add_ports(ins,  dir="i")
-        self.add_ports(outs, dir="o")
-
-        return (SignalSet(self.iter_ports("i")),
-                SignalSet(self.iter_ports("o")),
-                SignalSet(self.iter_ports("io")))
-
-    def prepare(self, ports=(), ensure_sync_exists=True):
+    def prepare(self, ports=None, ensure_sync_exists=True):
         from .xfrm import SampleLowerer
 
         fragment = SampleLowerer()(self)
@@ -410,7 +482,10 @@ class Fragment:
         fragment._resolve_hierarchy_conflicts()
         fragment = fragment._insert_domain_resets()
         fragment = fragment._lower_domain_signals()
-        fragment._propagate_ports(ports)
+        if ports is None:
+            fragment._propagate_ports(ports=(), all_undef_as_ports=True)
+        else:
+            fragment._propagate_ports(ports=ports, all_undef_as_ports=False)
         return fragment
 
 
index f9f6ab8f3bbb7614609eee41eb381cdb2da9ef51..34a31e62d2d065d56e6050ea9167a6888e47a4e8 100644 (file)
@@ -65,7 +65,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f = Fragment()
         self.assertEqual(list(f.iter_ports()), [])
 
-        f._propagate_ports(ports=())
+        f._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([]))
 
     def test_iter_signals(self):
@@ -80,7 +80,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.s1.eq(self.c1)
         )
 
-        f._propagate_ports(ports=())
+        f._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([]))
 
     def test_infer_input(self):
@@ -89,7 +89,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.c1.eq(self.s1)
         )
 
-        f._propagate_ports(ports=())
+        f._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([
             (self.s1, "i")
         ]))
@@ -100,7 +100,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.c1.eq(self.s1)
         )
 
-        f._propagate_ports(ports=(self.c1,))
+        f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([
             (self.s1, "i"),
             (self.c1, "o")
@@ -116,7 +116,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.s1.eq(0)
         )
         f1.add_subfragment(f2)
-        f1._propagate_ports(ports=())
+        f1._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict())
         self.assertEqual(f2.ports, SignalDict([
             (self.s1, "o"),
@@ -129,7 +129,7 @@ class FragmentPortsTestCase(FHDLTestCase):
             self.c1.eq(self.s1)
         )
         f1.add_subfragment(f2)
-        f1._propagate_ports(ports=())
+        f1._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict([
             (self.s1, "i"),
         ]))
@@ -148,7 +148,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
         f1.add_subfragment(f2)
 
-        f1._propagate_ports(ports=(self.c2,))
+        f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict([
             (self.c2, "o"),
         ]))
@@ -170,7 +170,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f3.add_driver(self.c2)
         f1.add_subfragment(f3)
 
-        f1._propagate_ports(ports=())
+        f1._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict())
 
     def test_output_input_sibling(self):
@@ -187,7 +187,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         )
         f1.add_subfragment(f3)
 
-        f1._propagate_ports(ports=())
+        f1._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict())
 
     def test_input_cd(self):
@@ -199,7 +199,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_domains(sync)
         f.add_driver(self.c1, "sync")
 
-        f._propagate_ports(ports=())
+        f._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([
             (self.s1,  "i"),
             (sync.clk, "i"),
@@ -215,7 +215,7 @@ class FragmentPortsTestCase(FHDLTestCase):
         f.add_domains(sync)
         f.add_driver(self.c1, "sync")
 
-        f._propagate_ports(ports=())
+        f._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f.ports, SignalDict([
             (self.s1,  "i"),
             (sync.clk, "i"),
@@ -224,11 +224,10 @@ class FragmentPortsTestCase(FHDLTestCase):
     def test_inout(self):
         s = Signal()
         f1 = Fragment()
-        f2 = Fragment()
-        f2.add_ports(s, dir="io")
+        f2 = Instance("foo", io_x=s)
         f1.add_subfragment(f2)
 
-        f1._propagate_ports(ports=())
+        f1._propagate_ports(ports=(), all_undef_as_ports=True)
         self.assertEqual(f1.ports, SignalDict([
             (s, "io")
         ]))
@@ -556,9 +555,15 @@ class InstanceTestCase(FHDLTestCase):
         clk = f.domains["sync"].clk
         self.assertEqual(f.ports, SignalDict([
             (clk, "i"),
+            (self.rst, "i"),
+            (self.pins, "io"),
+        ]))
+
+    def test_prepare_explicit_ports(self):
+        self.setUp_cpu()
+        f = self.inst.prepare(ports=[self.rst, self.stb])
+        self.assertEqual(f.ports, SignalDict([
             (self.rst, "i"),
             (self.stb, "o"),
-            (self.datal, "o"),
-            (self.datah, "o"),
             (self.pins, "io"),
         ]))
index 25b6893ab1a4b85402bf25bc96ebdc66e58b4241..edf21959b2acaa833bcc605fca15070a190ebd04 100644 (file)
@@ -1,11 +1,12 @@
 import contextlib
 import functools
 import warnings
+from collections import OrderedDict
 from collections.abc import Iterable
 from contextlib import contextmanager
 
 
-__all__ = ["flatten", "union", "log2_int", "bits_for", "deprecated"]
+__all__ = ["flatten", "union", "log2_int", "bits_for", "memoize", "deprecated"]
 
 
 def flatten(i):
@@ -46,6 +47,16 @@ def bits_for(n, require_sign_bit=False):
     return r
 
 
+def memoize(f):
+    memo = OrderedDict()
+    @functools.wraps(f)
+    def g(*args):
+        if args not in memo:
+            memo[args] = f(*args)
+        return memo[args]
+    return g
+
+
 def deprecated(message, stacklevel=2):
     def decorator(f):
         @functools.wraps(f)