Integrate rotator.py into shift_rot unit
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 13 May 2020 17:45:06 +0000 (13:45 -0400)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 13 May 2020 17:45:06 +0000 (13:45 -0400)
src/soc/shift_rot/main_stage.py
src/soc/shift_rot/pipeline.py
src/soc/shift_rot/rotator.py
src/soc/shift_rot/test/test_pipe_caller.py

index 09c32855377cfa5292abeafce4988fa788352b96..aa2b79a59c6f1beba8dc7a94caf1d71fcb3dc9b5 100644 (file)
@@ -5,11 +5,11 @@
 # output stage
 from nmigen import (Module, Signal, Cat, Repl, Mux, Const)
 from nmutil.pipemodbase import PipeModBase
-from soc.alu.pipe_data import ALUInputData, ALUOutputData
+from soc.alu.pipe_data import ALUOutputData
+from soc.shift_rot.pipe_data import ShiftRotInputData
 from ieee754.part.partsig import PartitionedSignal
 from soc.decoder.power_enums import InternalOp
-from soc.shift_rot.maskgen import MaskGen
-from soc.shift_rot.rotl import ROTL
+from soc.shift_rot.rotator import Rotator
 
 from soc.decoder.power_fields import DecodeFields
 from soc.decoder.power_fieldsn import SignalBitRange
@@ -22,7 +22,7 @@ class ShiftRotMainStage(PipeModBase):
         self.fields.create_specs()
 
     def ispec(self):
-        return ALUInputData(self.pspec)
+        return ShiftRotInputData(self.pspec)
 
     def ospec(self):
         return ALUOutputData(self.pspec) # TODO: ALUIntermediateData
@@ -30,94 +30,39 @@ class ShiftRotMainStage(PipeModBase):
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
-
-
-        fields = self.fields.instrs['M']
-        mb = Signal(fields['MB'][0:-1].shape())
-        comb += mb.eq(fields['MB'][0:-1])
-        me = Signal(fields['ME'][0:-1].shape())
-        comb += me.eq(fields['ME'][0:-1])
-
-        # check if op is 32-bit, and get sign bit from operand a
-        is_32bit = Signal(reset_less=True)
-        sign_bit = Signal(reset_less=True)
-        comb += is_32bit.eq(self.i.ctx.op.is_32bit)
-        comb += sign_bit.eq(Mux(is_32bit, self.i.a[31], self.i.a[63]))
-
-        # Signals for rotates and shifts
-        rotl_out = Signal.like(self.i.a)
-        mask = Signal.like(self.i.a)
-        m.submodules.maskgen = maskgen = MaskGen(64)
-        m.submodules.rotl = rotl = ROTL(64)
-        m.submodules.rotl32 = rotl32 = ROTL(32)
-        rotate_amt = Signal.like(rotl.b)
-
+        m.submodules.rotator = rotator = Rotator()
         comb += [
-            rotl.a.eq(self.i.a),
-            rotl.b.eq(rotate_amt),
-            rotl32.a.eq(self.i.a[0:32]),
-            rotl32.b.eq(rotate_amt)]
-
-        with m.If(is_32bit):
-            comb += rotl_out.eq(Cat(rotl32.o, Repl(0, 32)))
-        with m.Else():
-            comb += rotl_out.eq(rotl.o)
-
-        ##########################
-        # main switch-statement for handling arithmetic and logic operations
+            rotator.rs.eq(self.i.rs),
+            rotator.ra.eq(self.i.ra),
+            rotator.shift.eq(self.i.rb),
+            rotator.insn.eq(self.i.ctx.op.insn),
+            rotator.is_32bit.eq(self.i.ctx.op.is_32bit),
+            rotator.arith.eq(self.i.ctx.op.is_signed),
+        ]
+
+        # Defaults
+        comb += [rotator.right_shift.eq(0),
+                 rotator.clear_left.eq(0),
+                 rotator.clear_right.eq(0)]
+
+        comb += [self.o.o.eq(rotator.result_o),
+                 self.o.carry_out.eq(rotator.carry_out_o)]
 
         with m.Switch(self.i.ctx.op.insn_type):
