From: whitequark Date: Thu, 17 Jan 2019 01:41:02 +0000 (+0000) Subject: hdl.xfrm: add SampleLowerer. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0199ce8375fbc258a6e787bd92157684fa2b0c2b;p=nmigen.git hdl.xfrm: add SampleLowerer. --- diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 95b2355..bea24b4 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -74,6 +74,15 @@ normalize = Const.normalize class _ValueCompiler(ValueVisitor): + def on_AnyConst(self, value): + raise NotImplementedError # :nocov: + + def on_AnySeq(self, value): + raise NotImplementedError # :nocov: + + def on_Sample(self, value): + raise NotImplementedError # :nocov: + def on_Record(self, value): return self(Cat(value.fields.values())) @@ -87,12 +96,6 @@ class _RHSValueCompiler(_ValueCompiler): def on_Const(self, value): return lambda state: value.value - def on_AnyConst(self, value): - raise NotImplementedError # :nocov: - - def on_AnySeq(self, value): - raise NotImplementedError # :nocov: - def on_Signal(self, value): if self.sensitivity is not None: self.sensitivity.add(value) @@ -225,12 +228,6 @@ class _LHSValueCompiler(_ValueCompiler): def on_Const(self, value): raise TypeError # :nocov: - def on_AnyConst(self, value): - raise TypeError # :nocov: - - def on_AnySeq(self, value): - raise TypeError # :nocov: - def on_Signal(self, value): shape = value.shape() value_slot = self.signal_slots[value] diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index ec157fe..5225cf9 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -308,6 +308,9 @@ class _ValueCompiler(xfrm.ValueVisitor): def on_ResetSignal(self, value): raise NotImplementedError # :nocov: + def on_Sample(self, value): + raise NotImplementedError # :nocov: + def on_Record(self, value): return self(ast.Cat(value.fields.values())) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index 728b95b..b335d0c 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -13,7 +13,7 @@ from .rec import * __all__ = ["ValueVisitor", "ValueTransformer", "StatementVisitor", "StatementTransformer", "FragmentTransformer", - "DomainRenamer", "DomainLowerer", + "DomainRenamer", "DomainLowerer", "SampleLowerer", "SwitchCleaner", "LHSGroupAnalyzer", "LHSGroupFilter", "ResetInserter", "CEInserter"] @@ -71,6 +71,10 @@ class ValueVisitor(metaclass=ABCMeta): def on_ArrayProxy(self, value): pass # :nocov: + @abstractmethod + def on_Sample(self, value): + pass # :nocov: + def on_unknown_value(self, value): raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov: @@ -102,6 +106,8 @@ class ValueVisitor(metaclass=ABCMeta): new_value = self.on_Repl(value) elif type(value) is ArrayProxy: new_value = self.on_ArrayProxy(value) + elif type(value) is Sample: + new_value = self.on_Sample(value) else: new_value = self.on_unknown_value(value) if isinstance(new_value, Value): @@ -153,6 +159,9 @@ class ValueTransformer(ValueVisitor): return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()], self.on_value(value.index)) + def on_Sample(self, value): + return Sample(self.on_value(value.value), value.clocks, value.domain) + class StatementVisitor(metaclass=ABCMeta): @abstractmethod @@ -331,6 +340,48 @@ class DomainLowerer(FragmentTransformer, ValueTransformer, StatementTransformer) return cd.rst +class SampleLowerer(FragmentTransformer, ValueTransformer, StatementTransformer): + def __init__(self): + self.sample_cache = ValueDict() + self.sample_stmts = OrderedDict() + + def _name_reset(self, value): + if isinstance(value, Const): + return "c${}".format(value.value), value.value + elif isinstance(value, Signal): + return "s${}".format(value.name), value.reset + else: + raise NotImplementedError # :nocov: + + def on_Sample(self, value): + if value in self.sample_cache: + return self.sample_cache[value] + + if value.clocks == 0: + sample = value.value + else: + assert value.domain is not None + sampled_name, sampled_reset = self._name_reset(value.value) + name = "$sample${}${}${}".format(sampled_name, value.domain, value.clocks) + sample = Signal.like(value.value, name=name, reset_less=True, reset=sampled_reset) + + prev_sample = self.on_Sample(Sample(value.value, value.clocks - 1, value.domain)) + if value.domain not in self.sample_stmts: + self.sample_stmts[value.domain] = [] + self.sample_stmts[value.domain].append(sample.eq(prev_sample)) + + self.sample_cache[value] = sample + return sample + + def on_fragment(self, fragment): + new_fragment = super().on_fragment(fragment) + for domain, stmts in self.sample_stmts.items(): + new_fragment.add_statements(stmts) + for stmt in stmts: + new_fragment.add_driver(stmt.lhs, domain) + return new_fragment + + class SwitchCleaner(StatementVisitor): def on_Assign(self, stmt): return stmt diff --git a/nmigen/test/test_hdl_xfrm.py b/nmigen/test/test_hdl_xfrm.py index 4483e7c..f983ab7 100644 --- a/nmigen/test/test_hdl_xfrm.py +++ b/nmigen/test/test_hdl_xfrm.py @@ -170,6 +170,52 @@ class DomainLowererTestCase(FHDLTestCase): DomainLowerer({"sync": sync})(f) +class SampleLowererTestCase(FHDLTestCase): + def setUp(self): + self.i = Signal() + self.o1 = Signal() + self.o2 = Signal() + self.o3 = Signal() + + def test_lower_signal(self): + f = Fragment() + f.add_statements( + self.o1.eq(Sample(self.i, 2, "sync")), + self.o2.eq(Sample(self.i, 1, "sync")), + self.o3.eq(Sample(self.i, 1, "pix")), + ) + + f = SampleLowerer()(f) + self.assertRepr(f.statements, """ + ( + (eq (sig o1) (sig $sample$s$i$sync$2)) + (eq (sig o2) (sig $sample$s$i$sync$1)) + (eq (sig o3) (sig $sample$s$i$pix$1)) + (eq (sig $sample$s$i$sync$1) (sig i)) + (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1)) + (eq (sig $sample$s$i$pix$1) (sig i)) + ) + """) + self.assertEqual(len(f.drivers["sync"]), 2) + self.assertEqual(len(f.drivers["pix"]), 1) + + def test_lower_const(self): + f = Fragment() + f.add_statements( + self.o1.eq(Sample(1, 2, "sync")), + ) + + f = SampleLowerer()(f) + self.assertRepr(f.statements, """ + ( + (eq (sig o1) (sig $sample$c$1$sync$2)) + (eq (sig $sample$c$1$sync$1) (const 1'd1)) + (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1)) + ) + """) + self.assertEqual(len(f.drivers["sync"]), 2) + + class SwitchCleanerTestCase(FHDLTestCase): def test_clean(self): a = Signal()