fix name bug in specallocate
[ieee754fpu.git] / src / add / singlepipe.py
index a977dfabce3d9c86946117f7a6de8b0434a36a3a..4222e4694c8994ff43755e7c8eadf4d35f16dea5 100644 (file)
 
     input acceptance conditions are when:
         * incoming previous-stage strobe (p.i_valid) is HIGH
-        * outgoing previous-stage ready   (p.o_ready) is LOW
+        * outgoing previous-stage ready   (p.ready_o) is LOW
 
     output transmission conditions are when:
         * outgoing next-stage strobe (n.o_valid) is HIGH
     https://github.com/ZipCPU/dbgbus/blob/master/hexbus/rtl/hbdeword.v
 """
 
-from nmigen import Signal, Cat, Const, Mux, Module, Value
+from nmigen import Signal, Cat, Const, Mux, Module, Value, Elaboratable
 from nmigen.cli import verilog, rtlil
 from nmigen.lib.fifo import SyncFIFO, SyncFIFOBuffered
 from nmigen.hdl.ast import ArrayProxy
 from nmigen.hdl.rec import Record, Layout
 
 from abc import ABCMeta, abstractmethod
-from collections.abc import Sequence
+from collections.abc import Sequence, Iterable
+from collections import OrderedDict
 from queue import Queue
+import inspect
+
+
+class Object:
+    def __init__(self):
+        self.fields = OrderedDict()
+
+    def __setattr__(self, k, v):
+        print ("kv", k, v)
+        if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
+           k in dir(Object) or "fields" not in self.__dict__):
+            return object.__setattr__(self, k, v)
+        self.fields[k] = v
+
+    def __getattr__(self, k):
+        if k in self.__dict__:
+            return object.__getattr__(self, k)
+        try:
+            return self.fields[k]
+        except KeyError as e:
+            raise AttributeError(e)
+
+    def __iter__(self):
+        for x in self.fields.values():
+            if isinstance(x, Iterable):
+                yield from x
+            else:
+                yield x
+
+    def eq(self, inp):
+        res = []
+        for (k, o) in self.fields.items():
+            i = getattr(inp, k)
+            print ("eq", o, i)
+            rres = o.eq(i)
+            if isinstance(rres, Sequence):
+                res += rres
+            else:
+                res.append(rres)
+        print (res)
+        return res
+
+    def ports(self):
+        return list(self)
 
 
 class RecordObject(Record):
@@ -184,45 +229,65 @@ class RecordObject(Record):
         Record.__init__(self, layout=layout or [], name=None)
 
     def __setattr__(self, k, v):
-        if k in dir(Record) or "fields" not in self.__dict__:
+        #print (dir(Record))
+        if (k.startswith('_') or k in ["fields", "name", "src_loc"] or
+           k in dir(Record) or "fields" not in self.__dict__):
             return object.__setattr__(self, k, v)
         self.fields[k] = v
+        #print ("RecordObject setattr", k, v)
         if isinstance(v, Record):
             newlayout = {k: (k, v.layout)}
-        else:
+        elif isinstance(v, Value):
             newlayout = {k: (k, v.shape())}
+        else:
+            newlayout = {k: (k, shape(v))}
         self.layout.fields.update(newlayout)
 
     def __iter__(self):
         for x in self.fields.values():
-            yield x
+            if isinstance(x, Iterable):
+                yield from x
+            else:
+                yield x
 
+    def ports(self):
+        return list(self)
 
-class PrevControl:
+
+def _spec(fn, name=None):
+    if name is None:
+        return fn()
+    varnames = dict(inspect.getmembers(fn.__code__))['co_varnames']
+    if 'name' in varnames:
+        return fn(name=name)
+    return fn()
+
+
+class PrevControl(Elaboratable):
     """ contains signals that come *from* the previous stage (both in and out)
         * i_valid: previous stage indicating all incoming data is valid.
                    may be a multi-bit signal, where all bits are required
                    to be asserted to indicate "valid".
-        * o_ready: output to next stage indicating readiness to accept data
+        * ready_o: output to next stage indicating readiness to accept data
         * i_data : an input - added by the user of this class
     """
 
     def __init__(self, i_width=1, stage_ctl=False):
         self.stage_ctl = stage_ctl
         self.i_valid = Signal(i_width, name="p_i_valid") # prev   >>in  self
-        self._o_ready = Signal(name="p_o_ready") # prev   <<out self
+        self._ready_o = Signal(name="p_ready_o") # prev   <<out self
         self.i_data = None # XXX MUST BE ADDED BY USER
         if stage_ctl:
-            self.s_o_ready = Signal(name="p_s_o_rdy") # prev   <<out self
+            self.s_ready_o = Signal(name="p_s_o_rdy") # prev   <<out self
         self.trigger = Signal(reset_less=True)
 
     @property
-    def o_ready(self):
+    def ready_o(self):
         """ public-facing API: indicates (externally) that stage is ready
         """
         if self.stage_ctl:
-            return self.s_o_ready # set dynamically by stage
-        return self._o_ready      # return this when not under dynamic control
+            return self.s_ready_o # set dynamically by stage
+        return self._ready_o      # return this when not under dynamic control
 
     def _connect_in(self, prev, direct=False, fn=None):
         """ internal helper function to connect stage to an input source.