-            #### shift left ####
             with m.Case(InternalOp.OP_SHL):
-                comb += maskgen.mb.eq(Mux(is_32bit, 32, 0))
-                comb += maskgen.me.eq(63-self.i.b[0:6])
-                comb += rotate_amt.eq(self.i.b[0:6])
-                with m.If(is_32bit):
-                    with m.If(self.i.b[5]):
-                        comb += mask.eq(0)
-                    with m.Else():
-                        comb += mask.eq(maskgen.o)
-                with m.Else():
-                    with m.If(self.i.b[6]):
-                        comb += mask.eq(0)
-                    with m.Else():
-                        comb += mask.eq(maskgen.o)
-                comb += self.o.o.eq(rotl_out & mask)
-
-            #### shift right ####
+                comb += [rotator.right_shift.eq(0),
+                        rotator.clear_left.eq(0),
+                        rotator.clear_right.eq(0)]
             with m.Case(InternalOp.OP_SHR):
-                comb += maskgen.mb.eq(Mux(is_32bit, 32, 0) + self.i.b[0:6])
-                comb += maskgen.me.eq(63)
-                comb += rotate_amt.eq(64-self.i.b[0:6])
-                with m.If(is_32bit):
-                    with m.If(self.i.b[5]):
-                        comb += mask.eq(0)
-                    with m.Else():
-                        comb += mask.eq(maskgen.o)
-                with m.Else():
-                    with m.If(self.i.b[6]):
-                        comb += mask.eq(0)
-                    with m.Else():
-                        comb += mask.eq(maskgen.o)
-                with m.If(self.i.ctx.op.is_signed):
-                    out = rotl_out & mask | Mux(sign_bit, ~mask, 0)
-                    cout = sign_bit & ((rotl_out & mask) != 0)
-                    comb += self.o.o.eq(out)
-                    comb += self.o.carry_out.eq(cout)
-                with m.Else():
-                    comb += self.o.o.eq(rotl_out & mask)
-
-            with m.Case(InternalOp.OP_RLC):
-                with m.If(self.i.ctx.op.imm_data.imm_ok):
-                    comb += rotate_amt.eq(self.i.ctx.op.imm_data.imm[0:5])
-                with m.Else():
-                    comb += rotate_amt.eq(self.i.b[0:5])
-                comb += maskgen.mb.eq(mb+32)
-                comb += maskgen.me.eq(me+32)
-                comb += mask.eq(maskgen.o)
-                comb += self.o.o.eq((rotl_out & mask) | (self.i.b & ~mask))
+                comb += [rotator.right_shift.eq(1),
+                        rotator.clear_left.eq(0),
+                        rotator.clear_right.eq(0)]
                 
 
+
+
+
+
         ###### sticky overflow and context, both pass-through #####
 
         comb += self.o.so.eq(self.i.so)
index eb62013aae39a60ce847d3826a6cd994f955a0d4..1080aa8debdaa56c571b8ab43c7638d080cfd92b 100644 (file)
@@ -1,12 +1,12 @@
 from nmutil.singlepipe import ControlBase
 from nmutil.pipemodbase import PipeModBaseChain
-from soc.alu.input_stage import ALUInputStage
+from soc.shift_rot.input_stage import ShiftRotInputStage
 from soc.shift_rot.main_stage import ShiftRotMainStage
 from soc.alu.output_stage import ALUOutputStage
 
 class ShiftRotStages(PipeModBaseChain):
     def get_chain(self):
-        inp = ALUInputStage(self.pspec)
+        inp = ShiftRotInputStage(self.pspec)
         main = ShiftRotMainStage(self.pspec)
         out = ALUOutputStage(self.pspec)
         return [inp, main, out]
index 7681692e8b2bc1ec569231f254285d16b369fc09..3d90f8bed485e9cefedfb6711560948fb36ba5dd 100644 (file)
@@ -1,29 +1,25 @@
 # Manual translation and adaptation of rotator.vhdl from microwatt into nmigen
 #
 
