convert shift_rot pipeline to XLEN=32/64
[soc.git] / src / soc / fu / shift_rot / rotator.py
index b5dc80eef6aa4be023e6e772fe766ea7a01d9163..eac042fedcece092fec572ed75dd9759f852728e 100644 (file)
@@ -1,22 +1,28 @@
 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
 #
+from nmigen.compat.sim import run_simulation
 
 from nmigen import (Elaboratable, Signal, Module, Const, Cat, Repl,
                     unsigned, signed)
 from soc.fu.shift_rot.rotl import ROTL
+from nmigen.back.pysim import Settle
 from nmutil.extend import exts
+from nmutil.mask import Mask
 
 
 # note BE bit numbering
-def right_mask(m, mask_begin):
-    ret = Signal(64, name="right_mask", reset_less=True)
-    with m.If(mask_begin <= 64):
-        m.d.comb += ret.eq((1<<(64-mask_begin)) - 1)
+def right_mask(m, mask_begin, width):
+    ret = Signal(width, name="right_mask", reset_less=True)
+    with m.If(mask_begin <= width):
+        m.d.comb += ret.eq((1 << (width-mask_begin)) - 1)
+    with m.Else():
+        m.d.comb += ret.eq(0)
     return ret
 
-def left_mask(m, mask_end):
-    ret = Signal(64, name="left_mask", reset_less=True)
-    m.d.comb += ret.eq(~((1<<(63-mask_end)) - 1))
+
+def left_mask(m, mask_end, width):
+    ret = Signal(width, name="left_mask", reset_less=True)
+    m.d.comb += ret.eq(~((1 << (width-1-mask_end)) - 1))
     return ret
 
 
@@ -38,13 +44,16 @@ class Rotator(Elaboratable):
         * clear_left = 1 when insn_type is OP_RLC or OP_RLCL
         * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
     """
-    def __init__(self):
+
+    def __init__(self, width):
+        self.width = width
         # input
         self.me = Signal(5, reset_less=True)        # ME field
         self.mb = Signal(5, reset_less=True)        # MB field
-        self.mb_extra = Signal(1, reset_less=True)  # extra bit of mb in MD-form
-        self.ra = Signal(64, reset_less=True)       # RA
-        self.rs = Signal(64, reset_less=True)       # RS
+        # extra bit of mb in MD-form
+        self.mb_extra = Signal(1, reset_less=True)
+        self.ra = Signal(width, reset_less=True)       # RA
+        self.rs = Signal(width, reset_less=True)       # RS
         self.shift = Signal(7, reset_less=True)     # RB[0:7]
         self.is_32bit = Signal(reset_less=True)
         self.right_shift = Signal(reset_less=True)
@@ -53,10 +62,11 @@ class Rotator(Elaboratable):
         self.clear_right = Signal(reset_less=True)
         self.sign_ext_rs = Signal(reset_less=True)
         # output
-        self.result_o = Signal(64, reset_less=True)
+        self.result_o = Signal(width, reset_less=True)
         self.carry_out_o = Signal(reset_less=True)
 
     def elaborate(self, platform):
+        width = self.width
         m = Module()
         comb = m.d.comb
         ra, rs = self.ra, self.rs
@@ -67,11 +77,11 @@ class Rotator(Elaboratable):
         sh = Signal(7, reset_less=True)
         mb = Signal(7, reset_less=True)
         me = Signal(7, reset_less=True)
-        mr = Signal(64, reset_less=True)
-        ml = Signal(64, reset_less=True)
+        mr = Signal(width, reset_less=True)
+        ml = Signal(width, reset_less=True)
         output_mode = Signal(2, reset_less=True)
         hi32 = Signal(32, reset_less=True)
-        repl32 = Signal(64, reset_less=True)
+        repl32 = Signal(width, reset_less=True)
 
         # First replicate bottom 32 bits to both halves if 32-bit
         with m.If(self.is_32bit):
@@ -80,7 +90,8 @@ class Rotator(Elaboratable):
             # sign-extend bottom 32 bits
             comb += hi32.eq(Repl(rs[31], 32))
         with m.Else():
-            comb += hi32.eq(rs[32:64])
+            if width == 64:
+                comb += hi32.eq(rs[32:64])
         comb += repl32.eq(Cat(rs[0:32], hi32))
 
         shift_signed = Signal(signed(6))
@@ -93,7 +104,7 @@ class Rotator(Elaboratable):
             comb += rot_count.eq(self.shift[0:6])
 
         # ROTL submodule
-        m.submodules.rotl = rotl = ROTL(64)
+        m.submodules.rotl = rotl = ROTL(width)
         comb += rotl.a.eq(repl32)
         comb += rotl.b.eq(rot_count)
         comb += rot.eq(rotl.o)
@@ -131,8 +142,19 @@ class Rotator(Elaboratable):
             comb += me.eq(Cat(~sh[0:6], sh[6]))
 
         # Calculate left and right masks
-        comb += mr.eq(right_mask(m, mb))
-        comb += ml.eq(left_mask(m, me))
+        m.submodules.right_mask = right_mask = Mask(width)
+        with m.If(mb <= width):
+            comb += right_mask.shift.eq(width-mb)
+            comb += mr.eq(right_mask.mask)
+        with m.Else():
+            comb += mr.eq(0)
+        #comb += mr.eq(right_mask(m, mb))
+
+        m.submodules.left_mask = left_mask = Mask(width)
+        comb += left_mask.shift.eq(width-1-me)
+        comb += ml.eq(~left_mask.mask)
+        #comb += ml.eq(left_mask(m, me))
+
 
         # Work out output mode
         # 00 for sl[wd]
@@ -140,7 +162,8 @@ class Rotator(Elaboratable):
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
         with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
-            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1)))
+            comb += output_mode.eq(Cat(self.arith &
+                                       repl32[width-1], Const(1, 1)))
         with m.Else():
             mbgt = self.clear_right & (mb[0:6] > me[0:6])
             comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
@@ -156,7 +179,25 @@ class Rotator(Elaboratable):
             with m.Case(0b11):
                 comb += self.result_o.eq(rot | ~mr)
                 # Generate carry output for arithmetic shift right of -ve value
-                comb += self.carry_out_o.eq((rs & ~ml).bool() & rs[0])
+                comb += self.carry_out_o.eq((rs & ~ml).bool())
 
         return m
 
+
+if __name__ == '__main__':
+
+    m = Module()
+    comb = m.d.comb
+    mr = Signal(64)
+    mb = Signal(6)
+    comb += mr.eq(left_mask(m, mb, 64))
+
+    def loop():
+        for i in range(64):
+            yield mb.eq(63-i)
+            yield Settle()
+            res = yield mr
+            print(i, hex(res))
+
+    run_simulation(m, [loop()],
+                   vcd_name="test_mask.vcd")