class MuxInOut:
- def __init__(self, dut, width, fpkls, fpop, vals, single_op, opcode):
+ def __init__(self, dut, width, fpkls, fpop, vals, single_op, opcode,
+ cancel=False):
+ self.cancel = cancel # allow (test) cancellation
self.dut = dut
self.fpkls = fpkls
self.fpop = fpop
self.opcode = opcode
self.di = {}
self.do = {}
+ self.sent = {}
self.tlen = len(vals) // dut.num_rows
self.width = width
for muxid in range(dut.num_rows):
self.di[muxid] = {}
- self.do[muxid] = []
+ self.do[muxid] = {}
+ self.sent[muxid] = []
+
for i in range(self.tlen):
if self.single_op:
#print ("vals", vals)
res = self.fpop(self.fpkls(op1), self.fpkls(op2))
self.di[muxid][i] = (op1, op2)
if hasattr(res, "bits"):
- self.do[muxid].append(res.bits)
+ self.do[muxid][i] = res.bits
else:
- self.do[muxid].append(res) # for FP to INT
+ self.do[muxid][i] = res # for FP to INT
def send(self, muxid):
+ rs = self.dut.p[muxid]
for i in range(self.tlen):
if self.single_op:
op1, = self.di[muxid][i]
else:
op1, op2 = self.di[muxid][i]
- rs = self.dut.p[muxid]
yield rs.valid_i.eq(1)
yield rs.data_i.a.eq(op1)
if self.opcode is not None:
print("send", muxid, i, hex(op1), hex(op2), hex(res.bits),
fop1, fop2, res)
+ self.sent[muxid].append(i)
+
yield rs.valid_i.eq(0)
+ if hasattr(rs, "mask_i"):
+ yield rs.mask_i.eq(0) # TEMPORARY HACK
+ # wait until it's received
+ while i in self.sent[muxid]:
+ yield
+
# wait random period of time before queueing another value
for i in range(randint(0, 3)):
yield
# send = randint(0, send_range) != 0
def rcv(self, muxid):
+ rs = self.dut.p[muxid]
while True:
+
+ # check cancellation
+ cancel = self.cancel and (randint(0, 2) == 0)
+ if hasattr(rs, "mask_i") and len(self.sent[muxid]) > 0 and cancel:
+ todel = self.sent[muxid].pop()
+ print ("to delete", muxid, self.sent[muxid], todel)
+ if todel in self.do[muxid]:
+ del self.do[muxid][todel]
+ yield rs.stop_i.eq(1)
+ print ("left", muxid, self.do[muxid])
+ if len(self.do[muxid]) == 0:
+ break
+
#stall_range = randint(0, 3)
#for j in range(randint(1,10)):
# stall = randint(0, stall_range) != 0
n = self.dut.n[muxid]
yield n.ready_i.eq(1)
yield
+ if hasattr(rs, "mask_i"):
+ yield rs.stop_i.eq(0) # resets cancel mask
+
o_n_valid = yield n.valid_o
i_n_ready = yield n.ready_i
if not o_n_valid or not i_n_ready:
out_muxid = yield n.data_o.muxid
out_z = yield n.data_o.z
- out_i = 0
+ if not self.sent[muxid]:
+ print ("cancelled/recv", muxid, hex(out_z))
+ continue
+
+ out_i = self.sent[muxid].pop()
print("recv", out_muxid, hex(out_z), "expected",
hex(self.do[muxid][out_i]))
# see if this output has occurred already, delete it if it has
assert muxid == out_muxid, "out_muxid %d not correct %d" % \
(out_muxid, muxid)
+
assert self.do[muxid][out_i] == out_z
+
+ print ("senddel", muxid, out_i, self.sent[muxid])
del self.do[muxid][out_i]
# check if there's any more outputs
if len(self.do[muxid]) == 0:
break
+
print("recv ended", muxid)
def runfp(dut, width, name, fpkls, fpop, single_op=False, n_vals=10,
- vals=None, opcode=None):
+ vals=None, opcode=None, cancel=False):
vl = rtlil.convert(dut, ports=dut.ports())
with open("%s.il" % name, "w") as f:
f.write(vl)
class TestDivPipe(unittest.TestCase):
def test_pipe_rsqrt_fp16(self):
- dut = FPDIVMuxInOut(16, 4)
+ dut = FPDIVMuxInOut(16, 8)
# don't forget to initialize opcode; don't use magic numbers
opcode = int(DivPipeCoreOperation.RSqrtRem)
runfp(dut, 16, "test_fprsqrt_pipe_fp16", Float16, rsqrt,
- single_op=True, opcode=opcode, n_vals=100)
+ single_op=True, opcode=opcode, n_vals=100, cancel=True)
def test_pipe_rsqrt_fp32(self):
- dut = FPDIVMuxInOut(32, 4)
+ dut = FPDIVMuxInOut(32, 8)
# don't forget to initialize opcode; don't use magic numbers
opcode = int(DivPipeCoreOperation.RSqrtRem)
runfp(dut, 32, "test_fprsqrt_pipe_fp32", Float32, rsqrt,
- single_op=True, opcode=opcode, n_vals=100)
+ single_op=True, opcode=opcode, n_vals=100, cancel=True)
@unittest.skip("rsqrt not implemented for fp64")
def test_pipe_rsqrt_fp64(self):
- dut = FPDIVMuxInOut(64, 4)
+ dut = FPDIVMuxInOut(64, 8)
# don't forget to initialize opcode; don't use magic numbers
opcode = int(DivPipeCoreOperation.RSqrtRem)
runfp(dut, 64, "test_fprsqrt_pipe_fp64", Float64, rsqrt,
- single_op=True, opcode=opcode, n_vals=100)
+ single_op=True, opcode=opcode, n_vals=100, cancel=True)
if __name__ == '__main__':