-from nmigen import (Elaboratable, Signal, Module, Const, Cat)
-from soc.alu.rotl import ROTL
+from nmigen import (Elaboratable, Signal, Module, Const, Cat,
+                    unsigned, signed)
+from soc.shift_rot.rotl import ROTL
 
 # note BE bit numbering
 def right_mask(m, mask_begin):
     """ this can be replaced by something like (mask_begin << 1) - 1"""
     ret = Signal(64, name="right_mask", reset_less=True)
-    m.d.comb += ret.eq(0)
-    for i in range(64):
-        with m.If(i >= unsigned(mask_begin)): # set from i upwards
-            m.d.comb += ret[63 - i].eq(1)
-    return ret;
+    with m.If(mask_begin > 64):
+        m.d.comb += ret.eq(0)
+    with m.Else():
+        m.d.comb += ret.eq((1<<(64-mask_begin)) - 1)
+    return ret
 
 def left_mask(m, mask_end):
     """ this can be replaced by something like ~((mask_end << 1) - 1)"""
     ret = Signal(64, name="left_mask", reset_less=True)
-    m.d.comb += ret.eq(0)
-    with m.If(mask_end[6] != 0):
-        return ret
-    for i in range(64):
-        with m.If(i <= unsigned(mask_end)): # set from i downwards
-            m.d.comb += ret[63 - i].eq(1)
-    return ret;
+    m.d.comb += ret.eq(~((1<<(63-mask_end)) - 1))
+    return ret
 
 
 class Rotator(Elaboratable):
@@ -65,7 +61,7 @@ class Rotator(Elaboratable):
         ra, rs = self.ra, self.rs
 
         # temporaries
-        repl32 = Signal(64, reset_less=True)
+        rot_in = Signal(64, reset_less=True)
         rot_count = Signal(6, reset_less=True)
         rot = Signal(64, reset_less=True)
         sh = Signal(7, reset_less=True)
@@ -76,24 +72,29 @@ class Rotator(Elaboratable):
         output_mode = Signal(2, reset_less=True)
 
         # First replicate bottom 32 bits to both halves if 32-bit
-        comb += repl32[0:32].eq(rs[0:32])
+        comb += rot_in[0:32].eq(rs[0:32])
         with m.If(self.is_32bit):
-            comb += repl32[32:64].eq(rs[0:32])
+            comb += rot_in[32:64].eq(rs[0:32])
+        with m.Else():
+            comb += rot_in[32:64].eq(rs[32:64])
+
+        shift_signed = Signal(signed(6))
+        comb += shift_signed.eq(self.shift[0:6])
 
         # Negate shift count for right shifts
         with m.If(self.right_shift):
-            comb += rot_count.eq(-signed(self.shift[0:6]))
+            comb += rot_count.eq(-shift_signed)
         with m.Else():
             comb += rot_count.eq(self.shift[0:6])
 
         # ROTL submodule
         m.submodules.rotl = rotl = ROTL(64)
-        comb += rotl.a.eq(repl32)
+        comb += rotl.a.eq(rot_in)
         comb += rotl.b.eq(rot_count)
         comb += rot.eq(rotl.o)
 
         # Trim shift count to 6 bits for 32-bit shifts
-        comb += sh.eq(Cat(shift[0:6], shift[6] & ~self.is_32bit))
+        comb += sh.eq(Cat(self.shift[0:6], self.shift[6] & ~self.is_32bit))
 
         # XXX errr... we should already have these, in Fields?  oh well
         # Work out mask begin/end indexes (caution, big-endian bit numbering)
@@ -120,7 +121,7 @@ class Rotator(Elaboratable):
             comb += me.eq(Cat(self.insn[6:11], self.insn[5], Const(0b0, 1)))
         with m.Else():
             # effectively, 63 - sh
-            comb += me.eq(Cat(~shift[0:6], shift[6]))
+            comb += me.eq(Cat(~self.shift[0:6], self.shift[6]))
 
         # Calculate left and right masks
         comb += mr.eq(right_mask(m, mb))
@@ -132,10 +133,10 @@ class Rotator(Elaboratable):
         # 10 for rldicl, sr[wd]
         # 1z for sra[wd][i], z = 1 if rs is negative
         with m.If((self.clear_left & ~self.clear_right) | self.right_shift):