@@ -231,7 +296,7 @@ class PrevControl:
         i_valid = prev.i_valid if direct else prev.i_valid_test
         i_data = fn(prev.i_data) if fn is not None else prev.i_data
         return [self.i_valid.eq(i_valid),
-                prev.o_ready.eq(self.o_ready),
+                prev.ready_o.eq(self.ready_o),
                 eq(self.i_data, i_data),
                ]
 
@@ -249,32 +314,35 @@ class PrevControl:
         # when stage indicates not ready, incoming data
         # must "appear" to be not ready too
         if self.stage_ctl:
-            i_valid = i_valid & self.s_o_ready
+            i_valid = i_valid & self.s_ready_o
 
         return i_valid
 
     def elaborate(self, platform):
         m = Module()
-        m.d.comb += self.trigger.eq(self.i_valid_test & self.o_ready)
+        m.d.comb += self.trigger.eq(self.i_valid_test & self.ready_o)
         return m
 
     def eq(self, i):
         return [self.i_data.eq(i.i_data),
-                self.o_ready.eq(i.o_ready),
+                self.ready_o.eq(i.ready_o),
                 self.i_valid.eq(i.i_valid)]
 
-    def ports(self):
-        res = [self.i_valid, self.o_ready]
+    def __iter__(self):
+        yield self.i_valid
+        yield self.ready_o
         if hasattr(self.i_data, "ports"):
-            res += self.i_data.ports()
+            yield from self.i_data.ports()
         elif isinstance(self.i_data, Sequence):
-            res += self.i_data
+            yield from self.i_data
         else:
-            res.append(self.i_data)
-        return res
+            yield self.i_data
+
+    def ports(self):
+        return list(self)
 
 
-class NextControl:
+class NextControl(Elaboratable):
     """ contains the signals that go *to* the next stage (both in and out)
         * o_valid: output indicating to next stage that data is valid
         * i_ready: input from next stage indicating that it can accept data
@@ -301,7 +369,7 @@ class NextControl:
             use this when connecting stage-to-stage
         """
         return [nxt.i_valid.eq(self.o_valid),
-                self.i_ready.eq(nxt.o_ready),
+                self.i_ready.eq(nxt.ready_o),
                 eq(nxt.i_data, self.o_data),
                ]
 
@@ -321,15 +389,18 @@ class NextControl:
         m.d.comb += self.trigger.eq(self.i_ready_test & self.o_valid)
         return m
 
-    def ports(self):
-        res = [self.i_ready, self.o_valid]
+    def __iter__(self):
+        yield self.i_ready
+        yield self.o_valid
         if hasattr(self.o_data, "ports"):
-            res += self.o_data.ports()
+            yield from self.o_data.ports()
         elif isinstance(self.o_data, Sequence):
-            res += self.o_data
+            yield from self.o_data
         else:
-            res.append(self.o_data)
-        return res
+            yield self.o_data
+
+    def ports(self):
+        return list(self)
 
 
 class Visitor2:
@@ -419,9 +490,9 @@ class Visitor:
         if not isinstance(i, Sequence):
             i = [i]
         for ai in i:
-            print ("iterate", ai)
+            #print ("iterate", ai)
             if isinstance(ai, Record):
-                print ("record", list(ai.layout))
+                #print ("record", list(ai.layout))
                 yield from self.record_iter(ai)
             elif isinstance(ai, ArrayProxy) and not isinstance(ai, Value):
                 yield from self.array_iter(ai)
@@ -438,7 +509,7 @@ class Visitor:
                 val = getattr(val, field_name)
             else:
                 val = val[field_name] # dictionary-style specification
-            print ("recidx", idx, field_name, field_shape, val)
+            #print ("recidx", idx, field_name, field_shape, val)
             yield from self.iterate(val)
 
     def array_iter(self, ai):
@@ -460,6 +531,16 @@ def eq(o, i):
     return res
 
 
