Implement CR predication
[soc.git] / src / soc / simple / issuer.py
index 8d9bec5e3c3eecef309eda699f08ba6c7364fa35..d2c248d50bef2566784967afa2429089c207f0d6 100644 (file)
@@ -34,6 +34,7 @@ from soc.config.test.test_loadstore import TestMemPspec
 from soc.config.ifetch import ConfigFetchUnit
 from soc.decoder.power_enums import (MicrOp, SVP64PredInt, SVP64PredCR,
                                      SVP64PredMode)
+from soc.consts import (CR, SVP64CROffs)
 from soc.debug.dmi import CoreDebug, DMIInterface
 from soc.debug.jtag import JTAG
 from soc.config.pinouts import get_pinspecs
@@ -56,22 +57,23 @@ def get_insn(f_instr_o, pc):
         return f_instr_o.word_select(pc[2], 32)
 
 # gets state input or reads from state regfile
-def state_get(m, state_i, name, regfile, regnum):
+def state_get(m, core_rst, state_i, name, regfile, regnum):
     comb = m.d.comb
     sync = m.d.sync
     # read the PC
     res = Signal(64, reset_less=True, name=name)
     res_ok_delay = Signal(name="%s_ok_delay" % name)
-    sync += res_ok_delay.eq(~state_i.ok)
-    with m.If(state_i.ok):
-        # incoming override (start from pc_i)
-        comb += res.eq(state_i.data)
-    with m.Else():
-        # otherwise read StateRegs regfile for PC...
-        comb += regfile.ren.eq(1<<regnum)
-    # ... but on a 1-clock delay
-    with m.If(res_ok_delay):
-        comb += res.eq(regfile.data_o)
+    with m.If(~core_rst):
+        sync += res_ok_delay.eq(~state_i.ok)
+        with m.If(state_i.ok):
+            # incoming override (start from pc_i)
+            comb += res.eq(state_i.data)
+        with m.Else():
+            # otherwise read StateRegs regfile for PC...
+            comb += regfile.ren.eq(1<<regnum)
+        # ... but on a 1-clock delay
+        with m.If(res_ok_delay):
+            comb += res.eq(regfile.data_o)
     return res
 
 def get_predint(m, mask, name):
@@ -122,29 +124,29 @@ def get_predcr(m, mask, name):
     invert = Signal(name=name+"crinvert")
     with m.Switch(mask):
         with m.Case(SVP64PredCR.LT.value):
-            comb += idx.eq(0)
-            comb += invert.eq(1)
-        with m.Case(SVP64PredCR.GE.value):
-            comb += idx.eq(0)
+            comb += idx.eq(CR.LT)
             comb += invert.eq(0)
-        with m.Case(SVP64PredCR.GT.value):
-            comb += idx.eq(1)
+        with m.Case(SVP64PredCR.GE.value):
+            comb += idx.eq(CR.LT)
             comb += invert.eq(1)
-        with m.Case(SVP64PredCR.LE.value):
-            comb += idx.eq(1)
+        with m.Case(SVP64PredCR.GT.value):
+            comb += idx.eq(CR.GT)
             comb += invert.eq(0)
-        with m.Case(SVP64PredCR.EQ.value):
-            comb += idx.eq(2)
+        with m.Case(SVP64PredCR.LE.value):
+            comb += idx.eq(CR.GT)
             comb += invert.eq(1)
-        with m.Case(SVP64PredCR.NE.value):
-            comb += idx.eq(1)
+        with m.Case(SVP64PredCR.EQ.value):
+            comb += idx.eq(CR.EQ)
             comb += invert.eq(0)
-        with m.Case(SVP64PredCR.SO.value):
-            comb += idx.eq(3)
+        with m.Case(SVP64PredCR.NE.value):
+            comb += idx.eq(CR.EQ)
             comb += invert.eq(1)
-        with m.Case(SVP64PredCR.NS.value):
-            comb += idx.eq(3)
+        with m.Case(SVP64PredCR.SO.value):
+            comb += idx.eq(CR.SO)
             comb += invert.eq(0)
+        with m.Case(SVP64PredCR.NS.value):
+            comb += idx.eq(CR.SO)
+            comb += invert.eq(1)
     return idx, invert
 
 
