add mask cancellation to FPDIV and to fpmux unit test
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 6 Aug 2019 11:19:33 +0000 (12:19 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 6 Aug 2019 11:19:33 +0000 (12:19 +0100)
src/ieee754/fpcommon/test/fpmux.py
src/ieee754/fpdiv/test/test_fprsqrt_pipe.py
src/ieee754/fpdiv/test/test_fprsqrt_pipe_16.py
src/nmutil/test/test_inout_unary_mux_cancel_pipe.py

index 77d70a45fa1911b66498a14a262c385e02822048..806b215f36c50f957778adf9f57b42841515fb16 100644 (file)
@@ -11,7 +11,9 @@ from nmigen.cli import verilog, rtlil
 
 
 class MuxInOut:
-    def __init__(self, dut, width, fpkls, fpop, vals, single_op, opcode):
+    def __init__(self, dut, width, fpkls, fpop, vals, single_op, opcode,
+                       cancel=False):
+        self.cancel = cancel # allow (test) cancellation
         self.dut = dut
         self.fpkls = fpkls
         self.fpop = fpop
@@ -19,11 +21,14 @@ class MuxInOut:
         self.opcode = opcode
         self.di = {}
         self.do = {}
+        self.sent = {}
         self.tlen = len(vals) // dut.num_rows
         self.width = width
         for muxid in range(dut.num_rows):
             self.di[muxid] = {}
-            self.do[muxid] = []
+            self.do[muxid] = {}
+            self.sent[muxid] = []
+
             for i in range(self.tlen):
                 if self.single_op:
                     #print ("vals", vals)
@@ -39,17 +44,17 @@ class MuxInOut:
                     res = self.fpop(self.fpkls(op1), self.fpkls(op2))
                     self.di[muxid][i] = (op1, op2)
                 if hasattr(res, "bits"):
-                    self.do[muxid].append(res.bits)
+                    self.do[muxid][i] = res.bits
                 else:
-                    self.do[muxid].append(res) # for FP to INT
+                    self.do[muxid][i] = res # for FP to INT
 
     def send(self, muxid):
+        rs = self.dut.p[muxid]
         for i in range(self.tlen):
             if self.single_op:
                 op1, = self.di[muxid][i]
             else:
                 op1, op2 = self.di[muxid][i]
-            rs = self.dut.p[muxid]
             yield rs.valid_i.eq(1)
             yield rs.data_i.a.eq(op1)
             if self.opcode is not None:
@@ -81,7 +86,15 @@ class MuxInOut:
                 print("send", muxid, i, hex(op1), hex(op2), hex(res.bits),
                               fop1, fop2, res)
 
+            self.sent[muxid].append(i)
+
             yield rs.valid_i.eq(0)
+            if hasattr(rs, "mask_i"):
+                yield rs.mask_i.eq(0) # TEMPORARY HACK
+            # wait until it's received
+            while i in self.sent[muxid]:
+                yield
+
             # wait random period of time before queueing another value
             for i in range(randint(0, 3)):
                 yield
@@ -102,7 +115,21 @@ class MuxInOut:
         #    send = randint(0, send_range) != 0
 
     def rcv(self, muxid):
+        rs = self.dut.p[muxid]
         while True:
+
+            # check cancellation
+            cancel = self.cancel and (randint(0, 2) == 0)
+            if hasattr(rs, "mask_i") and len(self.sent[muxid]) > 0 and cancel:
+                todel = self.sent[muxid].pop()
+                print ("to delete", muxid, self.sent[muxid], todel)
+                if todel in self.do[muxid]:
+                    del self.do[muxid][todel]
+                    yield rs.stop_i.eq(1)
+                print ("left", muxid, self.do[muxid])
+                if len(self.do[muxid]) == 0:
+                    break
+
             #stall_range = randint(0, 3)
             #for j in range(randint(1,10)):
             #    stall = randint(0, stall_range) != 0
@@ -111,6 +138,9 @@ class MuxInOut:
             n = self.dut.n[muxid]
             yield n.ready_i.eq(1)
             yield
+            if hasattr(rs, "mask_i"):
+                yield rs.stop_i.eq(0) # resets cancel mask
+
             o_n_valid = yield n.valid_o
             i_n_ready = yield n.ready_i
             if not o_n_valid or not i_n_ready:
@@ -119,7 +149,11 @@ class MuxInOut:
             out_muxid = yield n.data_o.muxid
             out_z = yield n.data_o.z
 
-            out_i = 0
+            if not self.sent[muxid]:
+                print ("cancelled/recv", muxid, hex(out_z))
+                continue
+
+            out_i = self.sent[muxid].pop()
 
             print("recv", out_muxid, hex(out_z), "expected",
                   hex(self.do[muxid][out_i]))
@@ -127,12 +161,16 @@ class MuxInOut:
             # see if this output has occurred already, delete it if it has
             assert muxid == out_muxid, "out_muxid %d not correct %d" % \
                                        (out_muxid, muxid)
+
             assert self.do[muxid][out_i] == out_z
+
+            print ("senddel", muxid, out_i, self.sent[muxid])
             del self.do[muxid][out_i]
 
             # check if there's any more outputs
             if len(self.do[muxid]) == 0:
                 break
+
         print("recv ended", muxid)
 
 
@@ -262,7 +300,7 @@ def pipe_cornercases_repeat(dut, name, mod, fmod, width, fn, cc, fpfn, count,
 
 
 def runfp(dut, width, name, fpkls, fpop, single_op=False, n_vals=10,
-          vals=None, opcode=None):
+          vals=None, opcode=None, cancel=False):
     vl = rtlil.convert(dut, ports=dut.ports())
     with open("%s.il" % name, "w") as f:
         f.write(vl)
index 0457c52402e5cfa1c6159e8e3ae4b83920a43275..a28cab604c1be44bf3318f8d3c6e6ec8be171255 100644 (file)
@@ -17,26 +17,26 @@ def rsqrt(x):
 
 class TestDivPipe(unittest.TestCase):
     def test_pipe_rsqrt_fp16(self):
-        dut = FPDIVMuxInOut(16, 4)
+        dut = FPDIVMuxInOut(16, 8)
         # don't forget to initialize opcode; don't use magic numbers
         opcode = int(DivPipeCoreOperation.RSqrtRem)
         runfp(dut, 16, "test_fprsqrt_pipe_fp16", Float16, rsqrt,
-              single_op=True, opcode=opcode, n_vals=100)
+              single_op=True, opcode=opcode, n_vals=100, cancel=True)
 
     def test_pipe_rsqrt_fp32(self):