+def shape(i):
+    #print ("shape", i)
+    r = 0
+    for part in list(i):
+        #print ("shape?", part)
+        s, _ = part.shape()
+        r += s
+    return r, False
+
+
 def cat(i):
     """ flattens a compound structure recursively using Cat
     """
@@ -545,20 +626,22 @@ class StageChain(StageCls):
         self.specallocate = specallocate
 
     def ispec(self):
-        return self.chain[0].ispec()
+        return _spec(self.chain[0].ispec, "chainin")
 
     def ospec(self):
-        return self.chain[-1].ospec()
+        return _spec(self.chain[-1].ospec, "chainout")
 
     def _specallocate_setup(self, m, i):
         for (idx, c) in enumerate(self.chain):
             if hasattr(c, "setup"):
                 c.setup(m, i)               # stage may have some module stuff
-            o = self.chain[idx].ospec()     # last assignment survives
+            ofn = self.chain[idx].ospec     # last assignment survives
+            o = _spec(ofn, 'chainin%d' % idx)
             m.d.comb += eq(o, c.process(i)) # process input into "o"
             if idx == len(self.chain)-1:
                 break
-            i = self.chain[idx+1].ispec()   # new input on next loop
+            ifn = self.chain[idx+1].ispec   # new input on next loop
+            i = _spec(ifn, 'chainin%d' % (idx+1))
             m.d.comb += eq(i, o)            # assign to next input
         return o                            # last loop is the output
 
@@ -579,7 +662,7 @@ class StageChain(StageCls):
         return self.o # conform to Stage API: return last-loop output
 
 
-class ControlBase:
+class ControlBase(Elaboratable):
     """ Common functions for Pipeline API
     """
     def __init__(self, stage=None, in_multi=None, stage_ctl=False):
@@ -600,8 +683,8 @@ class ControlBase:
 
         # set up the input and output data
         if stage is not None:
-            self.p.i_data = stage.ispec() # input type
-            self.n.o_data = stage.ospec()
+            self.p.i_data = _spec(stage.ispec, "i_data") # input type
+            self.n.o_data = _spec(stage.ospec, "o_data") # output type
 
     def connect_to_next(self, nxt):
         """ helper function to connect to the next stage data/valid/ready.
@@ -660,12 +743,12 @@ class ControlBase:
 
         # connect front of chain to ourselves
         front = pipechain[0]
-        self.p.i_data = front.stage.ispec()
+        self.p.i_data = _spec(front.stage.ispec, "chainin")
         eqs += front._connect_in(self)
 
         # connect end of chain to ourselves
         end = pipechain[-1]
-        self.n.o_data = end.stage.ospec()
+        self.n.o_data = _spec(end.stage.ospec, "chainout")
         eqs += end._connect_out(self)
 
         return eqs
@@ -681,13 +764,19 @@ class ControlBase:
         """
         return eq(self.p.i_data, i)
 
+    def __iter__(self):
+        yield from self.p
+        yield from self.n
+
     def ports(self):
-        return self.p.ports() + self.n.ports()
+        return list(self)
 
-    def _elaborate(self, platform):
+    def elaborate(self, platform):
         """ handles case where stage has dynamic ready/valid functions
         """
         m = Module()
+        m.submodules.p = self.p
+        m.submodules.n = self.n
 
         if self.stage is not None and hasattr(self.stage, "setup"):
             self.stage.setup(m, self.p.i_data)
@@ -696,7 +785,7 @@ class ControlBase:
             return m
 
         # intercept the previous (outgoing) "ready", combine with stage ready
-        m.d.comb += self.p.s_o_ready.eq(self.p._o_ready & self.stage.d_ready)
+        m.d.comb += self.p.s_ready_o.eq(self.p._ready_o & self.stage.d_ready)
 
         # intercept the next (incoming) "ready" and combine it with data valid
         sdv = self.stage.d_valid(self.n.i_ready)