@@ -408,24 +410,11 @@ class TestIssuerInternal(Elaboratable):
         predmode = rm_dec.predmode
         srcpred, dstpred = rm_dec.srcpred, rm_dec.dstpred
         cr_pred, int_pred = self.cr_pred, self.int_pred   # read regfiles
-
-        # elif predmode == CR:
-        #    CR-src sidx, sinvert = get_predcr(m, srcpred)
-        #    CR-dst didx, dinvert = get_predcr(m, dstpred)
-        #    TODO read CR-src and CR-dst into self.srcmask+dstmask with loop
-        #         has to cope with first one then the other
-        #    for cr_idx = FSM-state-loop(0..VL-1):
-        #        FSM-state-trigger-CR-read:
-        #               cr_ren = (1<<7-(cr_idx+SVP64CROffs.CRPred))
-        #               comb += cr_pred.ren.eq(cr_ren)
-        #        FSM-state-1-clock-later-actual-Read:
-        #               cr_field = Signal(4)
-        #               cr_bit = Signal(1)
-        #               # read the CR field, select the appropriate bit
-        #               comb += cr_field.eq(cr_pred.data_o)
-        #               comb += cr_bit.eq(cr_field.bit_select(idx)))
-        #               # just like in branch BO tests
-        #               comd += self.srcmask[cr_idx].eq(inv ^ cr_bit)
+        # get src/dst step, so we can skip already used mask bits
+        cur_state = self.cur_state
+        srcstep = cur_state.svstate.srcstep
+        dststep = cur_state.svstate.dststep
+        cur_vl = cur_state.svstate.vl
 
         # decode predicates
         sregread, sinvert, sunary, sall1s = get_predint(m, srcpred, 's')
@@ -453,6 +442,11 @@ class TestIssuerInternal(Elaboratable):
                             comb += int_pred.addr.eq(dregread)
                             comb += int_pred.ren.eq(1)
                             m.next = "INT_DST_READ"
+                    with m.Elif(predmode == SVP64PredMode.CR):
+                        # go fetch masks from the CR register file
+                        sync += self.srcmask.eq(0)
+                        sync += self.dstmask.eq(0)
+                        m.next = "CR_READ"
                     with m.Else():
                         sync += self.srcmask.eq(-1)
                         sync += self.dstmask.eq(-1)
@@ -461,7 +455,17 @@ class TestIssuerInternal(Elaboratable):
             with m.State("INT_DST_READ"):
                 # store destination mask
                 inv = Repl(dinvert, 64)
-                sync += self.dstmask.eq(self.int_pred.data_o ^ inv)
+                new_dstmask = Signal(64)
+                with m.If(dunary):
+                    # set selected mask bit for 1<<r3 mode
+                    dst_shift = Signal(range(64))
+                    comb += dst_shift.eq(self.int_pred.data_o & 0b111111)
+                    comb += new_dstmask.eq(1 << dst_shift)
+                with m.Else():
+                    # invert mask if requested
+                    comb += new_dstmask.eq(self.int_pred.data_o ^ inv)
+                # shift-out already used mask bits
+                sync += self.dstmask.eq(new_dstmask >> dststep)
                 # skip fetching source mask register, when zero
                 with m.If(sall1s):
                     sync += self.srcmask.eq(-1)
@@ -475,9 +479,69 @@ class TestIssuerInternal(Elaboratable):
             with m.State("INT_SRC_READ"):
                 # store source mask
                 inv = Repl(sinvert, 64)
-                sync += self.srcmask.eq(self.int_pred.data_o ^ inv)
+                new_srcmask = Signal(64)
+                with m.If(sunary):
+                    # set selected mask bit for 1<<r3 mode
+                    src_shift = Signal(range(64))
+                    comb += src_shift.eq(self.int_pred.data_o & 0b111111)
+                    comb += new_srcmask.eq(1 << src_shift)
+                with m.Else():
+                    # invert mask if requested
+                    comb += new_srcmask.eq(self.int_pred.data_o ^ inv)
+                # shift-out already used mask bits
+                sync += self.srcmask.eq(new_srcmask >> srcstep)
                 m.next = "FETCH_PRED_DONE"
 
