from abc import ABCMeta, abstractmethod
from collections import defaultdict, OrderedDict
+from functools import reduce
import warnings
import traceback
import sys
return DomainLowerer(self.domains)(self)
- def _propagate_ports(self, ports):
+ def _prepare_use_def_graph(self, parent, level, uses, defs, ios, top):
+ def add_uses(*sigs, self=self):
+ for sig in flatten(sigs):
+ if sig not in uses:
+ uses[sig] = set()
+ uses[sig].add(self)
+
+ def add_defs(*sigs):
+ for sig in flatten(sigs):
+ if sig not in defs:
+ defs[sig] = self
+ else:
+ assert defs[sig] is self
+
+ def add_io(sig):
+ assert sig not in ios
+ ios[sig] = self
+
# Collect all signals we're driving (on LHS of statements), and signals we're using
# (on RHS of statements, or in clock domains).
if isinstance(self, Instance):
- self_driven = SignalSet()
- self_used = SignalSet()
for port_name, (value, dir) in self.named_ports.items():
if dir == "i":
- for signal in value._rhs_signals():
- self_used.add(signal)
- self.add_ports(signal, dir="i")
+ add_uses(value._rhs_signals())
if dir == "o":
- for signal in value._lhs_signals():
- self_driven.add(signal)
- self.add_ports(signal, dir="o")
+ add_defs(value._lhs_signals())
if dir == "io":
- self.add_ports(value, dir="io")
+ add_io(value)
else:
- self_driven = union((s._lhs_signals() for s in self.statements), start=SignalSet())
- self_used = union((s._rhs_signals() for s in self.statements), start=SignalSet())
+ for stmt in self.statements:
+ add_uses(stmt._rhs_signals())
+ add_defs(stmt._lhs_signals())
+
for domain, _ in self.iter_sync():
cd = self.domains[domain]
- self_used.add(cd.clk)
+ add_uses(cd.clk)
if cd.rst is not None:
- self_used.add(cd.rst)
-
- # Our input ports are all the signals we're using but not driving. This is an over-
- # approximation: some of these signals may be driven by our subfragments.
- ins = self_used - self_driven
- # Our output ports are all the signals we're asked to provide that we're driving. This is
- # an underapproximation: some of these signals may be driven by subfragments.
- outs = ports & self_driven
+ add_uses(cd.rst)
- # Go through subfragments and refine our approximation for inputs.
+ # Repeat for subfragments.
for subfrag, name in self.subfragments:
- # Refine the input port approximation: if a subfragment requires a signal as an input,
- # and we aren't driving it, it has to be our input as well.
- sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=())
- ins |= sub_ins - self_driven
+ parent[subfrag] = self
+ level [subfrag] = level[self] + 1
+
+ subfrag._prepare_use_def_graph(parent, level, uses, defs, ios, top)
+
+ def _propagate_ports(self, ports, all_undef_as_ports):
+ # Take this fragment graph:
+ #
+ # __ B (def: q, use: p r)
+ # /
+ # A (def: p, use: q r)
+ # \
+ # \_ C (def: r, use: p q)
+ #
+ # We need to consider three cases.
+ # 1. Signal p requires an input port in B;
+ # 2. Signal r requires an output port in C;
+ # 3. Signal r requires an output port in C and an input port in B.
+ #
+ # Adding these ports can be in general done in three steps for each signal:
+ # 1. Find the least common ancestor of all uses and defs.
+ # 2. Going upwards from the single def, add output ports.
+ # 3. Going upwards from all uses, add input ports.
+
+ parent = {self: None}
+ level = {self: 0}
+ uses = SignalDict()
+ defs = SignalDict()
+ ios = SignalDict()
+ self._prepare_use_def_graph(parent, level, uses, defs, ios, self)
+
+ ports = SignalSet(ports)
+ if all_undef_as_ports:
+ for sig in uses:
+ if sig in defs:
+ continue
+ ports.add(sig)
+ for sig in ports:
+ if sig not in uses:
+ uses[sig] = set()
+ uses[sig].add(self)
+
+ @memoize
+ def lca_of(fragu, fragv):
+ # Normalize fragu to be deeper than fragv.
+ if level[fragu] < level[fragv]:
+ fragu, fragv = fragv, fragu
+ # Find ancestor of fragu on the same level as fragv.
+ for _ in range(level[fragu] - level[fragv]):
+ fragu = parent[fragu]
+ # If fragv was the ancestor of fragv, we're done.
+ if fragu == fragv:
+ return fragu
+ # Otherwise, they are at the same level but in different branches. Step both fragu
+ # and fragv until we find the common ancestor.
+ while parent[fragu] != parent[fragv]:
+ fragu = parent[fragu]
+ fragv = parent[fragv]
+ return parent[fragu]
+
+ for sig in uses:
+ if sig in defs:
+ lca = reduce(lca_of, uses[sig], defs[sig])
+
+ frag = defs[sig]
+ while frag != lca:
+ frag.add_ports(sig, dir="o")
+ frag = parent[frag]
+ else:
+ lca = reduce(lca_of, uses[sig])
+
+ for frag in uses[sig]:
+ if sig in defs and frag is defs[sig]:
+ continue
+ while frag != lca:
+ frag.add_ports(sig, dir="i")
+ frag = parent[frag]
+
+ for sig in ios:
+ frag = ios[sig]
+ while frag is not None:
+ frag.add_ports(sig, dir="io")
+ frag = parent[frag]
+
+ for sig in ports:
+ if sig in ios:
+ continue
+ if sig in defs:
+ self.add_ports(sig, dir="o")
+ else:
+ self.add_ports(sig, dir="i")
- for subfrag, name in self.subfragments:
- # Always ask subfragments to provide all signals that are our inputs.
- # If the subfragment is not driving it, it will silently ignore it.
- sub_ins, sub_outs, sub_inouts = subfrag._propagate_ports(ports=ins | ports)
- # Refine the input port appropximation further: if any subfragment is driving a signal
- # that we currently think should be our input, it shouldn't actually be our input.
- ins -= sub_outs
- # Refine the output port approximation: if a subfragment is driving a signal,
- # and we're asked to provide it, we can provide it now.
- outs |= ports & sub_outs
- # All of our subfragments' bidirectional ports are also our bidirectional ports,
- # since these are only used for pins.
- self.add_ports(sub_inouts, dir="io")
-
- # We've computed the precise set of input and output ports.
- self.add_ports(ins, dir="i")
- self.add_ports(outs, dir="o")
-
- return (SignalSet(self.iter_ports("i")),
- SignalSet(self.iter_ports("o")),
- SignalSet(self.iter_ports("io")))
-
- def prepare(self, ports=(), ensure_sync_exists=True):
+ def prepare(self, ports=None, ensure_sync_exists=True):
from .xfrm import SampleLowerer
fragment = SampleLowerer()(self)
fragment._resolve_hierarchy_conflicts()
fragment = fragment._insert_domain_resets()
fragment = fragment._lower_domain_signals()
- fragment._propagate_ports(ports)
+ if ports is None:
+ fragment._propagate_ports(ports=(), all_undef_as_ports=True)
+ else:
+ fragment._propagate_ports(ports=ports, all_undef_as_ports=False)
return fragment
f = Fragment()
self.assertEqual(list(f.iter_ports()), [])
- f._propagate_ports(ports=())
+ f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([]))
def test_iter_signals(self):
self.s1.eq(self.c1)
)
- f._propagate_ports(ports=())
+ f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([]))
def test_infer_input(self):
self.c1.eq(self.s1)
)
- f._propagate_ports(ports=())
+ f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i")
]))
self.c1.eq(self.s1)
)
- f._propagate_ports(ports=(self.c1,))
+ f._propagate_ports(ports=(self.c1,), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(self.c1, "o")
self.s1.eq(0)
)
f1.add_subfragment(f2)
- f1._propagate_ports(ports=())
+ f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
self.assertEqual(f2.ports, SignalDict([
(self.s1, "o"),
self.c1.eq(self.s1)
)
f1.add_subfragment(f2)
- f1._propagate_ports(ports=())
+ f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(self.s1, "i"),
]))
)
f1.add_subfragment(f2)
- f1._propagate_ports(ports=(self.c2,))
+ f1._propagate_ports(ports=(self.c2,), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(self.c2, "o"),
]))
f3.add_driver(self.c2)
f1.add_subfragment(f3)
- f1._propagate_ports(ports=())
+ f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
def test_output_input_sibling(self):
)
f1.add_subfragment(f3)
- f1._propagate_ports(ports=())
+ f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict())
def test_input_cd(self):
f.add_domains(sync)
f.add_driver(self.c1, "sync")
- f._propagate_ports(ports=())
+ f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(sync.clk, "i"),
f.add_domains(sync)
f.add_driver(self.c1, "sync")
- f._propagate_ports(ports=())
+ f._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f.ports, SignalDict([
(self.s1, "i"),
(sync.clk, "i"),
def test_inout(self):
s = Signal()
f1 = Fragment()
- f2 = Fragment()
- f2.add_ports(s, dir="io")
+ f2 = Instance("foo", io_x=s)
f1.add_subfragment(f2)
- f1._propagate_ports(ports=())
+ f1._propagate_ports(ports=(), all_undef_as_ports=True)
self.assertEqual(f1.ports, SignalDict([
(s, "io")
]))
clk = f.domains["sync"].clk
self.assertEqual(f.ports, SignalDict([
(clk, "i"),
+ (self.rst, "i"),
+ (self.pins, "io"),
+ ]))
+
+ def test_prepare_explicit_ports(self):
+ self.setUp_cpu()
+ f = self.inst.prepare(ports=[self.rst, self.stb])
+ self.assertEqual(f.ports, SignalDict([
(self.rst, "i"),
(self.stb, "o"),
- (self.datal, "o"),
- (self.datah, "o"),
(self.pins, "io"),
]))