@@ -713,7 +802,7 @@ class BufferedHandshake(ControlBase):
         Argument: stage.  see Stage API above
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
-        stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
+        stage-1   p.ready_o <<out  stage   n.i_ready <<in    stage+1
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                             process --->----^
@@ -736,10 +825,10 @@ class BufferedHandshake(ControlBase):
     """
 
     def elaborate(self, platform):
-        self.m = ControlBase._elaborate(self, platform)
+        self.m = ControlBase.elaborate(self, platform)
 
-        result = self.stage.ospec()
-        r_data = self.stage.ospec()
+        result = _spec(self.stage.ospec, "r_tmp")
+        r_data = _spec(self.stage.ospec, "r_data")
 
         # establish some combinatorial temporaries
         o_n_validn = Signal(reset_less=True)
@@ -754,19 +843,19 @@ class BufferedHandshake(ControlBase):
         self.m.d.comb += [p_i_valid.eq(self.p.i_valid_test),
                      o_n_validn.eq(~self.n.o_valid),
                      n_i_ready.eq(self.n.i_ready_test),
-                     nir_por.eq(n_i_ready & self.p._o_ready),
-                     nir_por_n.eq(n_i_ready & ~self.p._o_ready),
+                     nir_por.eq(n_i_ready & self.p._ready_o),
+                     nir_por_n.eq(n_i_ready & ~self.p._ready_o),
                      nir_novn.eq(n_i_ready | o_n_validn),
                      nirn_novn.eq(~n_i_ready & o_n_validn),
                      npnn.eq(nir_por | nirn_novn),
-                     por_pivn.eq(self.p._o_ready & ~p_i_valid)
+                     por_pivn.eq(self.p._ready_o & ~p_i_valid)
         ]
 
         # store result of processing in combinatorial temporary
         self.m.d.comb += eq(result, self.stage.process(self.p.i_data))
 
         # if not in stall condition, update the temporary register
-        with self.m.If(self.p.o_ready): # not stalled
+        with self.m.If(self.p.ready_o): # not stalled
             self.m.d.sync += eq(r_data, result) # update buffer
 
         # data pass-through conditions
@@ -783,7 +872,7 @@ class BufferedHandshake(ControlBase):
                               eq(self.n.o_data, o_data), # flush buffer
                              ]
         # output ready conditions
-        self.m.d.sync += self.p._o_ready.eq(nir_novn | por_pivn)
+        self.m.d.sync += self.p._ready_o.eq(nir_novn | por_pivn)
 
         return self.m
 
@@ -795,7 +884,7 @@ class SimpleHandshake(ControlBase):
         Argument: stage.  see Stage API above
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
-        stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
+        stage-1   p.ready_o <<out  stage   n.i_ready <<in    stage+1
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                               +--process->--^
@@ -831,25 +920,25 @@ class SimpleHandshake(ControlBase):
     """
 
     def elaborate(self, platform):
-        self.m = m = ControlBase._elaborate(self, platform)
+        self.m = m = ControlBase.elaborate(self, platform)
 
         r_busy = Signal()
-        result = self.stage.ospec()
+        result = _spec(self.stage.ospec, "r_tmp")
 
         # establish some combinatorial temporaries
         n_i_ready = Signal(reset_less=True, name="n_i_rdy_data")
-        p_i_valid_p_o_ready = Signal(reset_less=True)
+        p_i_valid_p_ready_o = Signal(reset_less=True)
         p_i_valid = Signal(reset_less=True)
         m.d.comb += [p_i_valid.eq(self.p.i_valid_test),
                      n_i_ready.eq(self.n.i_ready_test),
-                     p_i_valid_p_o_ready.eq(p_i_valid & self.p.o_ready),
+                     p_i_valid_p_ready_o.eq(p_i_valid & self.p.ready_o),
         ]
 
         # store result of processing in combinatorial temporary
         m.d.comb += eq(result, self.stage.process(self.p.i_data))
 
         # previous valid and ready