-        dut = FPDIVMuxInOut(32, 4)
+        dut = FPDIVMuxInOut(32, 8)
         # don't forget to initialize opcode; don't use magic numbers
         opcode = int(DivPipeCoreOperation.RSqrtRem)
         runfp(dut, 32, "test_fprsqrt_pipe_fp32", Float32, rsqrt,
-              single_op=True, opcode=opcode, n_vals=100)
+              single_op=True, opcode=opcode, n_vals=100, cancel=True)
 
     @unittest.skip("rsqrt not implemented for fp64")
     def test_pipe_rsqrt_fp64(self):
-        dut = FPDIVMuxInOut(64, 4)
+        dut = FPDIVMuxInOut(64, 8)
         # don't forget to initialize opcode; don't use magic numbers
         opcode = int(DivPipeCoreOperation.RSqrtRem)
         runfp(dut, 64, "test_fprsqrt_pipe_fp64", Float64, rsqrt,
-              single_op=True, opcode=opcode, n_vals=100)
+              single_op=True, opcode=opcode, n_vals=100, cancel=True)
 
 
 if __name__ == '__main__':
index d8e4b2d11d98145f4483217eca5b7d1c76eb1ec3..4ca5612276958b8ad9f0c78d3dc107b7ff531e8f 100644 (file)
@@ -19,7 +19,7 @@ def rsqrt(x):
 
 class TestDivPipe(unittest.TestCase):
     def test_pipe_rsqrt_fp16(self):
-        dut = FPDIVMuxInOut(16, 4)
+        dut = FPDIVMuxInOut(16, 8)
         # don't forget to initialize opcode; don't use magic numbers
         opcode = int(DivPipeCoreOperation.RSqrtRem)
         run_pipe_fp(dut, 16, "rsqrt16", unit_test_half, Float16, None,
index 3bdf701839c7b1d918f4e47fc767d528c3ba5884..235eafe3970ac72fb34ddec29384e74696ef221a 100644 (file)
@@ -84,6 +84,7 @@ class InputTest:
             self.sent[muxid].append(i)
 
             yield rs.valid_i.eq(0)
+            yield rs.mask_i.eq(0)
             # wait until it's received
             while i in self.do[muxid]:
                 yield