cleanup rotator.py
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 11 May 2020 12:10:26 +0000 (13:10 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 11 May 2020 12:10:26 +0000 (13:10 +0100)
src/soc/alu/rotator.py

index 05b6b19b908d8a8002229d30dc86ad7e91ec2e00..8cea041f112a5916937e73c15f04eaf169f5ef51 100644 (file)
@@ -1,10 +1,12 @@
 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
 #
 
+from nmigen import (Elaboratable, Signal, Module, Const, Cat)
 from soc.alu.rotl import ROTL
 
-#note BE bit numbering
+# note BE bit numbering
 def right_mask(m, mask_begin):
+    """ this can be replaced by something like (mask_begin << 1) - 1"""
     ret = Signal(64, name="right_mask", reset_less=True)
     m.d.comb += ret.eq(0)
     for i in range(64):
@@ -13,6 +15,7 @@ def right_mask(m, mask_begin):
     return ret;
 
 def left_mask(m, mask_end):
+    """ this can be replaced by something like ~((mask_end << 1) - 1)"""
     ret = Signal(64, name="left_mask", reset_less=True)
     m.d.comb += ret.eq(0)
     with m.If(mask_end[6] != 0):
@@ -24,6 +27,23 @@ def left_mask(m, mask_end):
 
 
 class Rotator(Elaboratable):
+    """Rotator: covers multiple POWER9 rotate functions
+
+        supported modes:
+
+        * sl[wd]
+        * rlw*, rldic, rldicr, rldimi
+        * rldicl, sr[wd]
+        * sra[wd][i]
+
+        use as follows:
+
+        * shift = RB[0:7]
+        * arith = 1 when is_signed
+        * right_shift = 1 when insn_type is OP_SHR
+        * clear_left = 1 when insn_type is OP_RLC or OP_RLCL
+        * clear_right = 1 when insn_type is OP_RLC or OP_RLCR
+    """
     def __init__(self):
         # input
         self.rs = Signal(64, reset_less=True)
@@ -36,12 +56,13 @@ class Rotator(Elaboratable):
         self.clear_left = Signal(reset_less=True)
         self.clear_right = Signal(reset_less=True)
         # output
-        self.result = Signal(64, reset_less=True)
-        self.carry_out = Signal(reset_less=True)
+        self.result_o = Signal(64, reset_less=True)
+        self.carry_out_o = Signal(reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
+        ra, rs = self.ra, self.rs
 
         # temporaries
         repl32 = Signal(64, reset_less=True)
@@ -56,14 +77,14 @@ class Rotator(Elaboratable):
 
         # First replicate bottom 32 bits to both halves if 32-bit
         comb += repl32[0:32].eq(rs[0:32])
-        with m.If(is_32bit):
-            comb += repl32[32:64].eq(rs[:32])
+        with m.If(self.is_32bit):
+            comb += repl32[32:64].eq(rs[0:32])
 
         # Negate shift count for right shifts
-        with m.If(right_shift):
-            comb += rot_count.eq(-signed(shift[0:6]))
+        with m.If(self.right_shift):
+            comb += rot_count.eq(-signed(self.shift[0:6]))
         with m.Else():
-            comb += rot_count.eq(shift[0:6])
+            comb += rot_count.eq(self.shift[0:6])
 
         # ROTL submodule
         m.submodules.rotl = rotl = ROTL(64)
@@ -72,31 +93,31 @@ class Rotator(Elaboratable):
         comb += rot.eq(rotl.o)
 
         # Trim shift count to 6 bits for 32-bit shifts
-        comb += sh.eq(Cat(shift[0:6], shift[6] & ~is_32bit))
+        comb += sh.eq(Cat(shift[0:6], shift[6] & ~self.is_32bit))
 
         # XXX errr... we should already have these, in Fields?  oh well
         # Work out mask begin/end indexes (caution, big-endian bit numbering)
 
         # mask-begin (mb)
-        with m.If(clear_left):
-            with m.If(is_32bit):
-                comb += mb.eq(Cat(insn[6:11], Const(0b01, 2)))
+        with m.If(self.clear_left):
+            with m.If(self.is_32bit):
+                comb += mb.eq(Cat(self.insn[6:11], Const(0b01, 2)))
             with m.Else():
-                comb += mb.eq(Cat(insn[6:11], insn[5], Const(0b0, 1)))
-        with m.Elif(right_shift):
-            # this is basically mb <= sh + (is_32bit? 32: 0);
-            with m.If(is_32bit):
+                comb += mb.eq(Cat(self.insn[6:11], self.insn[5], Const(0b0, 1)))
+        with m.Elif(self.right_shift):
+            # this is basically mb = sh + (is_32bit? 32: 0);
+            with m.If(self.is_32bit):
                 comb += mb.eq(Cat(sh[0:5], ~sh[5], sh[5]))
             with m.Else():
                 comb += mb.eq(sh)
         with m.Else():
-            comb += mb.eq(Cat(Const(0b0, 5), is_32bit, Const(0b0, 1)))
+            comb += mb.eq(Cat(Const(0b0, 5), self.is_32bit, Const(0b0, 1)))
 
         # mask-end (me)
-        with m.If(clear_right & is_32bit):
-            comb += me.eq(Cat(insn[1:6], Const(0b01, 2)))
-        with m.Elif(clear_right & ~clear_left):
-            comb += me.eq(Cat(insn[6:11], insn[5], Const(0b0, 1)))
+        with m.If(self.clear_right & self.is_32bit):
+            comb += me.eq(Cat(self.insn[1:6], Const(0b01, 2)))
+        with m.Elif(self.clear_right & ~self.clear_left):
+            comb += me.eq(Cat(self.insn[6:11], self.insn[5], Const(0b0, 1)))
         with m.Else():
             # effectively, 63 - sh
             comb += me.eq(Cat(~shift[0:6], shift[6]))
@@ -110,27 +131,24 @@ class Rotator(Elaboratable):
         # 0w for rlw*, rldic, rldicr, rldimi, where w = 1 iff mb > me
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
-        with m.If((clear_left & ~clear_right) | right_shift):
-            comb += output_mode[1].eq(1)
-            comb += output_mode[0].eq(arith & repl32[63])
+        with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
+            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1))
         with m.Else():