-        with m.If(p_i_valid_p_o_ready):
+        with m.If(p_i_valid_p_ready_o):
             o_data = self._postprocess(result)
             m.d.sync += [r_busy.eq(1),      # output valid
                          eq(self.n.o_data, o_data), # update output
@@ -864,7 +953,7 @@ class SimpleHandshake(ControlBase):
 
         m.d.comb += self.n.o_valid.eq(r_busy)
         # if next is ready, so is previous
-        m.d.comb += self.p._o_ready.eq(n_i_ready)
+        m.d.comb += self.p._ready_o.eq(n_i_ready)
 
         return self.m
 
@@ -884,7 +973,7 @@ class UnbufferedPipeline(ControlBase):
         Argument: stage.  see Stage API, above
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
-        stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
+        stage-1   p.ready_o <<out  stage   n.i_ready <<in    stage+1
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |
                             r_data        result
@@ -939,21 +1028,21 @@ class UnbufferedPipeline(ControlBase):
     """
 
     def elaborate(self, platform):
-        self.m = m = ControlBase._elaborate(self, platform)
+        self.m = m = ControlBase.elaborate(self, platform)
 
         data_valid = Signal() # is data valid or not
-        r_data = self.stage.ospec() # output type
+        r_data = _spec(self.stage.ospec, "r_tmp") # output type
 
         # some temporaries
         p_i_valid = Signal(reset_less=True)
         pv = Signal(reset_less=True)
         buf_full = Signal(reset_less=True)
         m.d.comb += p_i_valid.eq(self.p.i_valid_test)
-        m.d.comb += pv.eq(self.p.i_valid & self.p.o_ready)
+        m.d.comb += pv.eq(self.p.i_valid & self.p.ready_o)
         m.d.comb += buf_full.eq(~self.n.i_ready_test & data_valid)
 
         m.d.comb += self.n.o_valid.eq(data_valid)
-        m.d.comb += self.p._o_ready.eq(~data_valid | self.n.i_ready_test)
+        m.d.comb += self.p._ready_o.eq(~data_valid | self.n.i_ready_test)
         m.d.sync += data_valid.eq(p_i_valid | buf_full)
 
         with m.If(pv):
@@ -978,7 +1067,7 @@ class UnbufferedPipeline2(ControlBase):
         Argument: stage.  see Stage API, above
 
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
-        stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
+        stage-1   p.ready_o <<out  stage   n.i_ready <<in    stage+1
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
                               |             |    |
                               +- process-> buf <-+
@@ -1025,17 +1114,17 @@ class UnbufferedPipeline2(ControlBase):
     """
 
     def elaborate(self, platform):
-        self.m = m = ControlBase._elaborate(self, platform)
+        self.m = m = ControlBase.elaborate(self, platform)
 
         buf_full = Signal() # is data valid or not
-        buf = self.stage.ospec() # output type
+        buf = _spec(self.stage.ospec, "r_tmp") # output type
 
         # some temporaries
         p_i_valid = Signal(reset_less=True)
         m.d.comb += p_i_valid.eq(self.p.i_valid_test)
 
         m.d.comb += self.n.o_valid.eq(buf_full | p_i_valid)
-        m.d.comb += self.p._o_ready.eq(~buf_full)
+        m.d.comb += self.p._ready_o.eq(~buf_full)
         m.d.sync += buf_full.eq(~self.n.i_ready_test & self.n.o_valid)
 
         o_data = Mux(buf_full, buf, self.stage.process(self.p.i_data))
@@ -1091,18 +1180,18 @@ class PassThroughHandshake(ControlBase):
     """
 
     def elaborate(self, platform):
-        self.m = m = ControlBase._elaborate(self, platform)
+        self.m = m = ControlBase.elaborate(self, platform)
 
-        r_data = self.stage.ospec() # output type
+        r_data = _spec(self.stage.ospec, "r_tmp") # output type
 
         # temporaries
         p_i_valid = Signal(reset_less=True)
         pvr = Signal(reset_less=True)
         m.d.comb += p_i_valid.eq(self.p.i_valid_test)
-        m.d.comb += pvr.eq(p_i_valid & self.p.o_ready)
+        m.d.comb += pvr.eq(p_i_valid & self.p.ready_o)
 
-        m.d.comb += self.p.o_ready.eq(~self.n.o_valid |  self.n.i_ready_test)
-        m.d.sync += self.n.o_valid.eq(p_i_valid       | ~self.p.o_ready)
+        m.d.comb += self.p.ready_o.eq(~self.n.o_valid |  self.n.i_ready_test)
+        m.d.sync += self.n.o_valid.eq(p_i_valid       | ~self.p.ready_o)
 
         odata = Mux(pvr, self.stage.process(self.p.i_data), r_data)
         m.d.sync += eq(r_data, odata)
@@ -1163,10 +1252,10 @@ class FIFOControl(ControlBase):
         ControlBase.__init__(self, stage, in_multi, stage_ctl)
 
     def elaborate(self, platform):
-        self.m = m = ControlBase._elaborate(self, platform)
+        self.m = m = ControlBase.elaborate(self, platform)
 
         # make a FIFO with a signal of equal width to the o_data.
-        (fwidth, _) = self.n.o_data.shape()
+        (fwidth, _) = shape(self.n.o_data)
         if self.buffered:
             fifo = SyncFIFOBuffered(fwidth, self.fdepth)
         else:
@@ -1174,14 +1263,14 @@ class FIFOControl(ControlBase):
         m.submodules.fifo = fifo
 
         # store result of processing in combinatorial temporary
-        result = self.stage.ospec()
+        result = _spec(self.stage.ospec, "r_temp")
         m.d.comb += eq(result, self.stage.process(self.p.i_data))
 
         # connect previous rdy/valid/data - do cat on i_data
         # NOTE: cannot do the PrevControl-looking trick because
         # of need to process the data.  shaaaame....
         m.d.comb += [fifo.we.eq(self.p.i_valid_test),
-                     self.p.o_ready.eq(fifo.writable),
+                     self.p.ready_o.eq(fifo.writable),
                      eq(fifo.din, cat(result)),
                    ]