From: Sebastien Bourdeauducq Date: Sun, 11 Nov 2012 22:48:23 +0000 (+0100) Subject: pytholite: move expression and register handling to separate modules X-Git-Tag: 24jan2021_ls180~2099^2~790 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bf5ce8dc20339c93a201a38d02b9a59d58082278;p=litex.git pytholite: move expression and register handling to separate modules --- diff --git a/migen/pytholite/compiler.py b/migen/pytholite/compiler.py index 5ba96e9a..ec5c41d1 100644 --- a/migen/pytholite/compiler.py +++ b/migen/pytholite/compiler.py @@ -1,62 +1,14 @@ import inspect import ast -from operator import itemgetter from migen.fhdl.structure import * from migen.fhdl.structure import _Slice -from migen.fhdl import visit as fhdl +from migen.pytholite.reg import * +from migen.pytholite.expr import * from migen.pytholite import transel from migen.pytholite.io import make_io_object, gen_io from migen.pytholite.fsm import * -class FinalizeError(Exception): - pass - -class _AbstractLoad: - def __init__(self, target, source): - self.target = target - self.source = source - - def lower(self): - if not self.target.finalized: - raise FinalizeError - return self.target.sel.eq(self.target.source_encoding[self.source]) - -class _LowerAbstractLoad(fhdl.NodeTransformer): - def visit_unknown(self, node): - if isinstance(node, _AbstractLoad): - return node.lower() - else: - return node - -class _Register: - def __init__(self, name, nbits): - self.name = name - self.storage = Signal(BV(nbits), name=self.name) - self.source_encoding = {} - self.finalized = False - - def load(self, source): - if source not in self.source_encoding: - self.source_encoding[source] = len(self.source_encoding) + 1 - return _AbstractLoad(self, source) - - def finalize(self): - if self.finalized: - raise FinalizeError - self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name) - self.finalized = True - - def get_fragment(self): - if not self.finalized: - raise FinalizeError - # do nothing when sel == 0 - items = sorted(self.source_encoding.items(), key=itemgetter(1)) - cases = [(Constant(v, self.sel.bv), - self.storage.eq(k)) for k, v in items] - sync = [Case(self.sel, *cases)] - return Fragment(sync=sync) - def _is_name_used(node, name): for n in ast.walk(node): if isinstance(n, ast.Name) and n.id == name: @@ -68,6 +20,7 @@ class _Compiler: self.ioo = ioo self.symdict = symdict self.registers = registers + self.ec = ExprCompiler(self.symdict) def visit_top(self, node): if isinstance(node, ast.Module) \ @@ -109,12 +62,15 @@ class _Compiler: def visit_assign(self, sa, node, statements): if isinstance(node.value, ast.Call): + is_special = False try: - value = self.visit_expr_call(node.value) + value = self.ec.visit_expr_call(node.value) except NotImplementedError: + is_special = True + if is_special: return self.visit_assign_special(sa, node, statements) else: - value = self.visit_expr(node.value) + value = self.ec.visit_expr(node.value) if isinstance(value, Value): r = [] for target in node.targets: @@ -146,7 +102,7 @@ class _Compiler: targetname = node.targets[0].id else: targetname = "unk" - reg = _Register(targetname, nbits) + reg = ImplRegister(targetname, nbits) self.registers.append(reg) for target in node.targets: if isinstance(target, ast.Name): @@ -173,6 +129,7 @@ class _Compiler: or not isinstance(ystatement.value, ast.Yield) \ or not isinstance(ystatement.value.value, ast.Name) \ or ystatement.value.value.id != modelname: + print(ast.dump(ystatement)) raise NotImplementedError("Unrecognized I/O pattern") # following optional statements are assignments to registers @@ -202,7 +159,7 @@ class _Compiler: return fstatement def visit_if(self, sa, node): - test = self.visit_expr(node.test) + test = self.ec.visit_expr(node.test) states_t, exit_states_t = self.visit_block(node.body) states_f, exit_states_f = self.visit_block(node.orelse) exit_states = exit_states_t + exit_states_f @@ -218,7 +175,7 @@ class _Compiler: exit_states) def visit_while(self, sa, node): - test = self.visit_expr(node.test) + test = self.ec.visit_expr(node.test) states_b, exit_states_b = self.visit_block(node.body) test_state = [If(test, AbstractNextState(states_b[0]))] @@ -269,102 +226,6 @@ class _Compiler: sa.assemble(states, exit_states) else: raise NotImplementedError - - # expressions - def visit_expr(self, node): - if isinstance(node, ast.Call): - return self.visit_expr_call(node) - elif isinstance(node, ast.BinOp): - return self.visit_expr_binop(node) - elif isinstance(node, ast.Compare): - return self.visit_expr_compare(node) - elif isinstance(node, ast.Name): - return self.visit_expr_name(node) - elif isinstance(node, ast.Num): - return self.visit_expr_num(node) - else: - raise NotImplementedError - - def visit_expr_call(self, node): - if isinstance(node.func, ast.Name): - callee = self.symdict[node.func.id] - else: - raise NotImplementedError - if callee == transel.bitslice: - if len(node.args) != 2 and len(node.args) != 3: - raise TypeError("bitslice() takes 2 or 3 arguments") - val = self.visit_expr(node.args[0]) - low = ast.literal_eval(node.args[1]) - if len(node.args) == 3: - up = ast.literal_eval(node.args[2]) - else: - up = low + 1 - return _Slice(val, low, up) - else: - raise NotImplementedError - - def visit_expr_binop(self, node): - left = self.visit_expr(node.left) - right = self.visit_expr(node.right) - if isinstance(node.op, ast.Add): - return left + right - elif isinstance(node.op, ast.Sub): - return left - right - elif isinstance(node.op, ast.Mult): - return left * right - elif isinstance(node.op, ast.LShift): - return left << right - elif isinstance(node.op, ast.RShift): - return left >> right - elif isinstance(node.op, ast.BitOr): - return left | right - elif isinstance(node.op, ast.BitXor): - return left ^ right - elif isinstance(node.op, ast.BitAnd): - return left & right - else: - raise NotImplementedError - - def visit_expr_compare(self, node): - test = self.visit_expr(node.left) - r = None - for op, rcomparator in zip(node.ops, node.comparators): - comparator = self.visit_expr(rcomparator) - if isinstance(op, ast.Eq): - comparison = test == comparator - elif isinstance(op, ast.NotEq): - comparison = test != comparator - elif isinstance(op, ast.Lt): - comparison = test < comparator - elif isinstance(op, ast.LtE): - comparison = test <= comparator - elif isinstance(op, ast.Gt): - comparison = test > comparator - elif isinstance(op, ast.GtE): - comparison = test >= comparator - else: - raise NotImplementedError - if r is None: - r = comparison - else: - r = r & comparison - test = comparator - return r - - def visit_expr_name(self, node): - if node.id == "True": - return Constant(1) - if node.id == "False": - return Constant(0) - r = self.symdict[node.id] - if isinstance(r, _Register): - r = r.storage - if isinstance(r, int): - r = Constant(r) - return r - - def visit_expr_num(self, node): - return Constant(node.n) def make_pytholite(func, **ioresources): ioo = make_io_object(**ioresources) @@ -381,7 +242,7 @@ def make_pytholite(func, **ioresources): regf += register.get_fragment() fsm = implement_fsm(states) - fsmf = _LowerAbstractLoad().visit(fsm.get_fragment()) + fsmf = LowerAbstractLoad().visit(fsm.get_fragment()) ioo.fragment = regf + fsmf return ioo diff --git a/migen/pytholite/expr.py b/migen/pytholite/expr.py new file mode 100644 index 00000000..8c1c3551 --- /dev/null +++ b/migen/pytholite/expr.py @@ -0,0 +1,104 @@ +import ast + +from migen.fhdl.structure import * +from migen.pytholite import transel +from migen.pytholite.reg import * + +class ExprCompiler: + def __init__(self, symdict): + self.symdict = symdict + + def visit_expr(self, node): + if isinstance(node, ast.Call): + return self.visit_expr_call(node) + elif isinstance(node, ast.BinOp): + return self.visit_expr_binop(node) + elif isinstance(node, ast.Compare): + return self.visit_expr_compare(node) + elif isinstance(node, ast.Name): + return self.visit_expr_name(node) + elif isinstance(node, ast.Num): + return self.visit_expr_num(node) + else: + raise NotImplementedError + + def visit_expr_call(self, node): + if isinstance(node.func, ast.Name): + callee = self.symdict[node.func.id] + else: + raise NotImplementedError + if callee == transel.bitslice: + if len(node.args) != 2 and len(node.args) != 3: + raise TypeError("bitslice() takes 2 or 3 arguments") + val = self.visit_expr(node.args[0]) + low = ast.literal_eval(node.args[1]) + if len(node.args) == 3: + up = ast.literal_eval(node.args[2]) + else: + up = low + 1 + return _Slice(val, low, up) + else: + raise NotImplementedError + + def visit_expr_binop(self, node): + left = self.visit_expr(node.left) + right = self.visit_expr(node.right) + if isinstance(node.op, ast.Add): + return left + right + elif isinstance(node.op, ast.Sub): + return left - right + elif isinstance(node.op, ast.Mult): + return left * right + elif isinstance(node.op, ast.LShift): + return left << right + elif isinstance(node.op, ast.RShift): + return left >> right + elif isinstance(node.op, ast.BitOr): + return left | right + elif isinstance(node.op, ast.BitXor): + return left ^ right + elif isinstance(node.op, ast.BitAnd): + return left & right + else: + raise NotImplementedError + + def visit_expr_compare(self, node): + test = self.visit_expr(node.left) + r = None + for op, rcomparator in zip(node.ops, node.comparators): + comparator = self.visit_expr(rcomparator) + if isinstance(op, ast.Eq): + comparison = test == comparator + elif isinstance(op, ast.NotEq): + comparison = test != comparator + elif isinstance(op, ast.Lt): + comparison = test < comparator + elif isinstance(op, ast.LtE): + comparison = test <= comparator + elif isinstance(op, ast.Gt): + comparison = test > comparator + elif isinstance(op, ast.GtE): + comparison = test >= comparator + else: + raise NotImplementedError + if r is None: + r = comparison + else: + r = r & comparison + test = comparator + return r + + def visit_expr_name(self, node): + if node.id == "True": + return Constant(1) + if node.id == "False": + return Constant(0) + r = self.symdict[node.id] + if isinstance(r, ImplRegister): + r = r.storage + if isinstance(r, int): + r = Constant(r) + return r + + def visit_expr_num(self, node): + return Constant(node.n) diff --git a/migen/pytholite/reg.py b/migen/pytholite/reg.py new file mode 100644 index 00000000..610f2330 --- /dev/null +++ b/migen/pytholite/reg.py @@ -0,0 +1,52 @@ +from operator import itemgetter + +from migen.fhdl.structure import * +from migen.fhdl import visit as fhdl + +class FinalizeError(Exception): + pass + +class AbstractLoad: + def __init__(self, target, source): + self.target = target + self.source = source + + def lower(self): + if not self.target.finalized: + raise FinalizeError + return self.target.sel.eq(self.target.source_encoding[self.source]) + +class LowerAbstractLoad(fhdl.NodeTransformer): + def visit_unknown(self, node): + if isinstance(node, AbstractLoad): + return node.lower() + else: + return node + +class ImplRegister: + def __init__(self, name, nbits): + self.name = name + self.storage = Signal(BV(nbits), name=self.name) + self.source_encoding = {} + self.finalized = False + + def load(self, source): + if source not in self.source_encoding: + self.source_encoding[source] = len(self.source_encoding) + 1 + return AbstractLoad(self, source) + + def finalize(self): + if self.finalized: + raise FinalizeError + self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name) + self.finalized = True + + def get_fragment(self): + if not self.finalized: + raise FinalizeError + # do nothing when sel == 0 + items = sorted(self.source_encoding.items(), key=itemgetter(1)) + cases = [(Constant(v, self.sel.bv), + self.storage.eq(k)) for k, v in items] + sync = [Case(self.sel, *cases)] + return Fragment(sync=sync)