-            comb += output_mode.eq(Cat(self.arith & repl32[63], Const(1, 1))
+            comb += output_mode.eq(Cat(self.arith & rot_in[63], Const(1, 1)))
         with m.Else():
-            mbgt = self.clear_right & (unsigned(mb[0:6]) > unsigned(me[0:6]))
-            comb += output_mode.eq(Cat(mbgt, Const(0, 1))
+            mbgt = self.clear_right & (mb[0:6] > me[0:6])
+            comb += output_mode.eq(Cat(mbgt, Const(0, 1)))
 
         # Generate output from rotated input and masks
         with m.Switch(output_mode):
index 6a836d787eb92d741f3d1f9e443aeebfdadcd7d3..888de74cc9987e1320dde5b0b243d0f778d8e193 100644 (file)
@@ -41,27 +41,29 @@ def set_alu_inputs(alu, dec2, sim):
     reg3_ok = yield dec2.e.read_reg3.ok
     if reg3_ok:
         reg3_sel = yield dec2.e.read_reg3.data
-        inputs.append(sim.gpr(reg3_sel).value)
+        data3 = sim.gpr(reg3_sel).value
+    else:
+        data3 = 0
     reg1_ok = yield dec2.e.read_reg1.ok
     if reg1_ok:
         reg1_sel = yield dec2.e.read_reg1.data
-        inputs.append(sim.gpr(reg1_sel).value)
+        data1 = sim.gpr(reg1_sel).value
+    else:
+        data1 = 0
     reg2_ok = yield dec2.e.read_reg2.ok
+    imm_ok = yield dec2.e.imm_data.ok
     if reg2_ok:
         reg2_sel = yield dec2.e.read_reg2.data
-        inputs.append(sim.gpr(reg2_sel).value)
+        data2 = sim.gpr(reg2_sel).value
+    elif imm_ok:
+        data2 = yield dec2.e.imm_data.imm
+    else:
+        data2 = 0
 
-    print(inputs)
+    yield alu.p.data_i.ra.eq(data1)
+    yield alu.p.data_i.rb.eq(data2)
+    yield alu.p.data_i.rs.eq(data3)
 
-    if len(inputs) == 0:
-        yield alu.p.data_i.a.eq(0)
-        yield alu.p.data_i.b.eq(0)
-    if len(inputs) == 1:
-        yield alu.p.data_i.a.eq(inputs[0])
-        yield alu.p.data_i.b.eq(0)
-    if len(inputs) == 2:
-        yield alu.p.data_i.a.eq(inputs[0])
-        yield alu.p.data_i.b.eq(inputs[1])
 
 def set_extra_alu_inputs(alu, dec2, sim):
     carry = 1 if sim.spr['XER'][XER_bits['CA']] else 0
@@ -120,6 +122,14 @@ class ALUTestCase(FHDLTestCase):
         print(initial_regs[1], initial_regs[2])
         self.run_tst_program(Program(lst), initial_regs)
 
+    def test_shift_once(self):
+        lst = ["sraw 3, 1, 2"]
+        initial_regs = [0] * 32
+        initial_regs[1] = 0xdeadbeefcafec0de
+        initial_regs[2] = 53
+        print(initial_regs[1], initial_regs[2])
+        self.run_tst_program(Program(lst), initial_regs)
+
     def test_rlwinm(self):
         for i in range(10):
             mb = random.randint(0,31)
@@ -219,8 +229,8 @@ class TestRunner(FHDLTestCase):
                     if out_reg_valid:
                         write_reg_idx = yield pdecode2.e.write_reg.data
                         expected = simulator.gpr(write_reg_idx).value
-                        print(f"expected {expected:x}, actual: {alu_out:x}")
-                        self.assertEqual(expected, alu_out)
+                        msg = f"expected {expected:x}, actual: {alu_out:x}"
+                        self.assertEqual(expected, alu_out, msg)
                     yield from self.check_extra_alu_outputs(alu, pdecode2,
                                                             simulator)