fc06923800c5960c0c6de0220bae9abc891314f7
[litex.git] / migen / pytholite / compiler.py
1 import inspect
2 import ast
3 from operator import itemgetter
4
5 from migen.fhdl.structure import *
6 from migen.fhdl.structure import _Slice
7 from migen.fhdl import visit as fhdl
8 from migen.corelogic.fsm import FSM
9 from migen.pytholite import transel
10 from migen.pytholite.io import make_io_object, gen_io
11
12 class FinalizeError(Exception):
13 pass
14
15 class _AbstractLoad:
16 def __init__(self, target, source):
17 self.target = target
18 self.source = source
19
20 def lower(self):
21 if not self.target.finalized:
22 raise FinalizeError
23 return self.target.sel.eq(self.target.source_encoding[self.source])
24
25 class _LowerAbstractLoad(fhdl.NodeTransformer):
26 def visit_unknown(self, node):
27 if isinstance(node, _AbstractLoad):
28 return node.lower()
29 else:
30 return node
31
32 class _Register:
33 def __init__(self, name, nbits):
34 self.name = name
35 self.storage = Signal(BV(nbits), name=self.name)
36 self.source_encoding = {}
37 self.finalized = False
38
39 def load(self, source):
40 if source not in self.source_encoding:
41 self.source_encoding[source] = len(self.source_encoding) + 1
42 return _AbstractLoad(self, source)
43
44 def finalize(self):
45 if self.finalized:
46 raise FinalizeError
47 self.sel = Signal(BV(bits_for(len(self.source_encoding) + 1)), name="pl_regsel_"+self.name)
48 self.finalized = True
49
50 def get_fragment(self):
51 if not self.finalized:
52 raise FinalizeError
53 # do nothing when sel == 0
54 items = sorted(self.source_encoding.items(), key=itemgetter(1))
55 cases = [(Constant(v, self.sel.bv),
56 self.storage.eq(k)) for k, v in items]
57 sync = [Case(self.sel, *cases)]
58 return Fragment(sync=sync)
59
60 class _AbstractNextState:
61 def __init__(self, target_state):
62 self.target_state = target_state
63
64 # entry state is first state returned
65 class _StateAssembler:
66 def __init__(self):
67 self.states = []
68 self.exit_states = []
69
70 def assemble(self, n_states, n_exit_states):
71 self.states += n_states
72 for exit_state in self.exit_states:
73 exit_state.insert(0, _AbstractNextState(n_states[0]))
74 self.exit_states = n_exit_states
75
76 def ret(self):
77 return self.states, self.exit_states
78
79 class _Compiler:
80 def __init__(self, ioo, symdict, registers):
81 self.ioo = ioo
82 self.symdict = symdict
83 self.registers = registers
84 self.targetname = ""
85
86 def visit_top(self, node):
87 if isinstance(node, ast.Module) \
88 and len(node.body) == 1 \
89 and isinstance(node.body[0], ast.FunctionDef):
90 states, exit_states = self.visit_block(node.body[0].body)
91 return states
92 else:
93 raise NotImplementedError
94
95 # blocks and statements
96 def visit_block(self, statements):
97 sa = _StateAssembler()
98 statements = iter(statements)
99 while True:
100 try:
101 statement = next(statements)
102 except StopIteration:
103 return sa.ret()
104 if isinstance(statement, ast.Assign):
105 self.visit_assign(sa, statement)
106 elif isinstance(statement, ast.If):
107 self.visit_if(sa, statement)
108 elif isinstance(statement, ast.While):
109 self.visit_while(sa, statement)
110 elif isinstance(statement, ast.For):
111 self.visit_for(sa, statement)
112 elif isinstance(statement, ast.Expr):
113 self.visit_expr_statement(sa, statement)
114 else:
115 raise NotImplementedError
116
117 def visit_assign(self, sa, node):
118 if isinstance(node.targets[0], ast.Name):
119 self.targetname = node.targets[0].id
120 value = self.visit_expr(node.value, True)
121 self.targetname = ""
122
123 if isinstance(value, _Register):
124 self.registers.append(value)
125 for target in node.targets:
126 if isinstance(target, ast.Name):
127 self.symdict[target.id] = value
128 else:
129 raise NotImplementedError
130 elif isinstance(value, Value):
131 r = []
132 for target in node.targets:
133 if isinstance(target, ast.Attribute) and target.attr == "store":
134 treg = target.value
135 if isinstance(treg, ast.Name):
136 r.append(self.symdict[treg.id].load(value))
137 else:
138 raise NotImplementedError
139 else:
140 raise NotImplementedError
141 sa.assemble([r], [r])
142 else:
143 raise NotImplementedError
144
145 def visit_if(self, sa, node):
146 test = self.visit_expr(node.test)
147 states_t, exit_states_t = self.visit_block(node.body)
148 states_f, exit_states_f = self.visit_block(node.orelse)
149 exit_states = exit_states_t + exit_states_f
150
151 test_state_stmt = If(test, _AbstractNextState(states_t[0]))
152 test_state = [test_state_stmt]
153 if states_f:
154 test_state_stmt.Else(_AbstractNextState(states_f[0]))
155 else:
156 exit_states.append(test_state)
157
158 sa.assemble([test_state] + states_t + states_f,
159 exit_states)
160
161 def visit_while(self, sa, node):
162 test = self.visit_expr(node.test)
163 states_b, exit_states_b = self.visit_block(node.body)
164
165 test_state = [If(test, _AbstractNextState(states_b[0]))]
166 for exit_state in exit_states_b:
167 exit_state.insert(0, _AbstractNextState(test_state))
168
169 sa.assemble([test_state] + states_b, [test_state])
170
171 def visit_for(self, sa, node):
172 if not isinstance(node.target, ast.Name):
173 raise NotImplementedError
174 target = node.target.id
175 if target in self.symdict:
176 raise NotImplementedError("For loop target must use an available name")
177 it = self.visit_iterator(node.iter)
178 states = []
179 last_exit_states = []
180 for iteration in it:
181 self.symdict[target] = iteration
182 states_b, exit_states_b = self.visit_block(node.body)
183 for exit_state in last_exit_states:
184 exit_state.insert(0, _AbstractNextState(states_b[0]))
185 last_exit_states = exit_states_b
186 states += states_b
187 del self.symdict[target]
188 sa.assemble(states, last_exit_states)
189
190 def visit_iterator(self, node):
191 if isinstance(node, ast.List):
192 return ast.literal_eval(node)
193 elif isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
194 funcname = node.func.id
195 args = map(ast.literal_eval, node.args)
196 if funcname == "range":
197 return range(*args)
198 else:
199 raise NotImplementedError
200 else:
201 raise NotImplementedError
202
203 def visit_expr_statement(self, sa, node):
204 if isinstance(node.value, ast.Yield):
205 yvalue = node.value.value
206 if not isinstance(yvalue, ast.Call) or not isinstance(yvalue.func, ast.Name):
207 raise NotImplementedError("Unrecognized I/O sequence")
208 callee = self.symdict[yvalue.func.id]
209 states, exit_states = gen_io(self, callee, yvalue.args, [])
210 sa.assemble(states, exit_states)
211 else:
212 raise NotImplementedError
213
214 # expressions
215 def visit_expr(self, node, allow_registers=False):
216 if isinstance(node, ast.Call):
217 r = self.visit_expr_call(node)
218 if not allow_registers and isinstance(r, _Register):
219 raise NotImplementedError
220 return r
221 elif isinstance(node, ast.BinOp):
222 return self.visit_expr_binop(node)
223 elif isinstance(node, ast.Compare):
224 return self.visit_expr_compare(node)
225 elif isinstance(node, ast.Name):
226 return self.visit_expr_name(node)
227 elif isinstance(node, ast.Num):
228 return self.visit_expr_num(node)
229 else:
230 raise NotImplementedError
231
232 def visit_expr_call(self, node):
233 if isinstance(node.func, ast.Name):
234 callee = self.symdict[node.func.id]
235 else:
236 raise NotImplementedError
237 if callee == transel.Register:
238 if len(node.args) != 1:
239 raise TypeError("Register() takes exactly 1 argument")
240 nbits = ast.literal_eval(node.args[0])
241 return _Register(self.targetname, nbits)
242 elif callee == transel.bitslice:
243 if len(node.args) != 2 and len(node.args) != 3:
244 raise TypeError("bitslice() takes 2 or 3 arguments")
245 val = self.visit_expr(node.args[0])
246 low = ast.literal_eval(node.args[1])
247 if len(node.args) == 3:
248 up = ast.literal_eval(node.args[2])
249 else:
250 up = low + 1
251 return _Slice(val, low, up)
252 else:
253 raise NotImplementedError
254
255 def visit_expr_binop(self, node):
256 left = self.visit_expr(node.left)
257 right = self.visit_expr(node.right)
258 if isinstance(node.op, ast.Add):
259 return left + right
260 elif isinstance(node.op, ast.Sub):
261 return left - right
262 elif isinstance(node.op, ast.Mult):
263 return left * right
264 elif isinstance(node.op, ast.LShift):
265 return left << right
266 elif isinstance(node.op, ast.RShift):
267 return left >> right
268 elif isinstance(node.op, ast.BitOr):
269 return left | right
270 elif isinstance(node.op, ast.BitXor):
271 return left ^ right
272 elif isinstance(node.op, ast.BitAnd):
273 return left & right
274 else:
275 raise NotImplementedError
276
277 def visit_expr_compare(self, node):
278 test = self.visit_expr(node.left)
279 r = None
280 for op, rcomparator in zip(node.ops, node.comparators):
281 comparator = self.visit_expr(rcomparator)
282 if isinstance(op, ast.Eq):
283 comparison = test == comparator
284 elif isinstance(op, ast.NotEq):
285 comparison = test != comparator
286 elif isinstance(op, ast.Lt):
287 comparison = test < comparator
288 elif isinstance(op, ast.LtE):
289 comparison = test <= comparator
290 elif isinstance(op, ast.Gt):
291 comparison = test > comparator
292 elif isinstance(op, ast.GtE):
293 comparison = test >= comparator
294 else:
295 raise NotImplementedError
296 if r is None:
297 r = comparison
298 else:
299 r = r & comparison
300 test = comparator
301 return r
302
303 def visit_expr_name(self, node):
304 if node.id == "True":
305 return Constant(1)
306 if node.id == "False":
307 return Constant(0)
308 r = self.symdict[node.id]
309 if isinstance(r, _Register):
310 r = r.storage
311 if isinstance(r, int):
312 r = Constant(r)
313 return r
314
315 def visit_expr_num(self, node):
316 return Constant(node.n)
317
318 # like list.index, but using "is" instead of comparison
319 def _index_is(l, x):
320 for i, e in enumerate(l):
321 if e is x:
322 return i
323
324 class _LowerAbstractNextState(fhdl.NodeTransformer):
325 def __init__(self, fsm, states, stnames):
326 self.fsm = fsm
327 self.states = states
328 self.stnames = stnames
329
330 def visit_unknown(self, node):
331 if isinstance(node, _AbstractNextState):
332 index = _index_is(self.states, node.target_state)
333 estate = getattr(self.fsm, self.stnames[index])
334 return self.fsm.next_state(estate)
335 else:
336 return node
337
338 def _create_fsm(states):
339 stnames = ["S" + str(i) for i in range(len(states))]
340 fsm = FSM(*stnames)
341 lans = _LowerAbstractNextState(fsm, states, stnames)
342 for i, state in enumerate(states):
343 actions = lans.visit(state)
344 fsm.act(getattr(fsm, stnames[i]), *actions)
345 return fsm
346
347 def make_pytholite(func, **ioresources):
348 ioo = make_io_object(**ioresources)
349
350 tree = ast.parse(inspect.getsource(func))
351 symdict = func.__globals__.copy()
352 registers = []
353
354 states = _Compiler(ioo, symdict, registers).visit_top(tree)
355
356 regf = Fragment()
357 for register in registers:
358 register.finalize()
359 regf += register.get_fragment()
360
361 fsm = _create_fsm(states)
362 fsmf = _LowerAbstractLoad().visit(fsm.get_fragment())
363
364 ioo.fragment = regf + fsmf
365 return ioo