add note about flatten function (one already exists in nmigen)
[nmutil.git] / src / nmutil / multipipe.py
1 """ Combinatorial Multi-input and Multi-output multiplexer blocks
2 conforming to Pipeline API
3
4 Multi-input is complex because if any one input is ready, the output
5 can be ready, and the decision comes from a separate module.
6
7 Multi-output is simple (pretty much identical to UnbufferedPipeline),
8 and the selection is just a mux. The only proviso (difference) being:
9 the outputs not being selected have to have their ready_o signals
10 DEASSERTED.
11 """
12
13 from math import log
14 from nmigen import Signal, Cat, Const, Mux, Module, Array, Elaboratable
15 from nmigen.cli import verilog, rtlil
16 from nmigen.lib.coding import PriorityEncoder
17 from nmigen.hdl.rec import Record, Layout
18 from nmutil.stageapi import _spec
19
20 from collections.abc import Sequence
21
22 from nmutil.nmoperator import eq
23 from nmutil.iocontrol import NextControl, PrevControl
24
25
26 class MultiInControlBase(Elaboratable):
27 """ Common functions for Pipeline API
28 """
29 def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
30 """ Multi-input Control class. Conforms to same API as ControlBase...
31 mostly. has additional indices to the *multiple* input stages
32
33 * p: contains ready/valid to the previous stages PLURAL
34 * n: contains ready/valid to the next stage
35
36 User must also:
37 * add data_i members to PrevControl and
38 * add data_o member to NextControl
39 """
40 self.routemask = routemask
41 # set up input and output IO ACK (prev/next ready/valid)
42 print ("multi_in", self, maskwid, p_len, routemask)
43 p = []
44 for i in range(p_len):
45 p.append(PrevControl(in_multi, maskwid=maskwid))
46 self.p = Array(p)
47 if routemask:
48 nmaskwid = maskwid # straight route mask mode
49 else:
50 nmaskwid = maskwid * p_len # fan-in mode
51 self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
52
53 def connect_to_next(self, nxt, p_idx=0):
54 """ helper function to connect to the next stage data/valid/ready.
55 """
56 return self.n.connect_to_next(nxt.p[p_idx])
57
58 def _connect_in(self, prev, idx=0, prev_idx=None):
59 """ helper function to connect stage to an input source. do not
60 use to connect stage-to-stage!
61 """
62 if prev_idx is None:
63 return self.p[idx]._connect_in(prev.p)
64 return self.p[idx]._connect_in(prev.p[prev_idx])
65
66 def _connect_out(self, nxt):
67 """ helper function to connect stage to an output source. do not
68 use to connect stage-to-stage!
69 """
70 if nxt_idx is None:
71 return self.n._connect_out(nxt.n)
72 return self.n._connect_out(nxt.n)
73
74 def set_input(self, i, idx=0):
75 """ helper function to set the input data
76 """
77 return eq(self.p[idx].data_i, i)
78
79 def elaborate(self, platform):
80 m = Module()
81 for i, p in enumerate(self.p):
82 setattr(m.submodules, "p%d" % i, p)
83 m.submodules.n = self.n
84 return m
85
86 def __iter__(self):
87 for p in self.p:
88 yield from p
89 yield from self.n
90
91 def ports(self):
92 return list(self)
93
94
95 class MultiOutControlBase(Elaboratable):
96 """ Common functions for Pipeline API
97 """
98 def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
99 """ Multi-output Control class. Conforms to same API as ControlBase...
100 mostly. has additional indices to the multiple *output* stages
101 [MultiInControlBase has multiple *input* stages]
102
103 * p: contains ready/valid to the previou stage
104 * n: contains ready/valid to the next stages PLURAL
105
106 User must also:
107 * add data_i member to PrevControl and
108 * add data_o members to NextControl
109 """
110
111 if routemask:
112 nmaskwid = maskwid # straight route mask mode
113 else:
114 nmaskwid = maskwid * n_len # fan-out mode
115
116 # set up input and output IO ACK (prev/next ready/valid)
117 self.p = PrevControl(in_multi, maskwid=nmaskwid)
118 n = []
119 for i in range(n_len):
120 n.append(NextControl(maskwid=maskwid))
121 self.n = Array(n)
122
123 def connect_to_next(self, nxt, n_idx=0):
124 """ helper function to connect to the next stage data/valid/ready.
125 """
126 return self.n[n_idx].connect_to_next(nxt.p)
127
128 def _connect_in(self, prev, idx=0):
129 """ helper function to connect stage to an input source. do not
130 use to connect stage-to-stage!
131 """
132 return self.n[idx]._connect_in(prev.p)
133
134 def _connect_out(self, nxt, idx=0, nxt_idx=None):
135 """ helper function to connect stage to an output source. do not
136 use to connect stage-to-stage!
137 """
138 if nxt_idx is None:
139 return self.n[idx]._connect_out(nxt.n)
140 return self.n[idx]._connect_out(nxt.n[nxt_idx])
141
142 def elaborate(self, platform):
143 m = Module()
144 m.submodules.p = self.p
145 for i, n in enumerate(self.n):
146 setattr(m.submodules, "n%d" % i, n)
147 return m
148
149 def set_input(self, i):
150 """ helper function to set the input data
151 """
152 return eq(self.p.data_i, i)
153
154 def __iter__(self):
155 yield from self.p
156 for n in self.n:
157 yield from n
158
159 def ports(self):
160 return list(self)
161
162
163 class CombMultiOutPipeline(MultiOutControlBase):
164 """ A multi-input Combinatorial block conforming to the Pipeline API
165
166 Attributes:
167 -----------
168 p.data_i : stage input data (non-array). shaped according to ispec
169 n.data_o : stage output data array. shaped according to ospec
170 """
171
172 def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
173 MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
174 routemask=routemask)
175 self.stage = stage
176 self.maskwid = maskwid
177 self.routemask = routemask
178 self.n_mux = n_mux
179
180 # set up the input and output data
181 self.p.data_i = _spec(stage.ispec, 'data_i') # input type
182 for i in range(n_len):
183 name = 'data_o_%d' % i
184 self.n[i].data_o = _spec(stage.ospec, name) # output type
185
186 def process(self, i):
187 if hasattr(self.stage, "process"):
188 return self.stage.process(i)
189 return i
190
191 def elaborate(self, platform):
192 m = MultiOutControlBase.elaborate(self, platform)
193
194 if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
195 m.submodules.n_mux = self.n_mux
196
197 # need buffer register conforming to *input* spec
198 r_data = _spec(self.stage.ispec, 'r_data') # input type
199 if hasattr(self.stage, "setup"):
200 self.stage.setup(m, r_data)
201
202 # multiplexer id taken from n_mux
203 muxid = self.n_mux.m_id
204 print ("self.n_mux", self.n_mux)
205 print ("self.n_mux.m_id", self.n_mux.m_id)
206
207 self.n_mux.m_id.name = "m_id"
208
209 # temporaries
210 p_valid_i = Signal(reset_less=True)
211 pv = Signal(reset_less=True)
212 m.d.comb += p_valid_i.eq(self.p.valid_i_test)
213 #m.d.comb += pv.eq(self.p.valid_i) #& self.n[muxid].ready_i)
214 m.d.comb += pv.eq(self.p.valid_i & self.p.ready_o)
215
216 # all outputs to next stages first initialised to zero (invalid)
217 # the only output "active" is then selected by the muxid
218 for i in range(len(self.n)):
219 m.d.comb += self.n[i].valid_o.eq(0)
220 if self.routemask:
221 #with m.If(pv):
222 m.d.comb += self.n[muxid].valid_o.eq(pv)
223 m.d.comb += self.p.ready_o.eq(self.n[muxid].ready_i)
224 else:
225 data_valid = self.n[muxid].valid_o
226 m.d.comb += self.p.ready_o.eq(~data_valid | self.n[muxid].ready_i)
227 m.d.comb += data_valid.eq(p_valid_i | \
228 (~self.n[muxid].ready_i & data_valid))
229
230
231 # send data on
232 #with m.If(pv):
233 m.d.comb += eq(r_data, self.p.data_i)
234 m.d.comb += eq(self.n[muxid].data_o, self.process(r_data))
235
236 if self.maskwid:
237 if self.routemask: # straight "routing" mode - treat like data
238 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
239 with m.If(pv):
240 m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
241 else:
242 ml = [] # accumulate output masks
243 ms = [] # accumulate output stops
244 # fan-out mode.
245 # conditionally fan-out mask bits, always fan-out stop bits
246 for i in range(len(self.n)):
247 ml.append(self.n[i].mask_o)
248 ms.append(self.n[i].stop_o)
249 m.d.comb += Cat(*ms).eq(self.p.stop_i)
250 with m.If(pv):
251 m.d.comb += Cat(*ml).eq(self.p.mask_i)
252 return m
253
254
255 class CombMultiInPipeline(MultiInControlBase):
256 """ A multi-input Combinatorial block conforming to the Pipeline API
257
258 Attributes:
259 -----------
260 p.data_i : StageInput, shaped according to ispec
261 The pipeline input
262 p.data_o : StageOutput, shaped according to ospec
263 The pipeline output
264 r_data : input_shape according to ispec
265 A temporary (buffered) copy of a prior (valid) input.
266 This is HELD if the output is not ready. It is updated
267 SYNCHRONOUSLY.
268 """
269
270 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
271 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
272 routemask=routemask)
273 self.stage = stage
274 self.maskwid = maskwid
275 self.p_mux = p_mux
276
277 # set up the input and output data
278 for i in range(p_len):
279 name = 'data_i_%d' % i
280 self.p[i].data_i = _spec(stage.ispec, name) # input type
281 self.n.data_o = _spec(stage.ospec, 'data_o')
282
283 def process(self, i):
284 if hasattr(self.stage, "process"):
285 return self.stage.process(i)
286 return i
287
288 def elaborate(self, platform):
289 m = MultiInControlBase.elaborate(self, platform)
290
291 m.submodules.p_mux = self.p_mux
292
293 # need an array of buffer registers conforming to *input* spec
294 r_data = []
295 data_valid = []
296 p_valid_i = []
297 n_ready_in = []
298 p_len = len(self.p)
299 for i in range(p_len):
300 name = 'r_%d' % i
301 r = _spec(self.stage.ispec, name) # input type
302 r_data.append(r)
303 data_valid.append(Signal(name="data_valid", reset_less=True))
304 p_valid_i.append(Signal(name="p_valid_i", reset_less=True))
305 n_ready_in.append(Signal(name="n_ready_in", reset_less=True))
306 if hasattr(self.stage, "setup"):
307 print ("setup", self, self.stage, r)
308 self.stage.setup(m, r)
309 if len(r_data) > 1:
310 r_data = Array(r_data)
311 p_valid_i = Array(p_valid_i)
312 n_ready_in = Array(n_ready_in)
313 data_valid = Array(data_valid)
314
315 nirn = Signal(reset_less=True)
316 m.d.comb += nirn.eq(~self.n.ready_i)
317 mid = self.p_mux.m_id
318 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
319 for i in range(p_len):
320 m.d.comb += data_valid[i].eq(0)
321 m.d.comb += n_ready_in[i].eq(1)
322 m.d.comb += p_valid_i[i].eq(0)
323 #m.d.comb += self.p[i].ready_o.eq(~data_valid[i] | self.n.ready_i)
324 m.d.comb += self.p[i].ready_o.eq(0)
325 p = self.p[mid]
326 maskedout = Signal(reset_less=True)
327 if hasattr(p, "mask_i"):
328 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
329 else:
330 m.d.comb += maskedout.eq(1)
331 m.d.comb += p_valid_i[mid].eq(maskedout & self.p_mux.active)
332 m.d.comb += self.p[mid].ready_o.eq(~data_valid[mid] | self.n.ready_i)
333 m.d.comb += n_ready_in[mid].eq(nirn & data_valid[mid])
334 anyvalid = Signal(i, reset_less=True)
335 av = []
336 for i in range(p_len):
337 av.append(data_valid[i])
338 anyvalid = Cat(*av)
339 m.d.comb += self.n.valid_o.eq(anyvalid.bool())
340 m.d.comb += data_valid[mid].eq(p_valid_i[mid] | \
341 (n_ready_in[mid] ))
342
343 if self.routemask:
344 # XXX hack - fixes loop
345 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
346 for i in range(p_len):
347 p = self.p[i]
348 vr = Signal(name="vr%d" % i, reset_less=True)
349 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
350 if hasattr(p, "mask_i"):
351 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
352 else:
353 m.d.comb += maskedout.eq(1)
354 m.d.comb += vr.eq(maskedout.bool() & p.valid_i & p.ready_o)
355 #m.d.comb += vr.eq(p.valid_i & p.ready_o)
356 with m.If(vr):
357 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
358 m.d.comb += eq(r_data[i], self.p[i].data_i)
359 else:
360 ml = [] # accumulate output masks
361 ms = [] # accumulate output stops
362 for i in range(p_len):
363 vr = Signal(reset_less=True)
364 p = self.p[i]
365 vr = Signal(reset_less=True)
366 maskedout = Signal(reset_less=True)
367 if hasattr(p, "mask_i"):
368 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
369 else:
370 m.d.comb += maskedout.eq(1)
371 m.d.comb += vr.eq(maskedout.bool() & p.valid_i & p.ready_o)
372 with m.If(vr):
373 m.d.comb += eq(r_data[i], self.p[i].data_i)
374 if self.maskwid:
375 mlen = len(self.p[i].mask_i)
376 s = mlen*i
377 e = mlen*(i+1)
378 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
379 ms.append(self.p[i].stop_i)
380 if self.maskwid:
381 m.d.comb += self.n.mask_o.eq(Cat(*ml))
382 m.d.comb += self.n.stop_o.eq(Cat(*ms))
383
384 m.d.comb += eq(self.n.data_o, self.process(r_data[mid]))
385
386 return m
387
388
389 class NonCombMultiInPipeline(MultiInControlBase):
390 """ A multi-input pipeline block conforming to the Pipeline API
391
392 Attributes:
393 -----------
394 p.data_i : StageInput, shaped according to ispec
395 The pipeline input
396 p.data_o : StageOutput, shaped according to ospec
397 The pipeline output
398 r_data : input_shape according to ispec
399 A temporary (buffered) copy of a prior (valid) input.
400 This is HELD if the output is not ready. It is updated
401 SYNCHRONOUSLY.
402 """
403
404 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
405 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
406 routemask=routemask)
407 self.stage = stage
408 self.maskwid = maskwid
409 self.p_mux = p_mux
410
411 # set up the input and output data
412 for i in range(p_len):
413 name = 'data_i_%d' % i
414 self.p[i].data_i = _spec(stage.ispec, name) # input type
415 self.n.data_o = _spec(stage.ospec, 'data_o')
416
417 def process(self, i):
418 if hasattr(self.stage, "process"):
419 return self.stage.process(i)
420 return i
421
422 def elaborate(self, platform):
423 m = MultiInControlBase.elaborate(self, platform)
424
425 m.submodules.p_mux = self.p_mux
426
427 # need an array of buffer registers conforming to *input* spec
428 r_data = []
429 r_busy = []
430 p_valid_i = []
431 p_len = len(self.p)
432 for i in range(p_len):
433 name = 'r_%d' % i
434 r = _spec(self.stage.ispec, name) # input type
435 r_data.append(r)
436 r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
437 p_valid_i.append(Signal(name="p_valid_i%d" % i, reset_less=True))
438 if hasattr(self.stage, "setup"):
439 print ("setup", self, self.stage, r)
440 self.stage.setup(m, r)
441 if len(r_data) > 1:
442 r_data = Array(r_data)
443 p_valid_i = Array(p_valid_i)
444 r_busy = Array(r_busy)
445
446 nirn = Signal(reset_less=True)
447 m.d.comb += nirn.eq(~self.n.ready_i)
448 mid = self.p_mux.m_id
449 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
450 for i in range(p_len):
451 m.d.comb += r_busy[i].eq(0)
452 m.d.comb += n_ready_in[i].eq(1)
453 m.d.comb += p_valid_i[i].eq(0)
454 m.d.comb += self.p[i].ready_o.eq(n_ready_in[i])
455 p = self.p[mid]
456 maskedout = Signal(reset_less=True)
457 if hasattr(p, "mask_i"):
458 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
459 else:
460 m.d.comb += maskedout.eq(1)
461 m.d.comb += p_valid_i[mid].eq(maskedout & self.p_mux.active)
462 m.d.comb += self.p[mid].ready_o.eq(~data_valid[mid] | self.n.ready_i)
463 m.d.comb += n_ready_in[mid].eq(nirn & data_valid[mid])
464 anyvalid = Signal(i, reset_less=True)
465 av = []
466 for i in range(p_len):
467 av.append(data_valid[i])
468 anyvalid = Cat(*av)
469 m.d.comb += self.n.valid_o.eq(anyvalid.bool())
470 m.d.comb += data_valid[mid].eq(p_valid_i[mid] | \
471 (n_ready_in[mid] ))
472
473 if self.routemask:
474 # XXX hack - fixes loop
475 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
476 for i in range(p_len):
477 p = self.p[i]
478 vr = Signal(name="vr%d" % i, reset_less=True)
479 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
480 if hasattr(p, "mask_i"):
481 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
482 else:
483 m.d.comb += maskedout.eq(1)
484 m.d.comb += vr.eq(maskedout.bool() & p.valid_i & p.ready_o)
485 #m.d.comb += vr.eq(p.valid_i & p.ready_o)
486 with m.If(vr):
487 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
488 m.d.comb += eq(r_data[i], self.p[i].data_i)
489 else:
490 ml = [] # accumulate output masks
491 ms = [] # accumulate output stops
492 for i in range(p_len):
493 vr = Signal(reset_less=True)
494 p = self.p[i]
495 vr = Signal(reset_less=True)
496 maskedout = Signal(reset_less=True)
497 if hasattr(p, "mask_i"):
498 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
499 else:
500 m.d.comb += maskedout.eq(1)
501 m.d.comb += vr.eq(maskedout.bool() & p.valid_i & p.ready_o)
502 with m.If(vr):
503 m.d.comb += eq(r_data[i], self.p[i].data_i)
504 if self.maskwid:
505 mlen = len(self.p[i].mask_i)
506 s = mlen*i
507 e = mlen*(i+1)
508 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
509 ms.append(self.p[i].stop_i)
510 if self.maskwid:
511 m.d.comb += self.n.mask_o.eq(Cat(*ml))
512 m.d.comb += self.n.stop_o.eq(Cat(*ms))
513
514 m.d.comb += eq(self.n.data_o, self.process(r_data[mid]))
515
516 return m
517
518
519 class CombMuxOutPipe(CombMultiOutPipeline):
520 def __init__(self, stage, n_len, maskwid=0, muxidname=None,
521 routemask=False):
522 muxidname = muxidname or "muxid"
523 # HACK: stage is also the n-way multiplexer
524 CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
525 n_mux=stage, maskwid=maskwid,
526 routemask=routemask)
527
528 # HACK: n-mux is also the stage... so set the muxid equal to input muxid
529 muxid = getattr(self.p.data_i, muxidname)
530 print ("combmuxout", muxidname, muxid)
531 stage.m_id = muxid
532
533
534
535 class InputPriorityArbiter(Elaboratable):
536 """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
537 """
538 def __init__(self, pipe, num_rows):
539 self.pipe = pipe
540 self.num_rows = num_rows
541 self.mmax = int(log(self.num_rows) / log(2))
542 self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
543 self.active = Signal(reset_less=True)
544
545 def elaborate(self, platform):
546 m = Module()
547
548 assert len(self.pipe.p) == self.num_rows, \
549 "must declare input to be same size"
550 pe = PriorityEncoder(self.num_rows)
551 m.submodules.selector = pe
552
553 # connect priority encoder
554 in_ready = []
555 for i in range(self.num_rows):
556 p_valid_i = Signal(reset_less=True)
557 if self.pipe.maskwid and not self.pipe.routemask:
558 p = self.pipe.p[i]
559 maskedout = Signal(reset_less=True)
560 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
561 m.d.comb += p_valid_i.eq(maskedout.bool() & p.valid_i_test)
562 else:
563 m.d.comb += p_valid_i.eq(self.pipe.p[i].valid_i_test)
564 in_ready.append(p_valid_i)
565 m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
566 m.d.comb += self.active.eq(~pe.n) # encoder active (one input valid)
567 m.d.comb += self.m_id.eq(pe.o) # output one active input
568
569 return m
570
571 def ports(self):
572 return [self.m_id, self.active]
573
574
575
576 class PriorityCombMuxInPipe(CombMultiInPipeline):
577 """ an example of how to use the combinatorial pipeline.
578 """
579
580 def __init__(self, stage, p_len=2, maskwid=0, routemask=False):
581 p_mux = InputPriorityArbiter(self, p_len)
582 CombMultiInPipeline.__init__(self, stage, p_len, p_mux,
583 maskwid=maskwid, routemask=routemask)
584
585
586 if __name__ == '__main__':
587
588 from nmutil.test.example_buf_pipe import ExampleStage
589 dut = PriorityCombMuxInPipe(ExampleStage)
590 vl = rtlil.convert(dut, ports=dut.ports())
591 with open("test_combpipe.il", "w") as f:
592 f.write(vl)