-            comb += output_mode[1].eq(0)
-            mbgt = clear_right & unsigned(mb[0:6]) > unsigned(me[0:6])
-            comb += output_mode[0].eq(mbgt)
+            mbgt = self.clear_right & (unsigned(mb[0:6]) > unsigned(me[0:6]))
+            comb += output_mode.eq(Cat(mbgt, Const(0, 1))
 
         # Generate output from rotated input and masks
         with m.Switch(output_mode):
             with m.Case(0b00):
-                comb += result.eq((rot & (mr & ml)) | (ra & ~(mr & ml)))
+                comb += self.result_o.eq((rot & (mr & ml)) | (ra & ~(mr & ml)))
             with m.Case(0b01):
-                comb += result.eq((rot & (mr | ml)) | (ra & ~(mr or ml)))
+                comb += self.result_o.eq((rot & (mr | ml)) | (ra & ~(mr | ml)))
             with m.Case(0b10):
-                comb += result.eq(rot & mr)
+                comb += self.result_o.eq(rot & mr)
             with m.Case(0b11):
-                comb += result.eq(rot | ~mr)
-
-        # Generate carry output for arithmetic shift right of negative value
-        with m.If(output_mode = 0b11):
-            comb += carry_out.eq(rs & ~ml)
+                comb += self.result_o.eq(rot | ~mr)
+                # Generate carry output for arithmetic shift right of -ve value
+                comb += self.carry_out_o.eq(rs & ~ml)
 
         return m
+