update rotator.py to match microwatt rotator.vhdl
[soc.git] / src / soc / fu / shift_rot / rotator.py
index 4d2d659a219a173ea4b8d47fd8cd16a23836304e..b5dc80eef6aa4be023e6e772fe766ea7a01d9163 100644 (file)
@@ -1,9 +1,11 @@
 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
 #
 
-from nmigen import (Elaboratable, Signal, Module, Const, Cat,
+from nmigen import (Elaboratable, Signal, Module, Const, Cat, Repl,
                     unsigned, signed)
 from soc.fu.shift_rot.rotl import ROTL
+from nmutil.extend import exts
+
 
 # note BE bit numbering
 def right_mask(m, mask_begin):
@@ -49,6 +51,7 @@ class Rotator(Elaboratable):
         self.arith = Signal(reset_less=True)
         self.clear_left = Signal(reset_less=True)
         self.clear_right = Signal(reset_less=True)
+        self.sign_ext_rs = Signal(reset_less=True)
         # output
         self.result_o = Signal(64, reset_less=True)
         self.carry_out_o = Signal(reset_less=True)
@@ -59,7 +62,6 @@ class Rotator(Elaboratable):
         ra, rs = self.ra, self.rs
 
         # temporaries
-        rot_in = Signal(64, reset_less=True)
         rot_count = Signal(6, reset_less=True)
         rot = Signal(64, reset_less=True)
         sh = Signal(7, reset_less=True)
@@ -68,13 +70,18 @@ class Rotator(Elaboratable):
         mr = Signal(64, reset_less=True)
         ml = Signal(64, reset_less=True)
         output_mode = Signal(2, reset_less=True)
+        hi32 = Signal(32, reset_less=True)
+        repl32 = Signal(64, reset_less=True)
 
         # First replicate bottom 32 bits to both halves if 32-bit
-        comb += rot_in[0:32].eq(rs[0:32])
         with m.If(self.is_32bit):
-            comb += rot_in[32:64].eq(rs[0:32])
+            comb += hi32.eq(rs[0:32])
+        with m.Elif(self.sign_ext_rs):
+            # sign-extend bottom 32 bits
+            comb += hi32.eq(Repl(rs[31], 32))
         with m.Else():
-            comb += rot_in[32:64].eq(rs[32:64])
+            comb += hi32.eq(rs[32:64])
+        comb += repl32.eq(Cat(rs[0:32], hi32))
 
         shift_signed = Signal(signed(6))
         comb += shift_signed.eq(self.shift[0:6])
@@ -87,7 +94,7 @@ class Rotator(Elaboratable):
 
         # ROTL submodule
         m.submodules.rotl = rotl = ROTL(64)
-        comb += rotl.a.eq(rot_in)
+        comb += rotl.a.eq(repl32)
         comb += rotl.b.eq(rot_count)
         comb += rot.eq(rotl.o)
 
@@ -133,7 +140,7 @@ class Rotator(Elaboratable):
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
         with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
-            comb += output_mode.eq(Cat(self.arith & rot_in[63], Const(1, 1)))
+            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1)))
         with m.Else():
             mbgt = self.clear_right & (mb[0:6] > me[0:6])
             comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
@@ -149,7 +156,7 @@ class Rotator(Elaboratable):
             with m.Case(0b11):
                 comb += self.result_o.eq(rot | ~mr)
                 # Generate carry output for arithmetic shift right of -ve value
-                comb += self.carry_out_o.eq(rs & ~ml)
+                comb += self.carry_out_o.eq((rs & ~ml).bool() & rs[0])
 
         return m