return self.on_statement(value)
-class _ControlInserter:
- def __init__(self, controls):
- if isinstance(controls, Value):
- controls = {"sys": controls}
- self.controls = OrderedDict(controls)
-
- def __call__(self, fragment):
- new_fragment = Fragment()
+class FragmentTransformer:
+ def map_subfragments(self, fragment, new_fragment):
for subfragment, name in fragment.subfragments:
new_fragment.add_subfragment(self(subfragment), name)
- new_fragment.add_statements(fragment.statements)
+
+ def map_statements(self, fragment, new_fragment):
+ if hasattr(self, "on_statement"):
+ new_fragment.add_statements(map(self.on_statement, fragment.statements))
+ else:
+ new_fragment.add_statements(fragment.statements)
+
+ def map_drivers(self, fragment, new_fragment):
for cd_name, signals in fragment.iter_domains():
for signal in signals:
new_fragment.drive(signal, cd_name)
+
+ def on_fragment(self, fragment):
+ new_fragment = Fragment()
+ self.map_subfragments(fragment, new_fragment)
+ self.map_statements(fragment, new_fragment)
+ self.map_drivers(fragment, new_fragment)
+ return new_fragment
+
+ def __call__(self, value):
+ return self.on_fragment(value)
+
+
+class _ControlInserter(FragmentTransformer):
+ def __init__(self, controls):
+ if isinstance(controls, Value):
+ controls = {"sync": controls}
+ self.controls = OrderedDict(controls)
+
+ def on_fragment(self, fragment):
+ new_fragment = super().on_fragment(fragment)
+ for cd_name, signals in fragment.iter_domains():
if cd_name is None or cd_name not in self.controls:
continue
- self._wrap_control(new_fragment, cd_name, signals)
+ self._insert_control(new_fragment, cd_name, signals)
return new_fragment
- def _wrap_control(self, fragment, cd_name, signals):
- raise NotImplementedError
+ def _insert_control(self, fragment, cd_name, signals):
+ raise NotImplementedError # :nocov:
class ResetInserter(_ControlInserter):
- def _wrap_control(self, fragment, cd_name, signals):
+ def _insert_control(self, fragment, cd_name, signals):
stmts = [s.eq(Const(s.reset, s.nbits)) for s in signals if not s.reset_less]
fragment.add_statements(Switch(self.controls[cd_name], {1: stmts}))
class CEInserter(_ControlInserter):
- def _wrap_control(self, fragment, cd_name, signals):
+ def _insert_control(self, fragment, cd_name, signals):
stmts = [s.eq(s) for s in signals]
fragment.add_statements(Switch(self.controls[cd_name], {0: stmts}))
--- /dev/null
+import re
+import unittest
+
+from nmigen.fhdl.ast import *
+from nmigen.fhdl.ir import *
+from nmigen.fhdl.xfrm import *
+
+
+class ResetInserterTestCase(unittest.TestCase):
+ def setUp(self):
+ self.s1 = Signal()
+ self.s2 = Signal(reset=1)
+ self.s3 = Signal(reset=1, reset_less=True)
+ self.c1 = Signal()
+
+ def assertRepr(self, obj, repr_str):
+ obj = Statement.wrap(obj)
+ repr_str = re.sub(r"\s+", " ", repr_str)
+ repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+ repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+ self.assertEqual(repr(obj), repr_str.strip())
+
+ def test_reset_default(self):
+ f = Fragment()
+ f.add_statements(
+ self.s1.eq(1)
+ )
+ f.drive(self.s1, "sync")
+
+ f = ResetInserter(self.c1)(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s1) (const 1'd1))
+ (switch (sig c1)
+ (case 1 (eq (sig s1) (const 1'd0)))
+ )
+ )
+ """)
+
+ def test_reset_cd(self):
+ f = Fragment()
+ f.add_statements(
+ self.s1.eq(1),
+ self.s2.eq(0),
+ )
+ f.drive(self.s1, "sync")
+ f.drive(self.s2, "pix")
+
+ f = ResetInserter({"pix": self.c1})(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s1) (const 1'd1))
+ (eq (sig s2) (const 0'd0))
+ (switch (sig c1)
+ (case 1 (eq (sig s2) (const 1'd1)))
+ )
+ )
+ """)
+
+ def test_reset_value(self):
+ f = Fragment()
+ f.add_statements(
+ self.s2.eq(0)
+ )
+ f.drive(self.s2, "sync")
+
+ f = ResetInserter(self.c1)(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s2) (const 0'd0))
+ (switch (sig c1)
+ (case 1 (eq (sig s2) (const 1'd1)))
+ )
+ )
+ """)
+
+ def test_reset_less(self):
+ f = Fragment()
+ f.add_statements(
+ self.s3.eq(0)
+ )
+ f.drive(self.s3, "sync")
+
+ f = ResetInserter(self.c1)(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s3) (const 0'd0))
+ (switch (sig c1)
+ (case 1 )
+ )
+ )
+ """)
+
+
+class CEInserterTestCase(unittest.TestCase):
+ def setUp(self):
+ self.s1 = Signal()
+ self.s2 = Signal()
+ self.s3 = Signal()
+ self.c1 = Signal()
+
+ def assertRepr(self, obj, repr_str):
+ obj = Statement.wrap(obj)
+ repr_str = re.sub(r"\s+", " ", repr_str)
+ repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+ repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+ self.assertEqual(repr(obj), repr_str.strip())
+
+ def test_ce_default(self):
+ f = Fragment()
+ f.add_statements(
+ self.s1.eq(1)
+ )
+ f.drive(self.s1, "sync")
+
+ f = CEInserter(self.c1)(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s1) (const 1'd1))
+ (switch (sig c1)
+ (case 0 (eq (sig s1) (sig s1)))
+ )
+ )
+ """)
+
+ def test_ce_cd(self):
+ f = Fragment()
+ f.add_statements(
+ self.s1.eq(1),
+ self.s2.eq(0),
+ )
+ f.drive(self.s1, "sync")
+ f.drive(self.s2, "pix")
+
+ f = CEInserter({"pix": self.c1})(f)
+ self.assertRepr(f.statements, """
+ (
+ (eq (sig s1) (const 1'd1))
+ (eq (sig s2) (const 0'd0))
+ (switch (sig c1)
+ (case 0 (eq (sig s2) (sig s2)))
+ )
+ )
+ """)