speed up ==, hash, <, >, <=, and >= for plain_data
[nmutil.git] / src / nmutil / multipipe.py
1 """ Combinatorial Multi-input and Multi-output multiplexer blocks
2 conforming to Pipeline API
3
4 This work is funded through NLnet under Grant 2019-02-012
5
6 License: LGPLv3+
7
8 Multi-input is complex because if any one input is ready, the output
9 can be ready, and the decision comes from a separate module.
10
11 Multi-output is simple (pretty much identical to UnbufferedPipeline),
12 and the selection is just a mux. The only proviso (difference) being:
13 the outputs not being selected have to have their o_ready signals
14 DEASSERTED.
15
16 https://bugs.libre-soc.org/show_bug.cgi?id=538
17 """
18
19 from math import log
20 from nmigen import Signal, Cat, Const, Mux, Module, Array, Elaboratable
21 from nmigen.cli import verilog, rtlil
22 from nmigen.lib.coding import PriorityEncoder
23 from nmigen.hdl.rec import Record, Layout
24 from nmutil.stageapi import _spec
25
26 from collections.abc import Sequence
27
28 from nmutil.nmoperator import eq
29 from nmutil.iocontrol import NextControl, PrevControl
30
31
32 class MultiInControlBase(Elaboratable):
33 """ Common functions for Pipeline API
34 """
35 def __init__(self, in_multi=None, p_len=1, maskwid=0, routemask=False):
36 """ Multi-input Control class. Conforms to same API as ControlBase...
37 mostly. has additional indices to the *multiple* input stages
38
39 * p: contains ready/valid to the previous stages PLURAL
40 * n: contains ready/valid to the next stage
41
42 User must also:
43 * add i_data members to PrevControl and
44 * add o_data member to NextControl
45 """
46 self.routemask = routemask
47 # set up input and output IO ACK (prev/next ready/valid)
48 print ("multi_in", self, maskwid, p_len, routemask)
49 p = []
50 for i in range(p_len):
51 p.append(PrevControl(in_multi, maskwid=maskwid))
52 self.p = Array(p)
53 if routemask:
54 nmaskwid = maskwid # straight route mask mode
55 else:
56 nmaskwid = maskwid * p_len # fan-in mode
57 self.n = NextControl(maskwid=nmaskwid) # masks fan in (Cat)
58
59 def connect_to_next(self, nxt, p_idx=0):
60 """ helper function to connect to the next stage data/valid/ready.
61 """
62 return self.n.connect_to_next(nxt.p[p_idx])
63
64 def _connect_in(self, prev, idx=0, prev_idx=None):
65 """ helper function to connect stage to an input source. do not
66 use to connect stage-to-stage!
67 """
68 if prev_idx is None:
69 return self.p[idx]._connect_in(prev.p)
70 return self.p[idx]._connect_in(prev.p[prev_idx])
71
72 def _connect_out(self, nxt):
73 """ helper function to connect stage to an output source. do not
74 use to connect stage-to-stage!
75 """
76 if nxt_idx is None:
77 return self.n._connect_out(nxt.n)
78 return self.n._connect_out(nxt.n)
79
80 def set_input(self, i, idx=0):
81 """ helper function to set the input data
82 """
83 return eq(self.p[idx].i_data, i)
84
85 def elaborate(self, platform):
86 m = Module()
87 for i, p in enumerate(self.p):
88 setattr(m.submodules, "p%d" % i, p)
89 m.submodules.n = self.n
90 return m
91
92 def __iter__(self):
93 for p in self.p:
94 yield from p
95 yield from self.n
96
97 def ports(self):
98 return list(self)
99
100
101 class MultiOutControlBase(Elaboratable):
102 """ Common functions for Pipeline API
103 """
104 def __init__(self, n_len=1, in_multi=None, maskwid=0, routemask=False):
105 """ Multi-output Control class. Conforms to same API as ControlBase...
106 mostly. has additional indices to the multiple *output* stages
107 [MultiInControlBase has multiple *input* stages]
108
109 * p: contains ready/valid to the previou stage
110 * n: contains ready/valid to the next stages PLURAL
111
112 User must also:
113 * add i_data member to PrevControl and
114 * add o_data members to NextControl
115 """
116
117 if routemask:
118 nmaskwid = maskwid # straight route mask mode
119 else:
120 nmaskwid = maskwid * n_len # fan-out mode
121
122 # set up input and output IO ACK (prev/next ready/valid)
123 self.p = PrevControl(in_multi, maskwid=nmaskwid)
124 n = []
125 for i in range(n_len):
126 n.append(NextControl(maskwid=maskwid))
127 self.n = Array(n)
128
129 def connect_to_next(self, nxt, n_idx=0):
130 """ helper function to connect to the next stage data/valid/ready.
131 """
132 return self.n[n_idx].connect_to_next(nxt.p)
133
134 def _connect_in(self, prev, idx=0):
135 """ helper function to connect stage to an input source. do not
136 use to connect stage-to-stage!
137 """
138 return self.n[idx]._connect_in(prev.p)
139
140 def _connect_out(self, nxt, idx=0, nxt_idx=None):
141 """ helper function to connect stage to an output source. do not
142 use to connect stage-to-stage!
143 """
144 if nxt_idx is None:
145 return self.n[idx]._connect_out(nxt.n)
146 return self.n[idx]._connect_out(nxt.n[nxt_idx])
147
148 def elaborate(self, platform):
149 m = Module()
150 m.submodules.p = self.p
151 for i, n in enumerate(self.n):
152 setattr(m.submodules, "n%d" % i, n)
153 return m
154
155 def set_input(self, i):
156 """ helper function to set the input data
157 """
158 return eq(self.p.i_data, i)
159
160 def __iter__(self):
161 yield from self.p
162 for n in self.n:
163 yield from n
164
165 def ports(self):
166 return list(self)
167
168
169 class CombMultiOutPipeline(MultiOutControlBase):
170 """ A multi-input Combinatorial block conforming to the Pipeline API
171
172 Attributes:
173 -----------
174 p.i_data : stage input data (non-array). shaped according to ispec
175 n.o_data : stage output data array. shaped according to ospec
176 """
177
178 def __init__(self, stage, n_len, n_mux, maskwid=0, routemask=False):
179 MultiOutControlBase.__init__(self, n_len=n_len, maskwid=maskwid,
180 routemask=routemask)
181 self.stage = stage
182 self.maskwid = maskwid
183 self.routemask = routemask
184 self.n_mux = n_mux
185
186 # set up the input and output data
187 self.p.i_data = _spec(stage.ispec, 'i_data') # input type
188 for i in range(n_len):
189 name = 'o_data_%d' % i
190 self.n[i].o_data = _spec(stage.ospec, name) # output type
191
192 def process(self, i):
193 if hasattr(self.stage, "process"):
194 return self.stage.process(i)
195 return i
196
197 def elaborate(self, platform):
198 m = MultiOutControlBase.elaborate(self, platform)
199
200 if hasattr(self.n_mux, "elaborate"): # TODO: identify submodule?
201 m.submodules.n_mux = self.n_mux
202
203 # need buffer register conforming to *input* spec
204 r_data = _spec(self.stage.ispec, 'r_data') # input type
205 if hasattr(self.stage, "setup"):
206 self.stage.setup(m, r_data)
207
208 # multiplexer id taken from n_mux
209 muxid = self.n_mux.m_id
210 print ("self.n_mux", self.n_mux)
211 print ("self.n_mux.m_id", self.n_mux.m_id)
212
213 self.n_mux.m_id.name = "m_id"
214
215 # temporaries
216 p_i_valid = Signal(reset_less=True)
217 pv = Signal(reset_less=True)
218 m.d.comb += p_i_valid.eq(self.p.i_valid_test)
219 #m.d.comb += pv.eq(self.p.i_valid) #& self.n[muxid].i_ready)
220 m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
221
222 # all outputs to next stages first initialised to zero (invalid)
223 # the only output "active" is then selected by the muxid
224 for i in range(len(self.n)):
225 m.d.comb += self.n[i].o_valid.eq(0)
226 if self.routemask:
227 #with m.If(pv):
228 m.d.comb += self.n[muxid].o_valid.eq(pv)
229 m.d.comb += self.p.o_ready.eq(self.n[muxid].i_ready)
230 else:
231 data_valid = self.n[muxid].o_valid
232 m.d.comb += self.p.o_ready.eq(~data_valid | self.n[muxid].i_ready)
233 m.d.comb += data_valid.eq(p_i_valid | \
234 (~self.n[muxid].i_ready & data_valid))
235
236
237 # send data on
238 #with m.If(pv):
239 m.d.comb += eq(r_data, self.p.i_data)
240 m.d.comb += eq(self.n[muxid].o_data, self.process(r_data))
241
242 if self.maskwid:
243 if self.routemask: # straight "routing" mode - treat like data
244 m.d.comb += self.n[muxid].stop_o.eq(self.p.stop_i)
245 with m.If(pv):
246 m.d.comb += self.n[muxid].mask_o.eq(self.p.mask_i)
247 else:
248 ml = [] # accumulate output masks
249 ms = [] # accumulate output stops
250 # fan-out mode.
251 # conditionally fan-out mask bits, always fan-out stop bits
252 for i in range(len(self.n)):
253 ml.append(self.n[i].mask_o)
254 ms.append(self.n[i].stop_o)
255 m.d.comb += Cat(*ms).eq(self.p.stop_i)
256 with m.If(pv):
257 m.d.comb += Cat(*ml).eq(self.p.mask_i)
258 return m
259
260
261 class CombMultiInPipeline(MultiInControlBase):
262 """ A multi-input Combinatorial block conforming to the Pipeline API
263
264 Attributes:
265 -----------
266 p.i_data : StageInput, shaped according to ispec
267 The pipeline input
268 p.o_data : StageOutput, shaped according to ospec
269 The pipeline output
270 r_data : input_shape according to ispec
271 A temporary (buffered) copy of a prior (valid) input.
272 This is HELD if the output is not ready. It is updated
273 SYNCHRONOUSLY.
274 """
275
276 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
277 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
278 routemask=routemask)
279 self.stage = stage
280 self.maskwid = maskwid
281 self.p_mux = p_mux
282
283 # set up the input and output data
284 for i in range(p_len):
285 name = 'i_data_%d' % i
286 self.p[i].i_data = _spec(stage.ispec, name) # input type
287 self.n.o_data = _spec(stage.ospec, 'o_data')
288
289 def process(self, i):
290 if hasattr(self.stage, "process"):
291 return self.stage.process(i)
292 return i
293
294 def elaborate(self, platform):
295 m = MultiInControlBase.elaborate(self, platform)
296
297 m.submodules.p_mux = self.p_mux
298
299 # need an array of buffer registers conforming to *input* spec
300 r_data = []
301 data_valid = []
302 p_i_valid = []
303 n_i_readyn = []
304 p_len = len(self.p)
305 for i in range(p_len):
306 name = 'r_%d' % i
307 r = _spec(self.stage.ispec, name) # input type
308 r_data.append(r)
309 data_valid.append(Signal(name="data_valid", reset_less=True))
310 p_i_valid.append(Signal(name="p_i_valid", reset_less=True))
311 n_i_readyn.append(Signal(name="n_i_readyn", reset_less=True))
312 if hasattr(self.stage, "setup"):
313 print ("setup", self, self.stage, r)
314 self.stage.setup(m, r)
315 if len(r_data) > 1:
316 r_data = Array(r_data)
317 p_i_valid = Array(p_i_valid)
318 n_i_readyn = Array(n_i_readyn)
319 data_valid = Array(data_valid)
320
321 nirn = Signal(reset_less=True)
322 m.d.comb += nirn.eq(~self.n.i_ready)
323 mid = self.p_mux.m_id
324 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
325 for i in range(p_len):
326 m.d.comb += data_valid[i].eq(0)
327 m.d.comb += n_i_readyn[i].eq(1)
328 m.d.comb += p_i_valid[i].eq(0)
329 #m.d.comb += self.p[i].o_ready.eq(~data_valid[i] | self.n.i_ready)
330 m.d.comb += self.p[i].o_ready.eq(0)
331 p = self.p[mid]
332 maskedout = Signal(reset_less=True)
333 if hasattr(p, "mask_i"):
334 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
335 else:
336 m.d.comb += maskedout.eq(1)
337 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
338 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
339 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
340 anyvalid = Signal(i, reset_less=True)
341 av = []
342 for i in range(p_len):
343 av.append(data_valid[i])
344 anyvalid = Cat(*av)
345 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
346 m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
347 (n_i_readyn[mid] ))
348
349 if self.routemask:
350 # XXX hack - fixes loop
351 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
352 for i in range(p_len):
353 p = self.p[i]
354 vr = Signal(name="vr%d" % i, reset_less=True)
355 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
356 if hasattr(p, "mask_i"):
357 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
358 else:
359 m.d.comb += maskedout.eq(1)
360 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
361 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
362 with m.If(vr):
363 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
364 m.d.comb += eq(r_data[i], self.p[i].i_data)
365 else:
366 ml = [] # accumulate output masks
367 ms = [] # accumulate output stops
368 for i in range(p_len):
369 vr = Signal(reset_less=True)
370 p = self.p[i]
371 vr = Signal(reset_less=True)
372 maskedout = Signal(reset_less=True)
373 if hasattr(p, "mask_i"):
374 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
375 else:
376 m.d.comb += maskedout.eq(1)
377 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
378 with m.If(vr):
379 m.d.comb += eq(r_data[i], self.p[i].i_data)
380 if self.maskwid:
381 mlen = len(self.p[i].mask_i)
382 s = mlen*i
383 e = mlen*(i+1)
384 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
385 ms.append(self.p[i].stop_i)
386 if self.maskwid:
387 m.d.comb += self.n.mask_o.eq(Cat(*ml))
388 m.d.comb += self.n.stop_o.eq(Cat(*ms))
389
390 m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
391
392 return m
393
394
395 class NonCombMultiInPipeline(MultiInControlBase):
396 """ A multi-input pipeline block conforming to the Pipeline API
397
398 Attributes:
399 -----------
400 p.i_data : StageInput, shaped according to ispec
401 The pipeline input
402 p.o_data : StageOutput, shaped according to ospec
403 The pipeline output
404 r_data : input_shape according to ispec
405 A temporary (buffered) copy of a prior (valid) input.
406 This is HELD if the output is not ready. It is updated
407 SYNCHRONOUSLY.
408 """
409
410 def __init__(self, stage, p_len, p_mux, maskwid=0, routemask=False):
411 MultiInControlBase.__init__(self, p_len=p_len, maskwid=maskwid,
412 routemask=routemask)
413 self.stage = stage
414 self.maskwid = maskwid
415 self.p_mux = p_mux
416
417 # set up the input and output data
418 for i in range(p_len):
419 name = 'i_data_%d' % i
420 self.p[i].i_data = _spec(stage.ispec, name) # input type
421 self.n.o_data = _spec(stage.ospec, 'o_data')
422
423 def process(self, i):
424 if hasattr(self.stage, "process"):
425 return self.stage.process(i)
426 return i
427
428 def elaborate(self, platform):
429 m = MultiInControlBase.elaborate(self, platform)
430
431 m.submodules.p_mux = self.p_mux
432
433 # need an array of buffer registers conforming to *input* spec
434 r_data = []
435 r_busy = []
436 p_i_valid = []
437 p_len = len(self.p)
438 for i in range(p_len):
439 name = 'r_%d' % i
440 r = _spec(self.stage.ispec, name) # input type
441 r_data.append(r)
442 r_busy.append(Signal(name="r_busy%d" % i, reset_less=True))
443 p_i_valid.append(Signal(name="p_i_valid%d" % i, reset_less=True))
444 if hasattr(self.stage, "setup"):
445 print ("setup", self, self.stage, r)
446 self.stage.setup(m, r)
447 if len(r_data) > 1:
448 r_data = Array(r_data)
449 p_i_valid = Array(p_i_valid)
450 r_busy = Array(r_busy)
451
452 nirn = Signal(reset_less=True)
453 m.d.comb += nirn.eq(~self.n.i_ready)
454 mid = self.p_mux.m_id
455 print ("CombMuxIn mid", self, self.stage, self.routemask, mid, p_len)
456 for i in range(p_len):
457 m.d.comb += r_busy[i].eq(0)
458 m.d.comb += n_i_readyn[i].eq(1)
459 m.d.comb += p_i_valid[i].eq(0)
460 m.d.comb += self.p[i].o_ready.eq(n_i_readyn[i])
461 p = self.p[mid]
462 maskedout = Signal(reset_less=True)
463 if hasattr(p, "mask_i"):
464 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
465 else:
466 m.d.comb += maskedout.eq(1)
467 m.d.comb += p_i_valid[mid].eq(maskedout & self.p_mux.active)
468 m.d.comb += self.p[mid].o_ready.eq(~data_valid[mid] | self.n.i_ready)
469 m.d.comb += n_i_readyn[mid].eq(nirn & data_valid[mid])
470 anyvalid = Signal(i, reset_less=True)
471 av = []
472 for i in range(p_len):
473 av.append(data_valid[i])
474 anyvalid = Cat(*av)
475 m.d.comb += self.n.o_valid.eq(anyvalid.bool())
476 m.d.comb += data_valid[mid].eq(p_i_valid[mid] | \
477 (n_i_readyn[mid] ))
478
479 if self.routemask:
480 # XXX hack - fixes loop
481 m.d.comb += eq(self.n.stop_o, self.p[-1].stop_i)
482 for i in range(p_len):
483 p = self.p[i]
484 vr = Signal(name="vr%d" % i, reset_less=True)
485 maskedout = Signal(name="maskedout%d" % i, reset_less=True)
486 if hasattr(p, "mask_i"):
487 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
488 else:
489 m.d.comb += maskedout.eq(1)
490 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
491 #m.d.comb += vr.eq(p.i_valid & p.o_ready)
492 with m.If(vr):
493 m.d.comb += eq(self.n.mask_o, self.p[i].mask_i)
494 m.d.comb += eq(r_data[i], self.p[i].i_data)
495 else:
496 ml = [] # accumulate output masks
497 ms = [] # accumulate output stops
498 for i in range(p_len):
499 vr = Signal(reset_less=True)
500 p = self.p[i]
501 vr = Signal(reset_less=True)
502 maskedout = Signal(reset_less=True)
503 if hasattr(p, "mask_i"):
504 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
505 else:
506 m.d.comb += maskedout.eq(1)
507 m.d.comb += vr.eq(maskedout.bool() & p.i_valid & p.o_ready)
508 with m.If(vr):
509 m.d.comb += eq(r_data[i], self.p[i].i_data)
510 if self.maskwid:
511 mlen = len(self.p[i].mask_i)
512 s = mlen*i
513 e = mlen*(i+1)
514 ml.append(Mux(vr, self.p[i].mask_i, Const(0, mlen)))
515 ms.append(self.p[i].stop_i)
516 if self.maskwid:
517 m.d.comb += self.n.mask_o.eq(Cat(*ml))
518 m.d.comb += self.n.stop_o.eq(Cat(*ms))
519
520 m.d.comb += eq(self.n.o_data, self.process(r_data[mid]))
521
522 return m
523
524
525 class CombMuxOutPipe(CombMultiOutPipeline):
526 def __init__(self, stage, n_len, maskwid=0, muxidname=None,
527 routemask=False):
528 muxidname = muxidname or "muxid"
529 # HACK: stage is also the n-way multiplexer
530 CombMultiOutPipeline.__init__(self, stage, n_len=n_len,
531 n_mux=stage, maskwid=maskwid,
532 routemask=routemask)
533
534 # HACK: n-mux is also the stage... so set the muxid equal to input muxid
535 muxid = getattr(self.p.i_data, muxidname)
536 print ("combmuxout", muxidname, muxid)
537 stage.m_id = muxid
538
539
540
541 class InputPriorityArbiter(Elaboratable):
542 """ arbitration module for Input-Mux pipe, baed on PriorityEncoder
543 """
544 def __init__(self, pipe, num_rows):
545 self.pipe = pipe
546 self.num_rows = num_rows
547 self.mmax = int(log(self.num_rows) / log(2))
548 self.m_id = Signal(self.mmax, reset_less=True) # multiplex id
549 self.active = Signal(reset_less=True)
550
551 def elaborate(self, platform):
552 m = Module()
553
554 assert len(self.pipe.p) == self.num_rows, \
555 "must declare input to be same size"
556 pe = PriorityEncoder(self.num_rows)
557 m.submodules.selector = pe
558
559 # connect priority encoder
560 in_ready = []
561 for i in range(self.num_rows):
562 p_i_valid = Signal(reset_less=True)
563 if self.pipe.maskwid and not self.pipe.routemask:
564 p = self.pipe.p[i]
565 maskedout = Signal(reset_less=True)
566 m.d.comb += maskedout.eq(p.mask_i & ~p.stop_i)
567 m.d.comb += p_i_valid.eq(maskedout.bool() & p.i_valid_test)
568 else:
569 m.d.comb += p_i_valid.eq(self.pipe.p[i].i_valid_test)
570 in_ready.append(p_i_valid)
571 m.d.comb += pe.i.eq(Cat(*in_ready)) # array of input "valids"
572 m.d.comb += self.active.eq(~pe.n) # encoder active (one input valid)
573 m.d.comb += self.m_id.eq(pe.o) # output one active input
574
575 return m
576
577 def ports(self):
578 return [self.m_id, self.active]
579
580
581
582 class PriorityCombMuxInPipe(CombMultiInPipeline):
583 """ an example of how to use the combinatorial pipeline.
584 """
585
586 def __init__(self, stage, p_len=2, maskwid=0, routemask=False):
587 p_mux = InputPriorityArbiter(self, p_len)
588 CombMultiInPipeline.__init__(self, stage, p_len, p_mux,
589 maskwid=maskwid, routemask=routemask)
590
591
592 if __name__ == '__main__':
593
594 from nmutil.test.example_buf_pipe import ExampleStage
595 dut = PriorityCombMuxInPipe(ExampleStage)
596 vl = rtlil.convert(dut, ports=dut.ports())
597 with open("test_combpipe.il", "w") as f:
598 f.write(vl)