get pipeline unit tests working for case where prev / next len is 1
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 25 Mar 2019 06:44:19 +0000 (06:44 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 25 Mar 2019 06:44:19 +0000 (06:44 +0000)
src/add/example_buf_pipe.py
src/add/test_buf_pipe.py

index 6c2345c4d77f88e68d16112785c400f970d585a4..1bf6370f84776ba9b37cd20adcc68b196256df50 100644 (file)
@@ -265,11 +265,11 @@ class PipelineBase:
             p.append(PrevControl(in_multi))
         for i in range(n_len):
             n.append(NextControl())
-        if p_len > 0:
+        if p_len > 1:
             self.p = Array(p)
         else:
             self.p = p
-        if n_len > 0:
+        if n_len > 1:
             self.n = Array(n)
         else:
             self.n = n
@@ -316,8 +316,6 @@ class BufferedPipeline(PipelineBase):
         if ever the input is ready and the output is not, processed data
         is stored in a temporary register.
 
-        Argument: stage.  see Stage API above
-
         stage-1   p.i_valid >>in   stage   n.o_valid out>>   stage+1
         stage-1   p.o_ready <<out  stage   n.i_ready <<in    stage+1
         stage-1   p.i_data  >>in   stage   n.o_data  out>>   stage+1
@@ -342,6 +340,17 @@ class BufferedPipeline(PipelineBase):
 
     """
     def __init__(self, stage, n_len=1, p_len=1, p_mux=None, n_mux=None):
+        """ set up a BufferedPipeline (multi-input, multi-output)
+            NOTE: n_len > 1 and p_len > 1 is NOT supported
+
+            Arguments:
+
+            * stage: see Stage API above
+            * p_len: number of inputs (PrevControls + data)
+            * n_len: number of outputs (NextControls + data)
+            * p_mux: optional multiplex selector for incoming data
+            * n_mux: optional multiplex router for outgoing data
+        """
         PipelineBase.__init__(self, stage)
         self.p_mux = p_mux
         self.n_mux = n_mux
@@ -358,7 +367,7 @@ class BufferedPipeline(PipelineBase):
         result = self.stage.ospec()
         r_data = self.stage.ospec()
         if hasattr(self.stage, "setup"):
-            for i in range(p_len):
+            for i in range(len(self.p)):
                 self.stage.setup(m, self.p[i].i_data)
 
         pi = 0 # TODO: use p_mux to decide which to select
@@ -518,32 +527,43 @@ class UnbufferedPipeline(PipelineBase):
             COMBINATORIALLY (no clock dependence).
     """
 
-    def __init__(self, stage):
-        PipelineBase.__init__(self, stage)
+    def __init__(self, stage, p_len=1, n_len=1):
+        PipelineBase.__init__(self, stage, p_len, n_len)
         self._data_valid = Signal()
 
         # set up the input and output data
-        self.p.i_data = stage.ispec() # input type
-        self.n.o_data = stage.ospec() # output type
+        for i in range(p_len):
+            self.p[i].i_data = stage.ispec() # input type
+        for i in range(n_len):
+            self.n[i].o_data = stage.ospec()
 
     def elaborate(self, platform):
         m = Module()
 
-        r_data = self.stage.ispec() # input type
+        r_data = []
         result = self.stage.ospec() # output data
-        if hasattr(self.stage, "setup"):
-            self.stage.setup(m, r_data)
+        for i in range(len(self.p)):
+            r = self.stage.ispec() # input type
+            r_data.append(r)
+            if hasattr(self.stage, "setup"):
+                self.stage.setup(m, r)
+        if len(r_data) > 1:
+            r_data = Array(r_data)
+
+        pi = 0 # TODO: use p_mux to decide which to select
+        ni = 0 # TODO: use n_nux to decide which to select
 
         p_i_valid = Signal(reset_less=True)
-        m.d.comb += p_i_valid.eq(self.p.i_valid_logic())
-        m.d.comb += eq(result, self.stage.process(r_data))
-        m.d.comb += self.n.o_valid.eq(self._data_valid)
-        m.d.comb += self.p.o_ready.eq(~self._data_valid | self.n.i_ready)
+        m.d.comb += p_i_valid.eq(self.p[pi].i_valid_logic())
+        m.d.comb += eq(result, self.stage.process(r_data[pi]))
+        m.d.comb += self.n[ni].o_valid.eq(self._data_valid)
+        m.d.comb += self.p[pi].o_ready.eq(~self._data_valid | \
+                                           self.n[ni].i_ready)
         m.d.sync += self._data_valid.eq(p_i_valid | \
-                                        (~self.n.i_ready & self._data_valid))
-        with m.If(self.p.i_valid & self.p.o_ready):
-            m.d.sync += eq(r_data, self.p.i_data)
-        m.d.comb += eq(self.n.o_data, result)
+                                    (~self.n[ni].i_ready & self._data_valid))
+        with m.If(self.p[pi].i_valid & self.p[pi].o_ready):
+            m.d.sync += eq(r_data[pi], self.p[pi].i_data)
+        m.d.comb += eq(self.n[ni].o_data, result)
         return m
 
 
