Skip leading zero bits on predicate masks
authorCesar Strauss <cestrauss@gmail.com>
Tue, 30 Mar 2021 11:57:48 +0000 (08:57 -0300)
committerCesar Strauss <cestrauss@gmail.com>
Tue, 30 Mar 2021 11:57:48 +0000 (08:57 -0300)
The PRED_SKIP state moves src/dst step to the next non-zero bit on the
mask.
The leading zeros on the mask (plus the set bit) are shifted out, while
the shifted amount is added to the step.
If the new step value would increase past VL, the loop is ended.

src/soc/simple/issuer.py

index 9b7ed9c59d4f23a4ecfe37ef9e48137c5d1e98bb..e2f66a50fa1142bad706ea8a1210dcf1cbf89b05 100644 (file)
@@ -16,11 +16,13 @@ improved.
 """
 
 from nmigen import (Elaboratable, Module, Signal, ClockSignal, ResetSignal,
-                    ClockDomain, DomainRenamer, Mux, Const, Repl)
+                    ClockDomain, DomainRenamer, Mux, Const, Repl, Cat)
 from nmigen.cli import rtlil
 from nmigen.cli import main
 import sys
 
+from nmigen.lib.coding import PriorityEncoder
+
 from soc.decoder.power_decoder import create_pdecode
 from soc.decoder.power_decoder2 import PowerDecode2, SVP64PrefixDecoder
 from soc.decoder.decode2execute1 import IssuerDecode2ToOperand
@@ -571,41 +573,64 @@ class TestIssuerInternal(Elaboratable):
             with m.State("MASK_WAIT"):
                 comb += pred_mask_ready_i.eq(1) # ready to receive the masks
                 with m.If(pred_mask_valid_o): # predication masks are ready
-                    # with m.If(is_svp64_mode):
-                    #    TODO advance src/dst step to "skip" over predicated-out
-                    #    from self.srcmask and self.dstmask
-                    #    https://bugs.libre-soc.org/show_bug.cgi?id=617#c3
-                    #    but still without exceeding VL in either case
-                    # IMPORTANT: when changing src/dest step, have to
-                    # jump to m.next = "DECODE_SV" to deal with the change in
-                    # SVSTATE
-
-                    with m.If(is_svp64_mode):
-                        if self.svp64_en:
-                            pred_src_zero = pdecode2.rm_dec.pred_sz
-                            pred_dst_zero = pdecode2.rm_dec.pred_dz
-
-                        """
-                        TODO: actually, can use
-                        PriorityEncoder(self.srcmask | (1<<cur_srcstep))
-
-                        if not pred_src_zero:
-                            if (((1<<cur_srcstep) & self.srcmask) == 0) and
-                                  (cur_srcstep != vl):
-                                comb += update_svstate.eq(1)
-                                comb += new_svstate.srcstep.eq(next_srcstep)
+                    m.next = "PRED_SKIP"
 
-                        if not pred_dst_zero:
-                            if (((1<<cur_dststep) & self.dstmask) == 0) and
-                                  (cur_dststep != vl):
-                                comb += new_svstate.dststep.eq(next_dststep)
-                                comb += update_svstate.eq(1)
-
-                        if update_svstate:
+            # skip zeros in predicate
+            with m.State("PRED_SKIP"):
+                with m.If(~is_svp64_mode):
+                    m.next = "DECODE_SV"  # nothing to do
+                with m.Else():
+                    if self.svp64_en:
+                        pred_src_zero = pdecode2.rm_dec.pred_sz
+                        pred_dst_zero = pdecode2.rm_dec.pred_dz
+
+                        # new srcstep, after skipping zeros
+                        skip_srcstep = Signal.like(cur_srcstep)
+                        # value to be added to the current srcstep
+                        src_delta = Signal.like(cur_srcstep)
+                        # add leading zeros to srcstep, if not in zero mode
+                        with m.If(~pred_src_zero):
+                            # priority encoder (count leading zeros)
+                            # append guard bit, in case the mask is all zeros
+                            pri_enc_src = PriorityEncoder(65)
+                            m.submodules.pri_enc_src = pri_enc_src
+                            comb += pri_enc_src.i.eq(Cat(self.srcmask, 1))
+                            comb += src_delta.eq(pri_enc_src.o)
+                        # apply delta to srcstep
+                        comb += skip_srcstep.eq(cur_srcstep + src_delta)
+                        # shift-out all leading zeros from the mask
+                        # plus the leading "one" bit
+                        sync += self.srcmask.eq(self.srcmask >> (src_delta+1))
+
+                        # same as above, but for dststep
+                        skip_dststep = Signal.like(cur_dststep)
+                        dst_delta = Signal.like(cur_dststep)
+                        with m.If(~pred_dst_zero):
+                            pri_enc_dst = PriorityEncoder(65)
+                            m.submodules.pri_enc_dst = pri_enc_dst
+                            comb += pri_enc_dst.i.eq(Cat(self.dstmask, 1))
+                            comb += dst_delta.eq(pri_enc_dst.o)
+                        comb += skip_dststep.eq(cur_dststep + dst_delta)
+                        sync += self.dstmask.eq(self.dstmask >> (dst_delta+1))
+
+                        # TODO: initialize mask[VL]=1 to avoid passing past VL
+                        with m.If((skip_srcstep >= cur_vl) |
+                                  (skip_dststep >= cur_vl)):
+                            # end of VL loop. Update PC and reset src/dst step
+                            comb += self.state_w_pc.wen.eq(1 << StateRegs.PC)
+                            comb += self.state_w_pc.data_i.eq(nia)
+                            comb += new_svstate.srcstep.eq(0)
+                            comb += new_svstate.dststep.eq(0)
+                            comb += update_svstate.eq(1)
+                            # go back to Issue
+                            m.next = "ISSUE_START"
+                        with m.Else():
+                            # update new src/dst step
+                            comb += new_svstate.srcstep.eq(skip_srcstep)
+                            comb += new_svstate.dststep.eq(skip_dststep)
+                            comb += update_svstate.eq(1)
+                            # proceed to Decode
                             m.next = "DECODE_SV"
-                        """
-
-                    m.next = "DECODE_SV"
 
             # after src/dst step have been updated, we are ready
             # to decode the instruction
@@ -666,7 +691,8 @@ class TestIssuerInternal(Elaboratable):
                             comb += new_svstate.srcstep.eq(next_srcstep)
                             comb += new_svstate.dststep.eq(next_dststep)
                             comb += update_svstate.eq(1)
-                            m.next = "DECODE_SV"
+                            # return to mask skip loop
+                            m.next = "PRED_SKIP"
 
                 with m.Else():
                     comb += core.core_stopped_i.eq(1)