hdl.xfrm: separate AST traversal from AST identity mapping.
authorwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 11:24:23 +0000 (11:24 +0000)
committerwhitequark <cz@m-labs.hk>
Sun, 16 Dec 2018 11:25:52 +0000 (11:25 +0000)
This is useful because backends don't generally want or need AST
identity mapping (unlike all other transforms) and when adding a new
node, it results in confusing type errors.

nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/hdl/xfrm.py

index 9a2ae8634a33b25199ded0e0fdb30b63e90aedcc..4d0364dd326d10eb63c6d7247a1e625ce650c62f 100644 (file)
@@ -6,7 +6,7 @@ from vcd.gtkw import GTKWSave
 
 from ..tools import flatten
 from ..hdl.ast import *
-from ..hdl.xfrm import ValueTransformer, StatementTransformer
+from ..hdl.xfrm import AbstractValueTransformer, AbstractStatementTransformer
 
 
 __all__ = ["Simulator", "Delay", "Tick", "Passive", "DeadlineError"]
@@ -44,7 +44,7 @@ class _State:
 normalize = Const.normalize
 
 
-class _RHSValueCompiler(ValueTransformer):
+class _RHSValueCompiler(AbstractValueTransformer):
     def __init__(self, sensitivity=None, mode="rhs"):
         self.sensitivity = sensitivity
         self.signal_mode = mode
@@ -165,7 +165,7 @@ class _RHSValueCompiler(ValueTransformer):
         return lambda state: normalize(elems[index(state)](state), shape)
 
 
-class _LHSValueCompiler(ValueTransformer):
+class _LHSValueCompiler(AbstractValueTransformer):
     def __init__(self, rhs_compiler):
         self.rhs_compiler = rhs_compiler
 
@@ -234,7 +234,7 @@ class _LHSValueCompiler(ValueTransformer):
         return eval
 
 
-class _StatementCompiler(StatementTransformer):
+class _StatementCompiler(AbstractStatementTransformer):
     def __init__(self):
         self.sensitivity   = ValueSet()
         self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
index 356e0f5aa381d265c93dc98c81366dc9339d2d60..7dee209dfda9d2c6732dffdda5e819853d93ee7b 100644 (file)
@@ -195,7 +195,7 @@ def src(src_loc):
     return "{}:{}".format(file, line)
 
 
