write-release moves out of "ALU valid" due to using alu_pulse
[soc.git] / src / soc / regfile / regfile.py
index 486b727c8bd04fe77d1c2a91098a708a161b1fed..d0a7b5402668bdaf568f68906d825a6d3f6eb291 100644 (file)
@@ -12,8 +12,9 @@ primary focus here is on *unary* indexing.
 Links:
 
 * https://libre-soc.org/3d_gpu/architecture/regfile
+* https://bugs.libre-soc.org/show_bug.cgi?id=345
 * https://bugs.libre-soc.org/show_bug.cgi?id=351
-* https://bugs.libre-soc.org/show_bug.cgi?id=349
+* https://bugs.libre-soc.org/show_bug.cgi?id=352
 """
 
 from nmigen.compat.sim import run_simulation
@@ -21,9 +22,9 @@ from nmigen.cli import verilog, rtlil
 
 from nmigen import Cat, Const, Array, Signal, Elaboratable, Module
 from nmutil.iocontrol import RecordObject
+from nmutil.util import treereduce
 
 from math import log
-from functools import reduce
 import operator
 
 
@@ -67,7 +68,7 @@ class Register(Elaboratable):
                 else:
                     m.d.comb += rp.data_o.eq(reg)
 
-        # write ports, don't allow write to address 0 (ignore it)
+        # write ports, delayed by 1 cycle
         for wp in self._wrports:
             with m.If(wp.wen):
                 m.d.sync += reg.eq(wp.data_i)
@@ -83,16 +84,8 @@ class Register(Elaboratable):
     def ports(self):
         res = list(self)
 
-def treereduce(tree, attr="data_o"):
-    #print ("treereduce", tree)
-    if not isinstance(tree, list):
-        return tree
-    if len(tree) == 1:
-        return getattr(tree[0], attr)
-    if len(tree) == 2:
-        return getattr(tree[0], attr) | getattr(tree[1], attr)
-    split = len(tree) // 2
-    return treereduce(tree[:split], attr) | treereduce(tree[split:], attr)
+def ortreereduce(tree, attr="data_o"):
+    return treereduce(tree, operator.or_, lambda x: getattr(x, attr))
 
 
 class RegFileArray(Elaboratable):
@@ -107,11 +100,22 @@ class RegFileArray(Elaboratable):
         self._rdports = []
         self._wrports = []
 
-    def read_port(self, name=None):
+    def read_reg_port(self, name=None):
+        regs = []
+        for i in range(self.depth):
+            port = self.regs[i].read_port("%s%d" % (name, i))
+            regs.append(port)
+        return regs
+
+    def write_reg_port(self, name=None):
         regs = []
         for i in range(self.depth):
-            port = self.regs[i].read_port(name)
+            port = self.regs[i].write_port("%s%d" % (name, i))
             regs.append(port)
+        return regs
+
+    def read_port(self, name=None):
+        regs = self.read_reg_port(name)
         regs = Array(regs)
         port = RecordObject([("ren", self.depth),
                              ("data_o", self.width)], name)
@@ -119,10 +123,7 @@ class RegFileArray(Elaboratable):
         return port
 
     def write_port(self, name=None):
-        regs = []
-        for i in range(self.depth):
-            port = self.regs[i].write_port(name)
-            regs.append(port)
+        regs = self.write_reg_port(name)
         regs = Array(regs)
         port = RecordObject([("wen", self.depth),
                              ("data_i", self.width)])
@@ -143,7 +144,7 @@ class RegFileArray(Elaboratable):
         for (regs, p) in self._rdports:
             #print (p)
             m.d.comb += self._get_en_sig(regs, 'ren').eq(p.ren)
-            ror = treereduce(list(regs))
+            ror = ortreereduce(list(regs))
             m.d.comb += p.data_o.eq(ror)
         for (regs, p) in self._wrports:
             m.d.comb += self._get_en_sig(regs, 'wen').eq(p.wen)
@@ -202,9 +203,9 @@ class RegFile(Elaboratable):
                 with m.If(~wr_detect):
                     m.d.comb += rp.data_o.eq(regs[rp.raddr])
 
-        # write ports, don't allow write to address 0 (ignore it)
+        # write ports, delayed by one cycle
         for wp in self._wrports:
-            with m.If(wp.wen & (wp.waddr != Const(0, bsz))):
+            with m.If(wp.wen):
                 m.d.sync += regs[wp.waddr].eq(wp.data_i)
 
         return m