fhdl.xfrm: add tests for ResetInserter, CEInserter.
authorwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 08:39:02 +0000 (08:39 +0000)
committerwhitequark <cz@m-labs.hk>
Thu, 13 Dec 2018 08:39:02 +0000 (08:39 +0000)
nmigen/fhdl/xfrm.py
nmigen/test/test_fhdl_xfrm.py [new file with mode: 0644]

index 5bc1ec2042f85bde93ad1944a36cd98ca28151dd..8a20aab12ac457e4d7ad19048b159440d28f79d4 100644 (file)
@@ -89,36 +89,58 @@ class StatementTransformer:
         return self.on_statement(value)
 
 
-class _ControlInserter:
-    def __init__(self, controls):
-        if isinstance(controls, Value):
-            controls = {"sys": controls}
-        self.controls = OrderedDict(controls)
-
-    def __call__(self, fragment):
-        new_fragment = Fragment()
+class FragmentTransformer:
+    def map_subfragments(self, fragment, new_fragment):
         for subfragment, name in fragment.subfragments:
             new_fragment.add_subfragment(self(subfragment), name)
-        new_fragment.add_statements(fragment.statements)
+
+    def map_statements(self, fragment, new_fragment):
+        if hasattr(self, "on_statement"):
+            new_fragment.add_statements(map(self.on_statement, fragment.statements))
+        else:
+            new_fragment.add_statements(fragment.statements)
+
+    def map_drivers(self, fragment, new_fragment):
         for cd_name, signals in fragment.iter_domains():
             for signal in signals:
                 new_fragment.drive(signal, cd_name)
+
+    def on_fragment(self, fragment):
+        new_fragment = Fragment()
+        self.map_subfragments(fragment, new_fragment)
+        self.map_statements(fragment, new_fragment)
+        self.map_drivers(fragment, new_fragment)
+        return new_fragment
+
+    def __call__(self, value):
+        return self.on_fragment(value)
+
+
+class _ControlInserter(FragmentTransformer):
+    def __init__(self, controls):
+        if isinstance(controls, Value):
+            controls = {"sync": controls}
+        self.controls = OrderedDict(controls)
+
+    def on_fragment(self, fragment):
+        new_fragment = super().on_fragment(fragment)
+        for cd_name, signals in fragment.iter_domains():
             if cd_name is None or cd_name not in self.controls:
                 continue
-            self._wrap_control(new_fragment, cd_name, signals)
+            self._insert_control(new_fragment, cd_name, signals)
         return new_fragment
 
-    def _wrap_control(self, fragment, cd_name, signals):
-        raise NotImplementedError
+    def _insert_control(self, fragment, cd_name, signals):
+        raise NotImplementedError # :nocov:
 
 
 class ResetInserter(_ControlInserter):
-    def _wrap_control(self, fragment, cd_name, signals):
+    def _insert_control(self, fragment, cd_name, signals):
         stmts = [s.eq(Const(s.reset, s.nbits)) for s in signals if not s.reset_less]
         fragment.add_statements(Switch(self.controls[cd_name], {1: stmts}))
 
 
 class CEInserter(_ControlInserter):
-    def _wrap_control(self, fragment, cd_name, signals):
+    def _insert_control(self, fragment, cd_name, signals):
         stmts = [s.eq(s) for s in signals]
         fragment.add_statements(Switch(self.controls[cd_name], {0: stmts}))
diff --git a/nmigen/test/test_fhdl_xfrm.py b/nmigen/test/test_fhdl_xfrm.py
new file mode 100644 (file)
index 0000000..f7e45ad
--- /dev/null
@@ -0,0 +1,144 @@
+import re
+import unittest
+
+from nmigen.fhdl.ast import *
+from nmigen.fhdl.ir import *
+from nmigen.fhdl.xfrm import *
+
+
+class ResetInserterTestCase(unittest.TestCase):
+    def setUp(self):
+        self.s1 = Signal()
+        self.s2 = Signal(reset=1)
+        self.s3 = Signal(reset=1, reset_less=True)
+        self.c1 = Signal()
+
+    def assertRepr(self, obj, repr_str):
+        obj = Statement.wrap(obj)
+        repr_str = re.sub(r"\s+",   " ",  repr_str)
+        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+        self.assertEqual(repr(obj), repr_str.strip())
+
+    def test_reset_default(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(1)
+        )
+        f.drive(self.s1, "sync")
+
+        f = ResetInserter(self.c1)(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (const 1'd1))
+            (switch (sig c1)
+                (case 1 (eq (sig s1) (const 1'd0)))
+            )
+        )
+        """)
+
+    def test_reset_cd(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(1),
+            self.s2.eq(0),
+        )
+        f.drive(self.s1, "sync")
+        f.drive(self.s2, "pix")
+
+        f = ResetInserter({"pix": self.c1})(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (const 1'd1))
+            (eq (sig s2) (const 0'd0))
+            (switch (sig c1)
+                (case 1 (eq (sig s2) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_reset_value(self):
+        f = Fragment()
+        f.add_statements(
+            self.s2.eq(0)
+        )
+        f.drive(self.s2, "sync")
+
+        f = ResetInserter(self.c1)(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s2) (const 0'd0))
+            (switch (sig c1)
+                (case 1 (eq (sig s2) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_reset_less(self):
+        f = Fragment()
+        f.add_statements(
+            self.s3.eq(0)
+        )
+        f.drive(self.s3, "sync")
+
+        f = ResetInserter(self.c1)(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s3) (const 0'd0))
+            (switch (sig c1)
+                (case 1 )
+            )
+        )
+        """)
+
+
+class CEInserterTestCase(unittest.TestCase):
+    def setUp(self):
+        self.s1 = Signal()
+        self.s2 = Signal()
+        self.s3 = Signal()
+        self.c1 = Signal()
+
+    def assertRepr(self, obj, repr_str):
+        obj = Statement.wrap(obj)
+        repr_str = re.sub(r"\s+",   " ",  repr_str)
+        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+        self.assertEqual(repr(obj), repr_str.strip())
+
+    def test_ce_default(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(1)
+        )
+        f.drive(self.s1, "sync")
+
+        f = CEInserter(self.c1)(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (const 1'd1))
+            (switch (sig c1)
+                (case 0 (eq (sig s1) (sig s1)))
+            )
+        )
+        """)
+
+    def test_ce_cd(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(1),
+            self.s2.eq(0),
+        )
+        f.drive(self.s1, "sync")
+        f.drive(self.s2, "pix")
+
+        f = CEInserter({"pix": self.c1})(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (const 1'd1))
+            (eq (sig s2) (const 0'd0))
+            (switch (sig c1)
+                (case 0 (eq (sig s2) (sig s2)))
+            )
+        )
+        """)