addr_split.py: shift bytes not bits
authorTobias Platen <tplaten@posteo.de>
Sat, 8 Aug 2020 11:54:42 +0000 (13:54 +0200)
committerTobias Platen <tplaten@posteo.de>
Sat, 8 Aug 2020 11:54:42 +0000 (13:54 +0200)
src/soc/scoreboard/addr_split.py

index aa99f63c9c097c2e8b5723f7b8ad899db857bcd9..b3156595300170a9692837809b2b150301c19360 100644 (file)
@@ -8,7 +8,7 @@ Links:
 
 from soc.experiment.pimem import PortInterface
 
-from nmigen import Elaboratable, Module, Signal, Record, Array, Const
+from nmigen import Elaboratable, Module, Signal, Record, Array, Const, Cat
 from nmutil.latch import SRLatch, latchregister
 from nmigen.back.pysim import Simulator, Delay
 from nmigen.cli import verilog, rtlil
@@ -56,13 +56,29 @@ class LDLatch(Elaboratable):
     def ports(self):
         return list(self)
 
+def byteExpand(signal):
+    if(type(signal)==int):
+        ret = 0
+        shf = 0
+        while(signal>0):
+            bit = signal & 1
+            ret |= (0xFF * bit) << shf
+            signal = signal >> 1
+            shf += 8
+        return ret
+    lst = []
+    for i in range(len(signal)):
+        bit = signal[i]
+        for j in range(8):
+            lst += [bit]
+    return Cat(*lst)
 
 class LDSTSplitter(Elaboratable):
 
     def __init__(self, dwidth, awidth, dlen):
         self.dwidth, self.awidth, self.dlen = dwidth, awidth, dlen
         # cline_wid = 8<<dlen # cache line width: bytes (8) times (2^^dlen)
-        cline_wid = dwidth  # TODO: make this bytes not bits
+        cline_wid = dwidth*8  # convert bytes to bits
 
         self.pi =  PortInterface()
 
@@ -75,11 +91,10 @@ class LDSTSplitter(Elaboratable):
         self.is_ld_i = self.pi.is_ld_i #Signal(reset_less=True)
         self.is_st_i = self.pi.is_st_i #Signal(reset_less=True)
 
-        self.ld_data_o = LDData(dwidth, "ld_data_o") #port.ld
-        self.st_data_i = LDData(dwidth, "st_data_i") #port.st
+        self.ld_data_o = LDData(dwidth*8, "ld_data_o") #port.ld
+        self.st_data_i = LDData(dwidth*8, "st_data_i") #port.st
 
         self.exc = Signal(reset_less=True) # pi.exc TODO
-
         # TODO : create/connect two outgoing port interfaces
 
         self.sld_valid_o = Signal(2, reset_less=True)
@@ -98,10 +113,11 @@ class LDSTSplitter(Elaboratable):
         dlen = self.dlen
         mlen = 1 << dlen
         mzero = Const(0, mlen)
-        m.submodules.ld1 = ld1 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
-        m.submodules.ld2 = ld2 = LDLatch(self.dwidth, self.awidth-dlen, mlen)
+        m.submodules.ld1 = ld1 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
+        m.submodules.ld2 = ld2 = LDLatch(self.dwidth*8, self.awidth-dlen, mlen)
         m.submodules.lenexp = lenexp = LenExpand(self.dlen)
 
+        # FIXME bytes not bits
         # set up len-expander, len to mask.  ld1 gets first bit, ld2 gets rest
         comb += lenexp.addr_i.eq(self.addr_i)
         comb += lenexp.len_i.eq(self.len_i)
@@ -124,6 +140,11 @@ class LDSTSplitter(Elaboratable):
         comb += ashift1.eq(self.addr_i[:self.dlen])
         comb += ashift2.eq((1 << dlen)-ashift1)
 
