13aaca3f21700f776715692b6246e6e040878d76
[nmutil.git] / src / nmutil / multipipe.py
1 """ Combinatorial Multi-input and Multi-output multiplexer blocks
2 conforming to Pipeline API
3
4 This work is funded through NLnet under Grant 2019-02-012
5
6 License: LGPLv3+
7
8 Multi-input is complex because if any one input is ready, the output
9 can be ready, and the decision comes from a separate module.
10
11 Multi-output is simple (pretty much identical to UnbufferedPipeline),
12 and the selection is just a mux. The only proviso (difference) being:
13 the outputs not being selected have to have their o_ready signals
14 DEASSERTED.
15
16 https://bugs.libre-soc.org/show_bug.cgi?id=538
17 """
18
19 from math import log
20 from nmigen import Signal, Cat, Const, Mux, Module, Array, Elaboratable
21 from nmigen.cli import verilog, rtlil
22 from nmigen.lib.coding import PriorityEncoder
23 from nmigen.hdl.rec import Record, Layout
24 from nmutil.stageapi import _spec
25
26 from collections.abc import Sequence
27
28 from nmutil.nmoperator import eq
29 from nmutil.iocontrol import NextControl, PrevControl
30
31
32 class MultiInControlBase(Elaboratable):
33 """ Common functions for Pipeline API
34 """
35 def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
36 """ Multi-input Control class. Conforms to same API as ControlBase...
37 mostly. has additional indices to the *multiple* input stages
38
39 * p: contains ready/valid to the previous stages PLURAL
40 * n: contains ready/valid to the next stage
41
42 User must also:
43 * add i_data members to PrevControl and
44 * add o_data member to NextControl
45 """
46 self.routemask = routemask
47 # set up input and output IO ACK (prev/next ready/valid)
48 print ("multi_in", self, maskwid, p_len, routemask)
49 p = []
50 for i in range(p_len):
51 p.append(PrevControl(in_multi, maskwid=maskwid))
52 self.p = Array(p)
53 if routemask:
54 nmaskwid = maskwid # straight route mask mode
55 else:
56 nmaskwid = maskwid * p_len # fan-in mode
57 self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
58
59 def connect_to_next(self, nxt, p_idx=0):
60 """ helper function to connect to the next stage data/valid/ready.
61 """
62 return self.n.connect_to_next(nxt.p[p_idx])
63
64 def _connect_in(self, prev, idx=0, prev_idx=None):
65 """ helper function to connect stage to an input source. do not
66 use to connect stage-to-stage!
67 """
68 if prev_idx is None:
69 return self.p[idx]._connect_in(prev.p)
70 return self.p[idx]._connect_in(prev.p[prev_idx])
71
72 def _connect_out(self, nxt):
73 """ helper function to connect stage to an output source. do not
74 use to connect stage-to-stage!
75 """
76 if nxt_idx is None:
77 return self.n._connect_out(nxt.n)
78 return self.n._connect_out(nxt.n)
79
80 def set_input(self, i, idx=0):
81 """ helper function to set the input data
82 """
83 return eq(self.p[idx].i_data, i)
84
85 def elaborate(self, platform):
86 m = Module()
87 for i, p in enumerate(self.p):
88 setattr(m.submodules, "p%d" % i, p)
89 m.submodules.n = self.n
90 return m
91
92 def __iter__(self):
93 for p in self.p:
94 yield from p
95 yield from self.n
96
97 def ports(self):
98 return list(self)
99
100
101 class MultiOutControlBase(Elaboratable):
102 """ Common functions for Pipeline API
103 """
104 def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
105 """ Multi-output Control class. Conforms to same API as ControlBase...
106 mostly. has additional indices to the multiple *output* stages
107 [MultiInControlBase has multiple *input* stages]
108
109 * p: contains ready/valid to the previou stage
110 * n: contains ready/valid to the next stages PLURAL
111
112 User must also:
113 * add i_data member to PrevControl and
114 * add o_data members to NextControl
115 """
116
117 if routemask:
118 nmaskwid = maskwid # straight route mask mode
119 else:
120 nmaskwid = maskwid * n_len # fan-out mode
121
122 # set up input and output IO ACK (prev/next ready/valid)
123 self.p = PrevControl(in_multi, maskwid=nmaskwid)
124 n = []
125 for i in range(n_len):
126 n.append(NextControl(maskwid=maskwid))
127 self.n = Array(n)
128
129 def connect_to_next(self, nxt, n_idx=0):
130 """ helper function to connect to the next stage data/valid/ready.
131 """
132 return self.n[n_idx].connect_to_next(nxt.p)
133
134 def _connect_in(self, prev, idx=0):
135 """ helper function to connect stage to an input source. do not
136 use to connect stage-to-stage!
137 """
138 return self.n[idx]._connect_in(prev.p)
139
140 def _connect_out(self, nxt, idx=0, nxt_idx=None):
141 """ helper function to connect stage to an output source. do not
142 use to connect stage-to-stage!
143 """
144 if nxt_idx is None:
145 return self.n[idx]._connect_out(nxt.n)
146 return self.n[idx]._connect_out(nxt.n[nxt_idx])
147
148 def elaborate(self, platform):
149 m = Module()
150 m.submodules.p = self.p
151 for i, n in enumerate(self.n):
152 setattr(m.submodules, "n%d" % i, n)
153 return m
154
155 def set_input(self, i):
156 """ helper function to set the input data
157 """
158 return eq(self.p.i_data, i)
159
160 def __iter__(self):
161 yield from self.p
162 for n in self.n:
163 yield from n
164
165 def ports(self):
166 return list(self)
167
168
169 class CombMultiOutPipeline(MultiOutControlBase):
170 """ A multi-input Combinatorial block conforming to the Pipeline API
171
172 Attributes:
173 -----------
174 p.i_data : stage input data (non-array). shaped according to ispec
175 n.o_data : stage output data array. shaped according to ospec
176 """
177
178 def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
179 MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
180 routemask=routemask)
181 self.stage = stage
182 self.maskwid = maskwid
183 self.routemask = routemask
184 self.n_mux = n_mux
185
186 # set up the input and output data
187 self.p.i_data = _spec(stage.ispec, 'i_data') # input type
188 for i in range(n_len):
189 name = 'o_data_%d' % i
190 self.n[i].o_data = _spec(stage.ospec, name) # output type
191
192 def process(self, i):
193 if hasattr(self.stage, "process"):
194 return self.stage.process(i)
195 return i
196
197 def elaborate(self, platform):
198 m = MultiOutControlBase.elaborate(self, platform)
199
200 if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
201 m.submodules.n_mux = self.n_mux
202
203 # need buffer register conforming to *input* spec
204 r_data = _spec(self.stage.ispec, 'r_data') # input type
205 if hasattr(self.stage, "setup"):
206 self.stage.setup(m, r_data)
207
208 # multiplexer id taken from n_mux
209 muxid = self.n_mux.m_id
210 print ("self.n_mux", self.n_mux)
211 print ("self.n_mux.m_id", self.n_mux.m_id)
212
213 self.n_mux.m_id.name = "m_id"
214
215 # temporaries
216 p_i_valid = Signal(reset_less=True)
217 pv = Signal(reset_less=True)
218 m.d.comb += p_i_valid.eq(self.p.i_valid_test)
219 #m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
220 m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
221
222 # all outputs to next stages first initialised to zero (invalid)
223 # the only output "active" is then selected by the muxid
224 for i in range(len(self.n)):
225 m.d.comb += self.n[i].o_valid.eq(0)
226 if self.routemask:
227 #with m.If(pv):
228 m.d.comb += self.n[muxid].o_valid.eq(pv)
229 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
230 else:
231 data_valid = self.n[muxid].o_valid
232 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
233 m.d.comb += data_valid.eq(p_i_valid | \
234 (~self.n[muxid].i_ready & data_valid))
235
236
237 # send data on
238 #with m.If(pv):
239 m.d.comb += eq(r_data, self.p.i_data)
240 #m.d.comb += eq(self.n[muxid].o_data, self.process(r_data))
241 for i in range(len(self.n)):
242 with m.If(muxid == i):
243 m.d.comb += eq(self.n[i].o_data, self.process(r_data))
244
245 if self.maskwid:
246 if self.routemask: # straight "routing" mode - treat like data
247 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
248 with m.If(pv):
249 m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
250 else:
251 ml = [] # accumulate output masks
252 ms = [] # accumulate output stops
253 # fan-out mode.
254 # conditionally fan-out mask bits, always fan-out stop bits
255 for i in range(len(self.n)):
256 ml.append(self.n[i].mask_o)
257 ms.append(self.n[i].stop_o)
258 m.d.comb += Cat(*ms).eq(self.p.stop_i)
259 with m.If(pv):
260 m.d.comb += Cat(*ml).eq(self.p.mask_i)
261 return m
262
263
264 class CombMultiInPipeline(MultiInControlBase):
265 """ A multi-input Combinatorial block conforming to the Pipeline API
266
267 Attributes:
268 -----------
269 p.i_data : StageInput, shaped according to ispec
270 The pipeline input
271 p.o_data : StageOutput, shaped according to ospec
272 The pipeline output
273 r_data : input_shape according to ispec
274 A temporary (buffered) copy of a prior (valid) input.
275 This is HELD if the output is not ready. It is updated
276 SYNCHRONOUSLY.
277 """
278
279 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
280 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
281 routemask=routemask)
282 self.stage = stage
283 self.maskwid = maskwid
284 self.p_mux = p_mux
285
286 # set up the input and output data
287 for i in range(p_len):
288 name = 'i_data_%d' % i
289 self.p[i].i_data = _spec(stage.ispec, name) # input type
290 self.n.o_data = _spec(stage.ospec, 'o_data')
291
292 def process(self, i):
293 if hasattr(self.stage, "process"):
294 return self.stage.process(i)
295 return i
296
297 def elaborate(self, platform):
298 m = MultiInControlBase.elaborate(self, platform)
299
300 m.submodules.p_mux = self.p_mux
301
302 # need an array of buffer registers conforming to *input* spec
303 r_data = []
304 data_valid = []
305 p_i_valid = []
306 n_i_readyn = []
307 p_len = len(self.p)
308 for i in range(p_len):
309 name = 'r_%d' % i
310 r = _spec(self.stage.ispec, name) # input type
311 r_data.append(r)
312 data_valid.append(Signal(name="data_valid", reset_less=True))
313 p_i_valid.append(Signal(name="p_i_valid", reset_less=True))
314 n_i_readyn.append(Signal(name="n_i_readyn", reset_less=True))
315 if hasattr(self.stage, "setup"):
316 print ("setup", self, self.stage, r)
317 self.stage.setup(m, r)
318 if True: # len(r_data) > 1: # hmm always create an Array even of len 1
319 p_i_valid = Array(p_i_valid)
320 n_i_readyn = Array(n_i_readyn)
321 data_valid = Array(data_valid)
322
323 nirn = Signal(reset_less=True)
324 m.d.comb += nirn.eq(~self.n.i_ready)
325 mid = self.p_mux.m_id
326 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
327 for i in range(p_len):
328 m.d.comb += data_valid[i].eq(0)
329 m.d.comb += n_i_readyn[i].eq(1)
330 m.d.comb += p_i_valid[i].eq(0)
331 #m.d.comb += self.p[i].o_ready.eq(~data_valid[i] | self.n.i_ready)
332 m.d.comb += self.p[i].o_ready.eq(0)
333 p = self.p[mid]
334 maskedout = Signal(reset_less=True)
335 if hasattr(p, "mask_i"):
336 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
337 else:
338 m.d.comb += maskedout.eq(1)
339 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active &
340 self.p[mid].i_valid)
341 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
342 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
343 anyvalid = Signal(i, reset_less=True)
344 av = []
345 for i in range(p_len):
346 av.append(data_valid[i])
347 anyvalid = Cat(*av)
348 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
349 m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
350 (n_i_readyn[mid] ))
351
352 if self.routemask:
353 # XXX hack - fixes loop
354 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
355 for i in range(p_len):
356 p = self.p[i]
357 vr = Signal(name="vr%d" % i, reset_less=True)
358 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
359 if hasattr(p, "mask_i"):
360 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
361 else:
362 m.d.comb += maskedout.eq(1)
363 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
364 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
365 with m.If(vr):
366 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
367 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
368 with m.If(mid == i):
369 m.d.comb += eq(self.n.o_data, r_data[i])
370 else:
371 ml = [] # accumulate output masks
372 ms = [] # accumulate output stops
373 for i in range(p_len):
374 vr = Signal(reset_less=True)
375 p = self.p[i]
376 vr = Signal(reset_less=True)
377 maskedout = Signal(reset_less=True)
378 if hasattr(p, "mask_i"):
379 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
380 else:
381 m.d.comb += maskedout.eq(1)
382 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
383 with m.If(vr):
384 m.d.comb += eq(r_data[i], self.process(self.p[i].i_data))
385 with m.If(mid == i):
386 m.d.comb += eq(self.n.o_data, r_data[i])
387 if self.maskwid:
388 mlen = len(self.p[i].mask_i)
389 s = mlen*i
390 e = mlen*(i+1)
391 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
392 ms.append(self.p[i].stop_i)
393 if self.maskwid:
394 m.d.comb += self.n.mask_o.eq(Cat(*ml))
395 m.d.comb += self.n.stop_o.eq(Cat(*ms))
396
397 #print ("o_data", self.n.o_data, "r_data[mid]", mid, r_data[mid])
398 #m.d.comb += eq(self.n.o_data, r_data[mid])
399
400 return m
401
402
403 class NonCombMultiInPipeline(MultiInControlBase):
404 """ A multi-input pipeline block conforming to the Pipeline API
405
406 Attributes:
407 -----------
408 p.i_data : StageInput, shaped according to ispec
409 The pipeline input
410 p.o_data : StageOutput, shaped according to ospec
411 The pipeline output
412 r_data : input_shape according to ispec
413 A temporary (buffered) copy of a prior (valid) input.
414 This is HELD if the output is not ready. It is updated
415 SYNCHRONOUSLY.
416 """
417
418 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
419 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
420 routemask=routemask)
421 self.stage = stage
422 self.maskwid = maskwid
423 self.p_mux = p_mux
424
425 # set up the input and output data
426 for i in range(p_len):
427 name = 'i_data_%d' % i
428 self.p[i].i_data = _spec(stage.ispec, name) # input type
429 self.n.o_data = _spec(stage.ospec, 'o_data')
430
431 def process(self, i):
432 if hasattr(self.stage, "process"):
433 return self.stage.process(i)
434 return i
435
436 def elaborate(self, platform):
437 m = MultiInControlBase.elaborate(self, platform)
438
439 m.submodules.p_mux = self.p_mux
440
441 # need an array of buffer registers conforming to *input* spec
442 r_data = []
443 r_busy = []
444 p_i_valid = []
445 p_len = len(self.p)
446 for i in range(p_len):
447 name = 'r_%d' % i
448 r = _spec(self.stage.ispec, name) # input type
449 r_data.append(r)
450 r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
451 p_i_valid.append(Signal(name="p_i_valid%d" % i, reset_less=True))
452 if hasattr(self.stage, "setup"):
453 print ("setup", self, self.stage, r)
454 self.stage.setup(m, r)
455 if len(r_data) > 1:
456 r_data = Array(r_data)
457 p_i_valid = Array(p_i_valid)
458 r_busy = Array(r_busy)
459
460 nirn = Signal(reset_less=True)
461 m.d.comb += nirn.eq(~self.n.i_ready)
462 mid = self.p_mux.m_id
463 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
464 for i in range(p_len):
465 m.d.comb += r_busy[i].eq(0)
466 m.d.comb += n_i_readyn[i].eq(1)
467 m.d.comb += p_i_valid[i].eq(0)
468 m.d.comb += self.p[i].o_ready.eq(n_i_readyn[i])
469 p = self.p[mid]
470 maskedout = Signal(reset_less=True)
471 if hasattr(p, "mask_i"):
472 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
473 else:
474 m.d.comb += maskedout.eq(1)
475 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
476 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
477 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
478 anyvalid = Signal(i, reset_less=True)
479 av = []
480 for i in range(p_len):
481 av.append(data_valid[i])
482 anyvalid = Cat(*av)
483 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
484 m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
485 (n_i_readyn[mid] ))
486
487 if self.routemask:
488 # XXX hack - fixes loop
489 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
490 for i in range(p_len):
491 p = self.p[i]
492 vr = Signal(name="vr%d" % i, reset_less=True)
493 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
494 if hasattr(p, "mask_i"):
495 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
496 else:
497 m.d.comb += maskedout.eq(1)
498 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
499 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
500 with m.If(vr):
501 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
502 m.d.comb += eq(r_data[i], self.p[i].i_data)
503 else:
504 ml = [] # accumulate output masks
505 ms = [] # accumulate output stops
506 for i in range(p_len):
507 vr = Signal(reset_less=True)
508 p = self.p[i]
509 vr = Signal(reset_less=True)
510 maskedout = Signal(reset_less=True)
511 if hasattr(p, "mask_i"):
512 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
513 else:
514 m.d.comb += maskedout.eq(1)
515 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
516 with m.If(vr):
517 m.d.comb += eq(r_data[i], self.p[i].i_data)
518 if self.maskwid:
519 mlen = len(self.p[i].mask_i)
520 s = mlen*i
521 e = mlen*(i+1)
522 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
523 ms.append(self.p[i].stop_i)
524 if self.maskwid:
525 m.d.comb += self.n.mask_o.eq(Cat(*ml))
526 m.d.comb += self.n.stop_o.eq(Cat(*ms))
527
528 m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
529
530 return m
531
532
533 class CombMuxOutPipe(CombMultiOutPipeline):
534 def __init__(self, stage, n_len, maskwid=0, muxidname=None,
535 routemask=False):
536 muxidname = muxidname or "muxid"
537 # HACK: stage is also the n-way multiplexer
538 CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
539 n_mux=stage, maskwid=maskwid,
540 routemask=routemask)
541
542 # HACK: n-mux is also the stage... so set the muxid equal to input muxid
543 muxid = getattr(self.p.i_data, muxidname)
544 print ("combmuxout", muxidname, muxid)
545 stage.m_id = muxid
546
547
548
549 class InputPriorityArbiter(Elaboratable):
550 """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
551 """
552 def __init__(self, pipe, num_rows):
553 self.pipe = pipe
554 self.num_rows = num_rows
555 self.mmax = int(log(self.num_rows) / log(2))
556 self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
557 self.active = Signal(reset_less=True)
558
559 def elaborate(self, platform):
560 m = Module()
561
562 assert len(self.pipe.p) == self.num_rows, \
563 "must declare input to be same size"
564 pe = PriorityEncoder(self.num_rows)
565 m.submodules.selector = pe
566
567 # connect priority encoder
568 in_ready = []
569 for i in range(self.num_rows):
570 p_i_valid = Signal(reset_less=True)
571 if self.pipe.maskwid and not self.pipe.routemask:
572 p = self.pipe.p[i]
573 maskedout = Signal(reset_less=True)
574 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
575 m.d.comb += p_i_valid.eq(maskedout.bool() & p.i_valid_test)
576 else:
577 m.d.comb += p_i_valid.eq(self.pipe.p[i].i_valid_test)
578 in_ready.append(p_i_valid)
579 m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
580 m.d.comb += self.active.eq(~pe.n) # encoder active (one input valid)
581 m.d.comb += self.m_id.eq(pe.o) # output one active input
582
583 return m
584
585 def ports(self):
586 return [self.m_id, self.active]
587
588
589
590 class PriorityCombMuxInPipe(CombMultiInPipeline):
591 """ an example of how to use the combinatorial pipeline.
592 """
593
594 def __init__(self, stage, p_len=2, maskwid=0, routemask=False):
595 p_mux = InputPriorityArbiter(self, p_len)
596 CombMultiInPipeline.__init__(self, stage, p_len, p_mux,
597 maskwid=maskwid, routemask=routemask)
598
599
600 if __name__ == '__main__':
601
602 from nmutil.test.example_buf_pipe import ExampleStage
603 dut = PriorityCombMuxInPipe(ExampleStage)
604 vl = rtlil.convert(dut, ports=dut.ports())
605 with open("test_combpipe.il", "w") as f:
606 f.write(vl)