+            # fetch masks from the CR register file
+            # implements the following loop:
+            # idx, inv = get_predcr(mask)
+            # mask = 0
+            # for cr_idx in range(vl):
+            #     cr = crl[cr_idx + SVP64CROffs.CRPred]  # takes one cycle to complete
+            #     if cr[idx] ^ inv:
+            #         mask |= 1 << cr_idx
+            # return mask
+            with m.State("CR_READ"):
+                # the CR index to be read, which will be ready by the next cycle
+                cr_idx = Signal.like(cur_vl, reset_less=True)
+                # submit the read operation to the regfile
+                with m.If(cr_idx != cur_vl):
+                    # the CR read port is unary ...
+                    # ren = 1 << cr_idx
+                    # ... in MSB0 convention ...
+                    # ren = 1 << (7 - cr_idx)
+                    # ... and with an offset:
+                    # ren = 1 << (7 - off - cr_idx)
+                    comb += cr_pred.ren.eq(1 << (7 - SVP64CROffs.CRPred - cr_idx))
+                    # signal data valid in the next cycle
+                    cr_read = Signal(reset_less=True)
+                    sync += cr_read.eq(1)
+                    # load the next index
+                    sync += cr_idx.eq(cr_idx + 1)
+                with m.Else():
+                    # exit on loop end
+                    sync += cr_read.eq(0)
+                    sync += cr_idx.eq(0)
+                    m.next = "FETCH_PRED_DONE"
+                with m.If(cr_read):
+                    # compensate for the one cycle delay on the regfile
+                    cur_cr_idx = Signal.like(cur_vl)
+                    comb += cur_cr_idx.eq(cr_idx - 1)
+                    # read the CR field, select the appropriate bit
+                    cr_field = Signal(4)
+                    scr_bit = Signal()
+                    dcr_bit = Signal()
+                    comb += cr_field.eq(cr_pred.data_o)
+                    comb += scr_bit.eq(cr_field.bit_select(sidx, 1) ^ scrinvert)
+                    comb += dcr_bit.eq(cr_field.bit_select(didx, 1) ^ dcrinvert)
+                    # set the corresponding mask bit
+                    bit_to_set = Signal.like(self.srcmask)
+                    comb += bit_to_set.eq(1 << cur_cr_idx)
+                    with m.If(scr_bit):
+                        sync += self.srcmask.eq(self.srcmask | bit_to_set)
+                    with m.If(dcr_bit):
+                        sync += self.dstmask.eq(self.dstmask | bit_to_set)
+
             with m.State("FETCH_PRED_DONE"):
                 comb += pred_mask_valid_o.eq(1)
                 with m.If(pred_mask_ready_i):
@@ -537,7 +601,6 @@ class TestIssuerInternal(Elaboratable):
                         m.next = "INSN_WAIT"
                 with m.Else():
                     # tell core it's stopped, and acknowledge debug handshake
-                    comb += core.core_stopped_i.eq(1)
                     comb += dbg.core_stopped_i.eq(1)
                     # while stopped, allow updating the PC and SVSTATE
                     with m.If(self.pc_i.ok):
@@ -600,12 +663,15 @@ class TestIssuerInternal(Elaboratable):
                             # append guard bit, in case the mask is all zeros
                             pri_enc_src = PriorityEncoder(65)
                             m.submodules.pri_enc_src = pri_enc_src
-                            comb += pri_enc_src.i.eq(Cat(self.srcmask, 1))
+                            comb += pri_enc_src.i.eq(Cat(self.srcmask,
+                                                         Const(1, 1)))
                             comb += src_delta.eq(pri_enc_src.o)
                         # apply delta to srcstep
                         comb += skip_srcstep.eq(cur_srcstep + src_delta)
                         # shift-out all leading zeros from the mask
                         # plus the leading "one" bit
+                        # TODO count leading zeros and shift-out the zero
+                        #      bits, in the same step, in hardware
                         sync += self.srcmask.eq(self.srcmask >> (src_delta+1))
 
                         # same as above, but for dststep
@@ -614,7 +680,8 @@ class TestIssuerInternal(Elaboratable):
                         with m.If(~pred_dst_zero):
                             pri_enc_dst = PriorityEncoder(65)
                             m.submodules.pri_enc_dst = pri_enc_dst
