pytholite: move expression and register handling to separate modules
authorSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 22:48:23 +0000 (23:48 +0100)
committerSebastien Bourdeauducq <sebastien@milkymist.org>
Sun, 11 Nov 2012 22:48:23 +0000 (23:48 +0100)
migen/pytholite/compiler.py
migen/pytholite/expr.py [new file with mode: 0644]
migen/pytholite/reg.py [new file with mode: 0644]

index 5ba96e9a6af30a1df625af867833f3a3fc3d7cb2..ec5c41d1f18e663f6009e0ae46ac35e0a22bba04 100644 (file)
@@ -1,62 +1,14 @@
 import inspect
 import ast
-from operator import itemgetter
 
 from migen.fhdl.structure import *
 from migen.fhdl.structure import _Slice
-from migen.fhdl import visit as fhdl
+from migen.pytholite.reg import *
+from migen.pytholite.expr import *
 from migen.pytholite import transel
 from migen.pytholite.io import make_io_object, gen_io
 from migen.pytholite.fsm import *
 
-class FinalizeError(Exception):
-       pass
-
-class _AbstractLoad:
-       def __init__(self, target, source):
-               self.target = target
-               self.source = source
-       
-       def lower(self):
-               if not self.target.finalized:
-                       raise FinalizeError
-               return self.target.sel.eq(self.target.source_encoding[self.source])
-
-class _LowerAbstractLoad(fhdl.NodeTransformer):
-       def visit_unknown(self, node):
-               if isinstance(node, _AbstractLoad):
-                       return node.lower()
-               else:
-                       return node
-
-class _Register:
-       def __init__(self, name, nbits):
-               self.name = name
-               self.storage = Signal(BV(nbits), name=self.name)
-               self.source_encoding = {}
-               self.finalized = False
-       
-       def load(self, source):
-               if source not in self.source_encoding:
-                       self.source_encoding[source] = len(self.source_encoding) + 1
-               return _AbstractLoad(self, source)
-       
-       def finalize(self):
-               if self.finalized:
-                       raise FinalizeError
-               self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
-               self.finalized = True
-       
-       def get_fragment(self):
-               if not self.finalized:
-                       raise FinalizeError
-               # do nothing when sel == 0
-               items = sorted(self.source_encoding.items(), key=itemgetter(1))
-               cases = [(Constant(v, self.sel.bv),
-                       self.storage.eq(k)) for k, v in items]
-               sync = [Case(self.sel, *cases)]
-               return Fragment(sync=sync)
-
 def _is_name_used(node, name):
        for n in ast.walk(node):
                if isinstance(n, ast.Name) and n.id == name:
@@ -68,6 +20,7 @@ class _Compiler:
                self.ioo = ioo
                self.symdict = symdict
                self.registers = registers
+               self.ec = ExprCompiler(self.symdict)
        
        def visit_top(self, node):
                if isinstance(node, ast.Module) \
@@ -109,12 +62,15 @@ class _Compiler:
        
        def visit_assign(self, sa, node, statements):
                if isinstance(node.value, ast.Call):
+                       is_special = False
                        try:
-                               value = self.visit_expr_call(node.value)
+                               value = self.ec.visit_expr_call(node.value)
                        except NotImplementedError:
+                               is_special = True
+                       if is_special:
                                return self.visit_assign_special(sa, node, statements)
                else:
-                       value = self.visit_expr(node.value)
+                       value = self.ec.visit_expr(node.value)
                if isinstance(value, Value):
                        r = []
                        for target in node.targets:
@@ -146,7 +102,7 @@ class _Compiler:
                                targetname = node.targets[0].id
                        else:
                                targetname = "unk"
-                       reg = _Register(targetname, nbits)
+                       reg = ImplRegister(targetname, nbits)
                        self.registers.append(reg)
                        for target in node.targets:
                                if isinstance(target, ast.Name):
@@ -173,6 +129,7 @@ class _Compiler:
                  or not isinstance(ystatement.value, ast.Yield) \
                  or not isinstance(ystatement.value.value, ast.Name) \
                  or ystatement.value.value.id != modelname:
+                       print(ast.dump(ystatement))
                        raise NotImplementedError("Unrecognized I/O pattern")
                
                # following optional statements are assignments to registers
@@ -202,7 +159,7 @@ class _Compiler:
                return fstatement
        
        def visit_if(self, sa, node):
-               test = self.visit_expr(node.test)
+               test = self.ec.visit_expr(node.test)
                states_t, exit_states_t = self.visit_block(node.body)
                states_f, exit_states_f = self.visit_block(node.orelse)
                exit_states = exit_states_t + exit_states_f
@@ -218,7 +175,7 @@ class _Compiler:
                        exit_states)
        
        def visit_while(self, sa, node):
