hdl.xfrm: add SampleLowerer.
authorwhitequark <whitequark@whitequark.org>
Thu, 17 Jan 2019 01:41:02 +0000 (01:41 +0000)
committerwhitequark <whitequark@whitequark.org>
Thu, 17 Jan 2019 01:41:02 +0000 (01:41 +0000)
nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_xfrm.py

index 95b23552fae16c1f62c0d3818b0a929a1b8043a3..bea24b4418f139e717cc8f79c57104f43f5fbcdc 100644 (file)
@@ -74,6 +74,15 @@ normalize = Const.normalize
 
 
 class _ValueCompiler(ValueVisitor):
+    def on_AnyConst(self, value):
+        raise NotImplementedError # :nocov:
+
+    def on_AnySeq(self, value):
+        raise NotImplementedError # :nocov:
+
+    def on_Sample(self, value):
+        raise NotImplementedError # :nocov:
+
     def on_Record(self, value):
         return self(Cat(value.fields.values()))
 
@@ -87,12 +96,6 @@ class _RHSValueCompiler(_ValueCompiler):
     def on_Const(self, value):
         return lambda state: value.value
 
-    def on_AnyConst(self, value):
-        raise NotImplementedError # :nocov:
-
-    def on_AnySeq(self, value):
-        raise NotImplementedError # :nocov:
-
     def on_Signal(self, value):
         if self.sensitivity is not None:
             self.sensitivity.add(value)
@@ -225,12 +228,6 @@ class _LHSValueCompiler(_ValueCompiler):
     def on_Const(self, value):
         raise TypeError # :nocov:
 
-    def on_AnyConst(self, value):
-        raise TypeError # :nocov:
-
-    def on_AnySeq(self, value):
-        raise TypeError # :nocov:
-
     def on_Signal(self, value):
         shape = value.shape()
         value_slot = self.signal_slots[value]
index ec157fecfbf4b8f1d7bd324d34d1eb568a21cce1..5225cf975dc2a1aa23391264f4d3d5dc6726be09 100644 (file)
@@ -308,6 +308,9 @@ class _ValueCompiler(xfrm.ValueVisitor):
     def on_ResetSignal(self, value):
         raise NotImplementedError # :nocov:
 
+    def on_Sample(self, value):
+        raise NotImplementedError # :nocov:
+
     def on_Record(self, value):
         return self(ast.Cat(value.fields.values()))
 
index 728b95bac72c54174f5244697a032f45adf78b7e..b335d0ceb3f173b5a208fdda4c468e9145dddee6 100644 (file)
@@ -13,7 +13,7 @@ from .rec import *
 __all__ = ["ValueVisitor", "ValueTransformer",
            "StatementVisitor", "StatementTransformer",
            "FragmentTransformer",
-           "DomainRenamer", "DomainLowerer",
+           "DomainRenamer", "DomainLowerer", "SampleLowerer",
            "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter",
            "ResetInserter", "CEInserter"]
 
@@ -71,6 +71,10 @@ class ValueVisitor(metaclass=ABCMeta):
     def on_ArrayProxy(self, value):
         pass # :nocov:
 
+    @abstractmethod
+    def on_Sample(self, value):
+        pass # :nocov:
+
     def on_unknown_value(self, value):
         raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
 
@@ -102,6 +106,8 @@ class ValueVisitor(metaclass=ABCMeta):
             new_value = self.on_Repl(value)
         elif type(value) is ArrayProxy:
             new_value = self.on_ArrayProxy(value)
+        elif type(value) is Sample:
+            new_value = self.on_Sample(value)
         else:
             new_value = self.on_unknown_value(value)
         if isinstance(new_value, Value):
@@ -153,6 +159,9 @@ class ValueTransformer(ValueVisitor):
         return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
                           self.on_value(value.index))
 
+    def on_Sample(self, value):
+        return Sample(self.on_value(value.value), value.clocks, value.domain)
+
 
 class StatementVisitor(metaclass=ABCMeta):
     @abstractmethod
@@ -331,6 +340,48 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer)
         return cd.rst
 
 
+class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer):
+    def __init__(self):
+        self.sample_cache = ValueDict()
+        self.sample_stmts = OrderedDict()
+
+    def _name_reset(self, value):
+        if isinstance(value, Const):
+            return "c${}".format(value.value), value.value
+        elif isinstance(value, Signal):
+            return "s${}".format(value.name), value.reset
+        else:
+            raise NotImplementedError # :nocov:
+
+    def on_Sample(self, value):
+        if value in self.sample_cache:
+            return self.sample_cache[value]
+
+        if value.clocks == 0:
+            sample = value.value
+        else:
+            assert value.domain is not None
+            sampled_name, sampled_reset = self._name_reset(value.value)
+            name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks)
+            sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset)
+
+            prev_sample = self.on_Sample(Sample(value.value, value.clocks - 1, value.domain))
+            if value.domain not in self.sample_stmts:
+                self.sample_stmts[value.domain] = []
+            self.sample_stmts[value.domain].append(sample.eq(prev_sample))
+
+        self.sample_cache[value] = sample
+        return sample
+
+    def on_fragment(self, fragment):
+        new_fragment = super().on_fragment(fragment)
+        for domain, stmts in self.sample_stmts.items():
+            new_fragment.add_statements(stmts)
+            for stmt in stmts:
+                new_fragment.add_driver(stmt.lhs, domain)
+        return new_fragment
+
+
 class SwitchCleaner(StatementVisitor):
     def on_Assign(self, stmt):
         return stmt
index 4483e7cb46fe017d3ae0626b4517a5b8f9bb1860..f983ab74a84ba4b2bd1c7a57a23290ded3d25102 100644 (file)
@@ -170,6 +170,52 @@ class DomainLowererTestCase(FHDLTestCase):
             DomainLowerer({"sync": sync})(f)
 
 
+class SampleLowererTestCase(FHDLTestCase):
+    def setUp(self):
+        self.i = Signal()
+        self.o1 = Signal()
+        self.o2 = Signal()
+        self.o3 = Signal()
+
+    def test_lower_signal(self):
+        f = Fragment()
+        f.add_statements(
+            self.o1.eq(Sample(self.i, 2, "sync")),
+            self.o2.eq(Sample(self.i, 1, "sync")),
+            self.o3.eq(Sample(self.i, 1, "pix")),
+        )
+
+        f = SampleLowerer()(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig o1) (sig $sample$s$i$sync$2))
+            (eq (sig o2) (sig $sample$s$i$sync$1))
+            (eq (sig o3) (sig $sample$s$i$pix$1))
+            (eq (sig $sample$s$i$sync$1) (sig i))
+            (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
+            (eq (sig $sample$s$i$pix$1) (sig i))
+        )
+        """)
+        self.assertEqual(len(f.drivers["sync"]), 2)
+        self.assertEqual(len(f.drivers["pix"]), 1)
+
+    def test_lower_const(self):
+        f = Fragment()
+        f.add_statements(
+            self.o1.eq(Sample(1, 2, "sync")),
+        )
+
+        f = SampleLowerer()(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig o1) (sig $sample$c$1$sync$2))
+            (eq (sig $sample$c$1$sync$1) (const 1'd1))
+            (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
+        )
+        """)
+        self.assertEqual(len(f.drivers["sync"]), 2)
+
+
 class SwitchCleanerTestCase(FHDLTestCase):
     def test_clean(self):
         a = Signal()