-                            comb += pri_enc_dst.i.eq(Cat(self.dstmask, 1))
+                            comb += pri_enc_dst.i.eq(Cat(self.dstmask,
+                                                         Const(1, 1)))
                             comb += dst_delta.eq(pri_enc_dst.o)
                         comb += skip_dststep.eq(cur_dststep + dst_delta)
                         sync += self.dstmask.eq(self.dstmask >> (dst_delta+1))
@@ -628,6 +695,8 @@ class TestIssuerInternal(Elaboratable):
                             comb += new_svstate.srcstep.eq(0)
                             comb += new_svstate.dststep.eq(0)
                             comb += update_svstate.eq(1)
+                            # synchronize with the simulator
+                            comb += self.insn_done.eq(1)
                             # go back to Issue
                             m.next = "ISSUE_START"
                         with m.Else():
@@ -706,7 +775,6 @@ class TestIssuerInternal(Elaboratable):
                             m.next = "PRED_SKIP"
 
                 with m.Else():
-                    comb += core.core_stopped_i.eq(1)
                     comb += dbg.core_stopped_i.eq(1)
                     # while stopped, allow updating the PC and SVSTATE
                     with m.If(self.pc_i.ok):
@@ -861,6 +929,10 @@ class TestIssuerInternal(Elaboratable):
         # set up peripherals and core
         core_rst = self.setup_peripherals(m)
 
+        # reset current state if core reset requested
+        with m.If(core_rst):
+            m.d.sync += self.cur_state.eq(0)
+
         # PC and instruction from I-Memory
         comb += self.pc_o.eq(cur_state.pc)
         pc_changed = Signal() # note write to PC
@@ -868,9 +940,11 @@ class TestIssuerInternal(Elaboratable):
 
         # read state either from incoming override or from regfile
         # TODO: really should be doing MSR in the same way
-        pc = state_get(m, self.pc_i, "pc",                  # read PC
+        pc = state_get(m, core_rst, self.pc_i,
+                            "pc",                  # read PC
                             self.state_r_pc, StateRegs.PC)
-        svstate = state_get(m, self.svstate_i, "svstate",   # read SVSTATE
+        svstate = state_get(m, core_rst, self.svstate_i,
+                            "svstate",   # read SVSTATE
                             self.state_r_sv, StateRegs.SVSTATE)
 
         # don't write pc every cycle
@@ -882,7 +956,7 @@ class TestIssuerInternal(Elaboratable):
 
         # address of the next instruction, in the absence of a branch
         # depends on the instruction size
-        nia = Signal(64, reset_less=True)
+        nia = Signal(64)
 
         # connect up debug signals
         # TODO comb += core.icache_rst_i.eq(dbg.icache_rst_o)
@@ -953,6 +1027,10 @@ class TestIssuerInternal(Elaboratable):
                          exec_insn_valid_i, exec_insn_ready_o,
                          exec_pc_valid_o, exec_pc_ready_i)
 
+        # whatever was done above, over-ride it if core reset is held
+        with m.If(core_rst):
+            sync += nia.eq(0)
+
         # this bit doesn't have to be in the FSM: connect up to read
         # regfiles on demand from DMI
         self.do_dmi(m, dbg)
@@ -1116,6 +1194,7 @@ class TestIssuer(Elaboratable):
         self.pll_en = hasattr(pspec, "use_pll") and pspec.use_pll
         if self.pll_en:
             self.pll_18_o = Signal(reset_less=True)
+            self.clk_sel_i = Signal(reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
@@ -1144,6 +1223,9 @@ class TestIssuer(Elaboratable):
             # output 18 mhz PLL test signal
             comb += self.pll_18_o.eq(pll.pll_18_o)
 
+            # input to pll clock selection
+            comb += Cat(pll.sel_a0_i, pll.sel_a1_i).eq(self.clk_sel_i)
+
             # now wire up ResetSignals.  don't mind them being in this domain
             pll_rst = ResetSignal("pllclk")
             comb += pll_rst.eq(ResetSignal())
@@ -1167,9 +1249,9 @@ class TestIssuer(Elaboratable):
         ports.append(ClockSignal())
         ports.append(ResetSignal())
         if self.pll_en:
-            ports.append(self.pll.clk_sel_i)
+            ports.append(self.clk_sel_i)
             ports.append(self.pll_18_o)
-            ports.append(self.pll.pll_lck_o)
+            ports.append(self.pll.pll_ana_o)
         return ports