-               test = self.visit_expr(node.test)
+               test = self.ec.visit_expr(node.test)
                states_b, exit_states_b = self.visit_block(node.body)
 
                test_state = [If(test, AbstractNextState(states_b[0]))]
@@ -269,102 +226,6 @@ class _Compiler:
                        sa.assemble(states, exit_states)
                else:
                        raise NotImplementedError
-       
-       # expressions
-       def visit_expr(self, node):
-               if isinstance(node, ast.Call):
-                       return self.visit_expr_call(node)
-               elif isinstance(node, ast.BinOp):
-                       return self.visit_expr_binop(node)
-               elif isinstance(node, ast.Compare):
-                       return self.visit_expr_compare(node)
-               elif isinstance(node, ast.Name):
-                       return self.visit_expr_name(node)
-               elif isinstance(node, ast.Num):
-                       return self.visit_expr_num(node)
-               else:
-                       raise NotImplementedError
-       
-       def visit_expr_call(self, node):
-               if isinstance(node.func, ast.Name):
-                       callee = self.symdict[node.func.id]
-               else:
-                       raise NotImplementedError
-               if callee == transel.bitslice:
-                       if len(node.args) != 2 and len(node.args) != 3:
-                               raise TypeError("bitslice() takes 2 or 3 arguments")
-                       val = self.visit_expr(node.args[0])
-                       low = ast.literal_eval(node.args[1])
-                       if len(node.args) == 3:
-                               up = ast.literal_eval(node.args[2])
-                       else:
-                               up = low + 1
-                       return _Slice(val, low, up)
-               else:
-                       raise NotImplementedError
-       
-       def visit_expr_binop(self, node):
-               left = self.visit_expr(node.left)
-               right = self.visit_expr(node.right)
-               if isinstance(node.op, ast.Add):
-                       return left + right
-               elif isinstance(node.op, ast.Sub):
-                       return left - right
-               elif isinstance(node.op, ast.Mult):
-                       return left * right
-               elif isinstance(node.op, ast.LShift):
-                       return left << right
-               elif isinstance(node.op, ast.RShift):
-                       return left >> right
-               elif isinstance(node.op, ast.BitOr):
-                       return left | right
-               elif isinstance(node.op, ast.BitXor):
-                       return left ^ right
-               elif isinstance(node.op, ast.BitAnd):
-                       return left & right
-               else:
-                       raise NotImplementedError
-       
-       def visit_expr_compare(self, node):
-               test = self.visit_expr(node.left)
-               r = None
-               for op, rcomparator in zip(node.ops, node.comparators):
-                       comparator = self.visit_expr(rcomparator)
-                       if isinstance(op, ast.Eq):
-                               comparison = test == comparator
-                       elif isinstance(op, ast.NotEq):
-                               comparison = test != comparator
-                       elif isinstance(op, ast.Lt):
-                               comparison = test < comparator
-                       elif isinstance(op, ast.LtE):
-                               comparison = test <= comparator
-                       elif isinstance(op, ast.Gt):
-                               comparison = test > comparator
-                       elif isinstance(op, ast.GtE):
-                               comparison = test >= comparator
-                       else:
-                               raise NotImplementedError
-                       if r is None:
-                               r = comparison
-                       else:
-                               r = r & comparison
-                       test = comparator
-               return r
-       
-       def visit_expr_name(self, node):
-               if node.id == "True":
-                       return Constant(1)
-               if node.id == "False":
-                       return Constant(0)
-               r = self.symdict[node.id]
-               if isinstance(r, _Register):
-                       r = r.storage
-               if isinstance(r, int):
-                       r = Constant(r)
-               return r
-       
-       def visit_expr_num(self, node):
-               return Constant(node.n)
 
 def make_pytholite(func, **ioresources):
        ioo = make_io_object(**ioresources)
@@ -381,7 +242,7 @@ def make_pytholite(func, **ioresources):
                regf += register.get_fragment()
        
        fsm = implement_fsm(states)
-       fsmf = _LowerAbstractLoad().visit(fsm.get_fragment())
+       fsmf = LowerAbstractLoad().visit(fsm.get_fragment())
        
        ioo.fragment = regf + fsmf
        return ioo