+        #expand masks
+        mask1 = byteExpand(mask1)
+        mask2 = byteExpand(mask2)
+        mzero = byteExpand(mzero)
+
         with m.If(self.is_ld_i):
             # set up connections to LD-split.  note: not active if mask is zero
             for i, (ld, mask) in enumerate(((ld1, mask1),
@@ -139,6 +160,8 @@ class LDSTSplitter(Elaboratable):
                 comb += self.valid_o.eq(self.sld_valid_o[0])
             with m.Else():
                 comb += self.valid_o.eq(self.sld_valid_o.all())
+            ## debug output -- output mask2 and mzero
+            ## guess second port is invalid
 
             # all bits valid (including when data error occurs!) decode ld1/ld2
             with m.If(self.valid_o):
@@ -147,8 +170,8 @@ class LDSTSplitter(Elaboratable):
 
                 # note that data from LD1 will be in *cache-line* byte position
                 # likewise from LD2 but we *know* it is at the start of the line
-                comb += self.ld_data_o.data.eq((ld1.ld_o.data >> ashift1) |
-                                               (ld2.ld_o.data << ashift2))
+                comb += self.ld_data_o.data.eq((ld1.ld_o.data >> (ashift1*8)) |
+                                               (ld2.ld_o.data << (ashift2*8)))
 
         with m.If(self.is_st_i):
             for i, (ld, mask) in enumerate(((ld1, mask1),
@@ -159,8 +182,8 @@ class LDSTSplitter(Elaboratable):
                 comb += self.sld_valid_o[i].eq(ld.valid_o)
                 comb += self.sst_data_o[i].data.eq(ld.ld_o.data)
 
-            comb += ld1.ld_i.eq((self.st_data_i << ashift1) & mask1)
-            comb += ld2.ld_i.eq((self.st_data_i >> ashift2) & mask2)
+            comb += ld1.ld_i.eq((self.st_data_i << (ashift1*8)) & mask1)
+            comb += ld2.ld_i.eq((self.st_data_i >> (ashift2*8)) & mask2)
 
             # sort out valid: mask2 zero we ignore 2nd LD
             with m.If(mask2 == mzero):
@@ -196,21 +219,25 @@ def sim(dut):
 
     sim = Simulator(dut)
     sim.add_clock(1e-6)
-    data = 0b11010011
-    dlen = 4  # 4 bits
-    addr = 0b1100
+    data = 0x0102030405060708A1A2A3A4A5A6A7A8
+    dlen = 16  # data length in bytes
+    addr = 0b1110
     ld_len = 8
     ldm = ((1 << ld_len)-1)
+    ldme = byteExpand(ldm)
     dlm = ((1 << dlen)-1)
-    data = data & ldm  # truncate data to be tested, mask to within ld len
-    print("ldm", ldm, bin(data & ldm))
+    data = data & ldme  # truncate data to be tested, mask to within ld len
+    print("ldm", ldm, hex(data & ldme))
     print("dlm", dlm, bin(addr & dlm))
+
     dmask = ldm << (addr & dlm)
     print("dmask", bin(dmask))
     dmask1 = dmask >> (1 << dlen)
     print("dmask1", bin(dmask1))
     dmask = dmask & ((1 << (1 << dlen))-1)
     print("dmask", bin(dmask))
+    dmask1 = byteExpand(dmask1)
+    dmask = byteExpand(dmask)
 
     def send_ld():
         print("send_ld")
@@ -224,11 +251,14 @@ def sim(dut):
             if valid_o:
                 break
             yield
+        exc = yield dut.exc
         ld_data_o = yield dut.ld_data_o.data
         yield dut.is_ld_i.eq(0)
         yield
 
-        print(bin(ld_data_o), bin(data))
+        print(exc)
+        assert exc==0
+        print(hex(ld_data_o), hex(data))
         assert ld_data_o == data
 
     def lds():
@@ -239,13 +269,14 @@ def sim(dut):
                 break
             yield
 
-        shf = addr & dlm
+        shf = (addr & dlm)*8  #shift bytes not bits
+        print("shf",shf/8.0)
         shfdata = (data << shf)
         data1 = shfdata & dmask
-        print("ld data1", bin(data), bin(data1), shf, bin(dmask))
+        print("ld data1", hex(data), hex(data1), shf,shf/8.0, hex(dmask))
 
-        data2 = (shfdata >> 16) & dmask1
-        print("ld data2", 1 << dlen, bin(data >> (1 << dlen)), bin(data2))
+        data2 = (shfdata >> 128) & dmask1
+        print("ld data2", 1 << dlen, hex(data >> (1 << dlen)), hex(data2))
         yield dut.sld_data_i[0].data.eq(data1)
         yield dut.sld_valid_i[0].eq(1)
         yield