working on splitting out common mul pipe test code
[soc.git] / src / soc / regfile / regfile.py
index 503ce8ee289b920f4d34d6f05bc82190733c0977..30063726137032501e6838f5712d690395d91de6 100644 (file)
@@ -32,9 +32,11 @@ import operator
 
 
 class Register(Elaboratable):
-    def __init__(self, width, writethru=True):
+    def __init__(self, width, writethru=True, synced=True, resetval=0):
         self.width = width
+        self.reset = resetval
         self.writethru = writethru
+        self.synced = synced
         self._rdports = []
         self._wrports = []
 
@@ -54,24 +56,28 @@ class Register(Elaboratable):
 
     def elaborate(self, platform):
         m = Module()
-        self.reg = reg = Signal(self.width, name="reg")
+        self.reg = reg = Signal(self.width, name="reg", reset=self.reset)
+
+        if self.synced:
+            domain = m.d.sync
+        else:
+            domain = m.d.comb
 
         # read ports. has write-through detection (returns data written)
         for rp in self._rdports:
-            with m.If(rp.ren == 1):
+            domain += rp.data_o.eq(0)
+            with m.If(rp.ren):
                 if self.writethru:
                     wr_detect = Signal(reset_less=False)
                     m.d.comb += wr_detect.eq(0)
                     for wp in self._wrports:
                         with m.If(wp.wen):
-                            m.d.comb += rp.data_o.eq(wp.data_i)
+                            domain += rp.data_o.eq(wp.data_i)
                             m.d.comb += wr_detect.eq(1)
                     with m.If(~wr_detect):
-                        m.d.comb += rp.data_o.eq(reg)
+                        domain += rp.data_o.eq(reg)
                 else:
-                    m.d.comb += rp.data_o.eq(reg)
-            with m.Else():
-                m.d.comb += rp.data_o.eq(0)
+                    domain += rp.data_o.eq(reg)
 
         # write ports, delayed by 1 cycle
         for wp in self._wrports:
@@ -101,10 +107,12 @@ class RegFileArray(Elaboratable):
         and read-en signals (per port).
     """
 
-    def __init__(self, width, depth):
+    def __init__(self, width, depth, synced=True):
+        self.synced = synced
         self.width = width
         self.depth = depth
-        self.regs = Array(Register(width) for _ in range(self.depth))
+        self.regs = Array(Register(width, synced=synced) \
+                          for _ in range(self.depth))
         self._rdports = []
         self._wrports = []
 
@@ -149,11 +157,22 @@ class RegFileArray(Elaboratable):
         for i, reg in enumerate(self.regs):
             setattr(m.submodules, "reg_%d" % i, reg)
 
+        if self.synced:
+            domain = m.d.sync
+        else:
+            domain = m.d.comb
+
         for (regs, p) in self._rdports:
             #print (p)
             m.d.comb += self._get_en_sig(regs, 'ren').eq(p.ren)
             ror = ortreereduce(list(regs))
-            m.d.comb += p.data_o.eq(ror)
+            if self.synced:
+                ren_delay = Signal.like(p.ren)
+                m.d.sync += ren_delay.eq(p.ren)
+                with m.If(ren_delay):
+                    m.d.comb += p.data_o.eq(ror)
+            else:
+                m.d.comb += p.data_o.eq(ror)
         for (regs, p) in self._wrports:
             m.d.comb += self._get_en_sig(regs, 'wen').eq(p.wen)
             for r in regs:
@@ -171,7 +190,9 @@ class RegFileArray(Elaboratable):
 
 class RegFileMem(Elaboratable):
     unary = False
-    def __init__(self, width, depth):
+    def __init__(self, width, depth, fwd_bus_mode=False, synced=True):
+        self.fwd_bus_mode = fwd_bus_mode
+        self.synced = synced
         self.width, self.depth = width, depth
         self.memory = Memory(width=width, depth=depth)
         self._rdports = {}
@@ -182,7 +203,11 @@ class RegFileMem(Elaboratable):
         port = RecordObject([("addr", bsz),
                              ("ren", 1),
                              ("data_o", self.width)], name=name)
-        self._rdports[name] = (port, self.memory.read_port(domain="comb"))
+        if self.synced:
+            domain = "sync"
+        else:
+            domain = "comb"
+        self._rdports[name] = (port, self.memory.read_port(domain=domain))
         return port
 
     def write_port(self, name=None):
@@ -202,15 +227,24 @@ class RegFileMem(Elaboratable):
             setattr(m.submodules, "rp_"+name, rport)
             wr_detect = Signal(reset_less=False)
             comb += rport.addr.eq(rp.addr)
-            with m.If(rp.ren):
-                m.d.comb += wr_detect.eq(0)
-                for _, (wp, wport) in self._wrports.items():
-                    addrmatch = Signal(reset_less=False)
-                    m.d.comb += addrmatch.eq(wp.addr == rp.addr)
-                    with m.If(wp.wen & addrmatch):
-                        m.d.comb += rp.data_o.eq(wp.data_i)
-                        m.d.comb += wr_detect.eq(1)
-                with m.If(~wr_detect):
+            if self.fwd_bus_mode:
+                with m.If(rp.ren):
+                    m.d.comb += wr_detect.eq(0)
+                    for _, (wp, wport) in self._wrports.items():
+                        addrmatch = Signal(reset_less=False)
+                        m.d.comb += addrmatch.eq(wp.addr == rp.addr)
+                        with m.If(wp.wen & addrmatch):
+                            m.d.comb += rp.data_o.eq(wp.data_i)
+                            m.d.comb += wr_detect.eq(1)
+                    with m.If(~wr_detect):
+                        m.d.comb += rp.data_o.eq(rport.data)
+            else:
+                if self.synced:
+                    ren_delay = Signal.like(rp.ren)
+                    m.d.sync += ren_delay.eq(rp.ren)
+                    with m.If(ren_delay):
+                        m.d.comb += rp.data_o.eq(rport.data)
+                else:
                     m.d.comb += rp.data_o.eq(rport.data)
 
         # write ports, delayed by one cycle (in the memory itself)
@@ -379,7 +413,7 @@ def test_regfile():
 
     run_simulation(dut, regfile_sim(dut, rp, wp), vcd_name='test_regfile.vcd')
 
-    dut = RegFileMem(32, 8)
+    dut = RegFileMem(32, 8, True, False)
     rp = dut.read_port("rp1")
     wp = dut.write_port("wp1")
     vl = rtlil.convert(dut)#, ports=dut.ports())
@@ -388,7 +422,7 @@ def test_regfile():
 
     run_simulation(dut, regfile_sim(dut, rp, wp), vcd_name='test_regmem.vcd')
 
-    dut = RegFileArray(32, 8)
+    dut = RegFileArray(32, 8, False)
     rp1 = dut.read_port("read1")
     rp2 = dut.read_port("read2")
     wp = dut.write_port("write")