-class _ValueTransformer(xfrm.ValueTransformer):
+class _ValueTransformer(xfrm.AbstractValueTransformer):
     operator_map = {
         (1, "~"):    "$not",
         (1, "-"):    "$neg",
@@ -300,6 +300,12 @@ class _ValueTransformer(xfrm.ValueTransformer):
         else:
             return wire_curr
 
+    def on_ClockSignal(self, value):
+        raise NotImplementedError # :nocov:
+
+    def on_ResetSignal(self, value):
+        raise NotImplementedError # :nocov:
+
     def on_Operator_unary(self, node):
         arg, = node.operands
         arg_bits, arg_sign = arg.shape()
@@ -397,8 +403,8 @@ class _ValueTransformer(xfrm.ValueTransformer):
         else:
             return "{} [{}:{}]".format(self(node.value), node.end - 1, node.start)
 
-    def on_Part(self, node):
-    #     return _Part(self(node.value), self(node.offset), node.width)
+    def on_Part(self, node):
+        raise NotImplementedError
 
     def on_Cat(self, node):
         return "{{ {} }}".format(" ".join(reversed([self(o) for o in node.operands])))
@@ -406,6 +412,9 @@ class _ValueTransformer(xfrm.ValueTransformer):
     def on_Repl(self, node):
         return "{{ {} }}".format(" ".join(self(node.value) for _ in range(node.count)))
 
+    def on_ArrayProxy(self, node):
+        raise NotImplementedError
+
 
 def convert_fragment(builder, fragment, name, top):
     with builder.module(name or "anonymous", attrs={"top": 1} if top else {}) as module:
index 8d09ac83145915af79446241b1a680200958c411..62f7fea697ed1e33a6d87b5e8910bab2b2147457 100644 (file)
@@ -1,3 +1,4 @@
+from abc import ABCMeta, abstractmethod
 from collections import OrderedDict
 from collections.abc import Iterable
 
@@ -8,41 +9,52 @@ from .cd import *
 from .ir import *
 
 
-__all__ = ["ValueTransformer", "StatementTransformer", "FragmentTransformer",
+__all__ = ["AbstractValueTransformer", "ValueTransformer",
+           "AbstractStatementTransformer", "StatementTransformer",
+           "FragmentTransformer",
            "DomainRenamer", "DomainLowerer", "ResetInserter", "CEInserter"]
 
 
-class ValueTransformer:
+class AbstractValueTransformer(metaclass=ABCMeta):
+    @abstractmethod
     def on_Const(self, value):
-        return value
+        pass
 
+    @abstractmethod
     def on_Signal(self, value):
-        return value
+        pass
 
+    @abstractmethod
     def on_ClockSignal(self, value):
-        return value
+        pass
 
+    @abstractmethod
     def on_ResetSignal(self, value):
-        return value
+        pass
 
+    @abstractmethod
     def on_Operator(self, value):
-        return Operator(value.op, [self.on_value(o) for o in value.operands])
+        pass
 
+    @abstractmethod
     def on_Slice(self, value):
-        return Slice(self.on_value(value.value), value.start, value.end)
+        pass
 
+    @abstractmethod
     def on_Part(self, value):
-        return Part(self.on_value(value.value), self.on_value(value.offset), value.width)
+        pass
 
+    @abstractmethod
     def on_Cat(self, value):
-        return Cat(self.on_value(o) for o in value.operands)
+        pass
 
+    @abstractmethod
     def on_Repl(self, value):
-        return Repl(self.on_value(value.value), value.count)
+        pass
 
+    @abstractmethod
     def on_ArrayProxy(self, value):
-        return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
-                          self.on_value(value.index))
+        pass
 
     def on_unknown_value(self, value):
         raise TypeError("Cannot transform value '{!r}'".format(value)) # :nocov:
@@ -78,19 +90,51 @@ class ValueTransformer:
         return self.on_value(value)
 
 
-class StatementTransformer:
-    def on_value(self, value):
+class ValueTransformer(AbstractValueTransformer):
+    def on_Const(self, value):
+        return value
+
+    def on_Signal(self, value):
+        return value
+
+    def on_ClockSignal(self, value):
         return value
 
+    def on_ResetSignal(self, value):
+        return value
+
+    def on_Operator(self, value):
+        return Operator(value.op, [self.on_value(o) for o in value.operands])
+
+    def on_Slice(self, value):
+        return Slice(self.on_value(value.value), value.start, value.end)
+
+    def on_Part(self, value):
+        return Part(self.on_value(value.value), self.on_value(value.offset), value.width)
+
+    def on_Cat(self, value):
+        return Cat(self.on_value(o) for o in value.operands)
+
+    def on_Repl(self, value):
+        return Repl(self.on_value(value.value), value.count)
+
+    def on_ArrayProxy(self, value):
+        return ArrayProxy([self.on_value(elem) for elem in value._iter_as_values()],
+                          self.on_value(value.index))
+
+
+class AbstractStatementTransformer(metaclass=ABCMeta):
+    @abstractmethod
     def on_Assign(self, stmt):
-        return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
+        pass
 
+    @abstractmethod
     def on_Switch(self, stmt):
-        cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
-        return Switch(self.on_value(stmt.test), cases)
+        pass
 
+    @abstractmethod
     def on_statements(self, stmt):
-        return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
+        pass
 
     def on_unknown_statement(self, stmt):
         raise TypeError("Cannot transform statement '{!r}'".format(stmt)) # :nocov:
@@ -109,6 +153,21 @@ class StatementTransformer:
         return self.on_statement(value)
 
 
+class StatementTransformer(AbstractStatementTransformer):
+    def on_value(self, value):
+        return value
+
+    def on_Assign(self, stmt):
+        return Assign(self.on_value(stmt.lhs), self.on_value(stmt.rhs))
+
+    def on_Switch(self, stmt):
+        cases = OrderedDict((k, self.on_statement(v)) for k, v in stmt.cases.items())
+        return Switch(self.on_value(stmt.test), cases)
+
+    def on_statements(self, stmt):
+        return _StatementList(flatten(self.on_statement(stmt) for stmt in stmt))
+
+
 class FragmentTransformer:
     def map_subfragments(self, fragment, new_fragment):
         for subfragment, name in fragment.subfragments: