sync = [Case(self.sel, *cases)]
return Fragment(sync=sync)
-class _AnonymousRegister:
- def __init__(self, nbits):
- self.nbits = nbits
-
-class _CompileVisitor(ast.NodeVisitor):
+class _Compiler:
def __init__(self, symdict, registers):
self.symdict = symdict
self.registers = registers
+ self.targetname = ""
- def visit_Assign(self, node):
- value = self.visit(node.value)
- if isinstance(value, _AnonymousRegister):
- if isinstance(node.targets[0], ast.Name):
- name = node.targets[0].id
+ def visit_top(self, node):
+ if isinstance(node, ast.Module) \
+ and len(node.body) == 1 \
+ and isinstance(node.body[0], ast.FunctionDef):
+ return self.visit_block(node.body[0].body)
+ else:
+ raise NotImplementedError
+
+ # blocks and statements
+ def visit_block(self, statements):
+ r = []
+ for statement in statements:
+ if isinstance(statement, ast.Assign):
+ r += self.visit_assign(statement)
else:
raise NotImplementedError
- value = _Register(name, value.nbits)
+ return r
+
+ def visit_assign(self, node):
+ if isinstance(node.targets[0], ast.Name):
+ self.targetname = node.targets[0].id
+ value = self.visit_expr(node.value, True)
+ self.targetname = ""
+
+ if isinstance(value, _Register):
self.registers.append(value)
for target in node.targets:
if isinstance(target, ast.Name):
self.symdict[target.id] = value
else:
raise NotImplementedError
+ return []
+ elif isinstance(value, Value):
+ r = []
+ for target in node.targets:
+ if isinstance(target, ast.Attribute) and target.attr == "store":
+ treg = target.value
+ if isinstance(treg, ast.Name):
+ r.append(self.symdict[treg.id].load(value))
+ else:
+ raise NotImplementedError
+ else:
+ raise NotImplementedError
+ return r
+ else:
+ raise NotImplementedError
- def visit_Call(self, node):
+ # expressions
+ def visit_expr(self, node, allow_call=False):
+ if isinstance(node, ast.Call):
+ if allow_call:
+ return self.visit_expr_call(node)
+ else:
+ raise NotImplementedError
+ elif isinstance(node, ast.BinOp):
+ return self.visit_expr_binop(node)
+ elif isinstance(node, ast.Name):
+ return self.visit_expr_name(node)
+ elif isinstance(node, ast.Num):
+ return self.visit_expr_num(node)
+ else:
+ raise NotImplementedError
+
+ def visit_expr_call(self, node):
if isinstance(node.func, ast.Name):
callee = self.symdict[node.func.id]
else:
if len(node.args) != 1:
raise TypeError("Register() takes exactly 1 argument")
nbits = ast.literal_eval(node.args[0])
- return _AnonymousRegister(nbits)
+ return _Register(self.targetname, nbits)
else:
raise NotImplementedError
+ def visit_expr_binop(self, node):
+ left = self.visit_expr(node.left)
+ right = self.visit_expr(node.right)
+ if isinstance(node.op, ast.Add):
+ return left + right
+ elif isinstance(node.op, ast.Sub):
+ return left - right
+ elif isinstance(node.op, ast.Mult):
+ return left * right
+ elif isinstance(node.op, ast.LShift):
+ return left << right
+ elif isinstance(node.op, ast.RShift):
+ return left >> right
+ elif isinstance(node.op, ast.BitOr):
+ return left | right
+ elif isinstance(node.op, ast.BitXor):
+ return left ^ right
+ elif isinstance(node.op, ast.BitAnd):
+ return left & right
+ else:
+ raise NotImplementedError
+
+ def visit_expr_name(self, node):
+ r = self.symdict[node.id]
+ if isinstance(r, _Register):
+ r = r.storage
+ return r
+
+ def visit_expr_num(self, node):
+ return node.n
+
def make_pytholite(func):
tree = ast.parse(inspect.getsource(func))
symdict = func.__globals__.copy()
registers = []
- cv = _CompileVisitor(symdict, registers)
- cv.visit(tree)
+ c = _Compiler(symdict, registers)
+ print("compilation result:")
+ print(c.visit_top(tree))
+ print("registers:")
print(registers)
- print(symdict)
+ #print("symdict:")
+ #print(symdict)
- #print(ast.dump(tree))
+ print("ast:")
+ print(ast.dump(tree))