index 49d53935b6cff156f0534ce5844a067494ed4dff..0811b3a62f5bb66f5c279bc7cf25ec9f7da7c071 100644 (file)
@@ -189,17 +189,17 @@ class Test5:
                     send = True
                 else:
                     send = randint(0, send_range) != 0
-                o_p_ready = yield self.dut.p.o_ready
+                o_p_ready = yield self.dut.p[0].o_ready
                 if not o_p_ready:
                     yield
                     continue
                 if send and self.i != len(self.data):
-                    yield self.dut.p.i_valid.eq(1)
+                    yield self.dut.p[0].i_valid.eq(1)
                     for v in self.dut.set_input(self.data[self.i]):
                         yield v
                     self.i += 1
                 else:
-                    yield self.dut.p.i_valid.eq(0)
+                    yield self.dut.p[0].i_valid.eq(0)
                 yield
 
     def rcv(self):
@@ -207,19 +207,19 @@ class Test5:
             stall_range = randint(0, 3)
             for j in range(randint(1,10)):
                 stall = randint(0, stall_range) != 0
-                yield self.dut.n.i_ready.eq(stall)
+                yield self.dut.n[0].i_ready.eq(stall)
                 yield
-                o_n_valid = yield self.dut.n.o_valid
-                i_n_ready = yield self.dut.n.i_ready
+                o_n_valid = yield self.dut.n[0].o_valid
+                i_n_ready = yield self.dut.n[0].i_ready
                 if not o_n_valid or not i_n_ready:
                     continue
-                if isinstance(self.dut.n.o_data, Record):
+                if isinstance(self.dut.n[0].o_data, Record):
                     o_data = {}
-                    dod = self.dut.n.o_data
+                    dod = self.dut.n[0].o_data
                     for k, v in dod.fields.items():
                         o_data[k] = yield v
                 else:
-                    o_data = yield self.dut.n.o_data
+                    o_data = yield self.dut.n[0].o_data
                 self.resultfn(o_data, self.data[self.o], self.i, self.o)
                 self.o += 1
                 if self.o == len(self.data):
@@ -582,9 +582,9 @@ if __name__ == '__main__':
     test = Test5(dut, test6_resultfn)
     run_simulation(dut, [test.send, test.rcv], vcd_name="test_ltcomb6.vcd")
 
-    ports = [dut.p.i_valid, dut.n.i_ready,
-             dut.n.o_valid, dut.p.o_ready] + \
-             list(dut.p.i_data) + [dut.n.o_data]
+    ports = [dut.p[0].i_valid, dut.n[0].i_ready,
+             dut.n[0].o_valid, dut.p[0].o_ready] + \
+             list(dut.p[0].i_data) + [dut.n[0].o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_ltcomb_pipe.il", "w") as f:
         f.write(vl)
@@ -595,10 +595,10 @@ if __name__ == '__main__':
     test = Test5(dut, test7_resultfn, data=data)
     run_simulation(dut, [test.send, test.rcv], vcd_name="test_addrecord.vcd")
 
-    ports = [dut.p.i_valid, dut.n.i_ready,
-             dut.n.o_valid, dut.p.o_ready,
-             dut.p.i_data.src1, dut.p.i_data.src2,
-             dut.n.o_data.src1, dut.n.o_data.src2]
+    ports = [dut.p[0].i_valid, dut.n[0].i_ready,
+             dut.n[0].o_valid, dut.p[0].o_ready,
+             dut.p[0].i_data.src1, dut.p[0].i_data.src2,
+             dut.n[0].o_data.src1, dut.n[0].o_data.src2]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_recordcomb_pipe.il", "w") as f:
         f.write(vl)
@@ -611,9 +611,9 @@ if __name__ == '__main__':
 
     print ("test 9")
     dut = ExampleBufPipeChain2()
-    ports = [dut.p.i_valid, dut.n.i_ready,
-             dut.n.o_valid, dut.p.o_ready] + \
-             [dut.p.i_data] + [dut.n.o_data]
+    ports = [dut.p[0].i_valid, dut.n[0].i_ready,
+             dut.n[0].o_valid, dut.p[0].o_ready] + \
+             [dut.p[0].i_data] + [dut.n[0].o_data]
     vl = rtlil.convert(dut, ports=ports)
     with open("test_bufpipechain2.il", "w") as f:
         f.write(vl)