diff --git a/migen/pytholite/expr.py b/migen/pytholite/expr.py
new file mode 100644 (file)
index 0000000..8c1c355
--- /dev/null
@@ -0,0 +1,104 @@
+import ast
+
+from migen.fhdl.structure import *
+from migen.pytholite import transel
+from migen.pytholite.reg import *
+
+class ExprCompiler:
+       def __init__(self, symdict):
+               self.symdict = symdict
+       
+       def visit_expr(self, node):
+               if isinstance(node, ast.Call):
+                       return self.visit_expr_call(node)
+               elif isinstance(node, ast.BinOp):
+                       return self.visit_expr_binop(node)
+               elif isinstance(node, ast.Compare):
+                       return self.visit_expr_compare(node)
+               elif isinstance(node, ast.Name):
+                       return self.visit_expr_name(node)
+               elif isinstance(node, ast.Num):
+                       return self.visit_expr_num(node)
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_call(self, node):
+               if isinstance(node.func, ast.Name):
+                       callee = self.symdict[node.func.id]
+               else:
+                       raise NotImplementedError
+               if callee == transel.bitslice:
+                       if len(node.args) != 2 and len(node.args) != 3:
+                               raise TypeError("bitslice() takes 2 or 3 arguments")
+                       val = self.visit_expr(node.args[0])
+                       low = ast.literal_eval(node.args[1])
+                       if len(node.args) == 3:
+                               up = ast.literal_eval(node.args[2])
+                       else:
+                               up = low + 1
+                       return _Slice(val, low, up)
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_binop(self, node):
+               left = self.visit_expr(node.left)
+               right = self.visit_expr(node.right)
+               if isinstance(node.op, ast.Add):
+                       return left + right
+               elif isinstance(node.op, ast.Sub):
+                       return left - right
+               elif isinstance(node.op, ast.Mult):
+                       return left * right
+               elif isinstance(node.op, ast.LShift):
+                       return left << right
+               elif isinstance(node.op, ast.RShift):
+                       return left >> right
+               elif isinstance(node.op, ast.BitOr):
+                       return left | right
+               elif isinstance(node.op, ast.BitXor):
+                       return left ^ right
+               elif isinstance(node.op, ast.BitAnd):
+                       return left & right
+               else:
+                       raise NotImplementedError
+       
+       def visit_expr_compare(self, node):
+               test = self.visit_expr(node.left)
+               r = None
+               for op, rcomparator in zip(node.ops, node.comparators):
+                       comparator = self.visit_expr(rcomparator)
+                       if isinstance(op, ast.Eq):
+                               comparison = test == comparator
+                       elif isinstance(op, ast.NotEq):
+                               comparison = test != comparator
+                       elif isinstance(op, ast.Lt):
+                               comparison = test < comparator
+                       elif isinstance(op, ast.LtE):
+                               comparison = test <= comparator
+                       elif isinstance(op, ast.Gt):
+                               comparison = test > comparator
+                       elif isinstance(op, ast.GtE):
+                               comparison = test >= comparator
+                       else:
+                               raise NotImplementedError
+                       if r is None:
+                               r = comparison
+                       else:
+                               r = r & comparison
+                       test = comparator
+               return r
+       
+       def visit_expr_name(self, node):
+               if node.id == "True":
+                       return Constant(1)
+               if node.id == "False":
+                       return Constant(0)
+               r = self.symdict[node.id]
+               if isinstance(r, ImplRegister):
+                       r = r.storage
+               if isinstance(r, int):
+                       r = Constant(r)
+               return r
+       
+       def visit_expr_num(self, node):
+               return Constant(node.n)
diff --git a/migen/pytholite/reg.py b/migen/pytholite/reg.py
new file mode 100644 (file)
index 0000000..610f233
--- /dev/null
@@ -0,0 +1,52 @@
+from operator import itemgetter
+
+from migen.fhdl.structure import *
+from migen.fhdl import visit as fhdl
+
+class FinalizeError(Exception):
+       pass
+
+class AbstractLoad:
+       def __init__(self, target, source):
+               self.target = target
+               self.source = source
+       
+       def lower(self):
+               if not self.target.finalized:
+                       raise FinalizeError
+               return self.target.sel.eq(self.target.source_encoding[self.source])
+
+class LowerAbstractLoad(fhdl.NodeTransformer):
+       def visit_unknown(self, node):
+               if isinstance(node, AbstractLoad):
+                       return node.lower()
+               else:
+                       return node
+
+class ImplRegister:
+       def __init__(self, name, nbits):
+               self.name = name
+               self.storage = Signal(BV(nbits), name=self.name)
+               self.source_encoding = {}
+               self.finalized = False
+       
+       def load(self, source):
+               if source not in self.source_encoding:
+                       self.source_encoding[source] = len(self.source_encoding) + 1
+               return AbstractLoad(self, source)
+       
+       def finalize(self):
+               if self.finalized:
+                       raise FinalizeError
+               self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
+               self.finalized = True
+       
+       def get_fragment(self):
+               if not self.finalized:
+                       raise FinalizeError
+               # do nothing when sel == 0
+               items = sorted(self.source_encoding.items(), key=itemgetter(1))
+               cases = [(Constant(v, self.sel.bv),
+                       self.storage.eq(k)) for k, v in items]
+               sync = [Case(self.sel, *cases)]
+               return Fragment(sync=sync)