add support for checking sprs and msr in unit tests
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 30 May 2023 08:00:01 +0000 (01:00 -0700)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 2 Jun 2023 18:51:19 +0000 (19:51 +0100)
src/openpower/test/algorithms/svp64_utf_8_validation.py
src/openpower/test/fmv_fcvt/fmv_fcvt.py
src/openpower/test/state.py
src/openpower/test/svp64/parallel_prefix_sum.py

index 772352656c899bfee65ed3268bd652654e1da3a0..5174999ce8908fd1a5dd94dc66bfe4e28f17b0bd 100644 (file)
@@ -359,9 +359,11 @@ class SVP64UTF8ValidationTestCase(TestAccumulatorBase):
         for i, v in enumerate(SECOND_BYTE_HIGH_NIBBLE_LUT):
             initial_mem[i + SECOND_BYTE_HIGH_NIBBLE_LUT_ADDR] = int(v), 1
         stop_at_pc = 0x10000000
-        initial_sprs = {8: SelectableInt(stop_at_pc, 64)}
+        initial_sprs = {8: stop_at_pc}
+        sprs = initial_sprs.copy()
+        sprs['SVSHAPE0'] = None
         e = ExpectedState(pc=stop_at_pc, int_regs=4, crregs=0, fp_regs=0,
-                          so=None, ov=None, ca=None)
+                          so=None, ov=None, ca=None, sprs=sprs)
         e.intregs[:3] = initial_regs[:3]
         e.intregs[3] = expected
         with self.subTest(data=data, expected=expected):
index e4803f7549165c339c8db58c3dfbcbc8a4e78c13..dc696ac00426359b11bc8b1761013b7c62131a99 100644 (file)
@@ -136,7 +136,9 @@ class FMvFCvtCases(TestAccumulatorBase):
                 # https://bugs.libre-soc.org/show_bug.cgi?id=1087#c21
                 expected = e.intregs[3]
                 e.pc = 0x700
-                # MSR and other SPRS not tested by ExpectedState
+                e.sprs['SRR0'] = 0  # insn is at address 0
+                e.sprs['SRR1'] = e.msr | (1 << (63 - 43))
+                e.msr = 0x9000000000000001
             lt = bool(expected & (1 << 63))
             gt = not lt and expected != 0
             eq = expected == 0
@@ -150,8 +152,7 @@ class FMvFCvtCases(TestAccumulatorBase):
                 e.fpscr = int(fpscr)
                 self.add_case(
                     _cached_program(*lst), gprs, fpregs=fprs, expected=e,
-                    initial_fpscr=int(initial_fpscr),
-                    initial_msr=(1 << MSR.FE0) | (1 << MSR.FE1))
+                    initial_fpscr=int(initial_fpscr))
 
     def toint(self, inp, expected=None, test_title="", inp_bits=None,
               signed=True, _32bit=True):
index 3841891dadb61c8aa6f9eaa37996c16c3f7f339c..5d61c27753a2d7754c5a590834a71d70ea53dcdb 100644 (file)
@@ -22,10 +22,12 @@ methods, the use of yield from/yield is required.
 """
 
 
-from openpower.decoder.power_enums import XER_bits
+from openpower.decoder.power_enums import XER_bits, SPRfull
 from openpower.decoder.isa.radixmmu import RADIX
 from openpower.util import log
 from openpower.fpscr import FPSCRState
+from openpower.decoder.selectable_int import SelectableInt
+from openpower.consts import DEFAULT_MSR
 import os
 import sys
 from copy import deepcopy
@@ -62,6 +64,59 @@ class StateRunner:
         if False: yield
 
 
+class StateSPRs:
+    KEYS = tuple(i for i in SPRfull if i != SPRfull.XER)
+
+    def __init__(self, values=None):
+        self.__values = {k: 0 for k in StateSPRs.KEYS}
+        if values is not None:
+            for k, v in values.items():
+                self[k] = v
+
+    @staticmethod
+    def __key(k, raise_if_invalid=True):
+        try:
+            if isinstance(k, str):
+                retval = SPRfull.__members__[k]
+            else:
+                retval = SPRfull(k)
+        except (ValueError, KeyError):
+            retval = None
+        if retval == SPRfull.XER:  # XER is not stored in StateSPRs
+            retval = None
+        if retval is None and raise_if_invalid:
+            raise KeyError(k)
+        return retval
+
+    def items(self):
+        for k in StateSPRs.KEYS:
+            yield (k, self[k])
+
+    def __iter__(self):
+        return iter(StateSPRs.KEYS)
+
+    def __len__(self):
+        return len(StateSPRs.KEYS)
+
+    def __contains__(self, k):
+        return self.__key(k, raise_if_invalid=False) is not None
+
+    def __getitem__(self, k):
+        return self.__values[self.__key(k)]
+
+    def __setitem__(self, k, v):
+        k = self.__key(k)
+        if v is not None:
+            v = int(v)
+        self.__values[k] = v
+
+    def nonzero(self):
+        return {k: v for k, v in self.__values.items() if v != 0}
+
+    def __repr__(self):
+        return repr(self.nonzero())
+
+
 class State:
     """State: Base class for the "state" of the Power ISA object to be tested
     including methods to compare various registers and memory between
@@ -71,8 +126,20 @@ class State:
 
     GPRs and CRs - stored as lists
     XERs/PC - simple members
+        SO/CA[32]/OV[32] are stored in so/ca/ov members,
+        xer_other is all other XER bits.
+    SPRs - stored in self.sprs as a StateSPRs
     memory - stored as a dictionary {location: data}
     """
+
+    @property
+    def sprs(self):
+        return self.__sprs
+
+    @sprs.setter
+    def sprs(self, value):
+        self.__sprs = StateSPRs(value)
+
     def get_state(self):
         yield from self.get_fpscr()
         yield from self.get_fpregs()
@@ -80,6 +147,8 @@ class State:
         yield from self.get_crregs()
         yield from self.get_xregs()
         yield from self.get_pc()
+        yield from self.get_msr()
+        yield from self.get_sprs()
         yield from self.get_mem()
 
     def compare(self, s2):
@@ -127,6 +196,11 @@ class State:
         if self.ca is not None and s2.ca is not None:
             self.dut.assertEqual(self.ca, s2.ca, "ca mismatch (%s != %s) %s" %
                 (self.state_type, s2.state_type, repr(self.code)))
+        if self.xer_other is not None and s2.xer_other is not None:
+            self.dut.assertEqual(
+                hex(self.xer_other), hex(s2.xer_other),
+                "xer_other mismatch (%s != %s) %s" %
+                (self.state_type, s2.state_type, repr(self.code)))
 
         # pc
         self.dut.assertEqual(self.pc, s2.pc, "pc mismatch (%s != %s) %s" %
@@ -155,6 +229,25 @@ class State:
                 finally:
                     self.dut.maxDiff = old_max_diff
 
+        for spr in self.sprs:
+            spr1 = self.sprs[spr]
+            spr2 = s2.sprs[spr]
+
+            if spr1 == spr2:
+                continue
+
+            if spr1 is not None and spr2 is not None:
+                # if not explicitly ignored
+
+                self.dut.fail(
+                    f"{spr1:#x} != {spr2:#x}: {spr} mismatch "
+                    f"({self.state_type} != {s2.state_type}) {self.code!r}\n")
+
+        if self.msr is not None and s2.msr is not None:
+            self.dut.assertEqual(
+                hex(self.msr), hex(s2.msr), "msr mismatch (%s != %s) %s" %
+                (self.state_type, s2.state_type, repr(self.code)))
+
     def compare_mem(self, s2):
         # copy dics to preserve state mem then pad empty locs since
         # different Power ISA objects may differ how theystore memory
@@ -206,6 +299,20 @@ class State:
             sout.write("%se.ov = 0x%x\n" % (lindent, self.ov))
         if(self.ca != 0):
             sout.write("%se.ca = 0x%x\n" % (lindent, self.ca))
+        if self.xer_other != 0:
+            sout.write("%se.xer_other = 0x%x\n" % (lindent, self.xer_other))
+
+        # FPSCR
+        if self.fpscr != 0:
+            sout.write(f"{lindent}e.fpscr = {self.fpscr:#x}\n")
+
+        # SPRs
+        for k, v in self.sprs.nonzero().items():
+            sout.write(f"{lindent}e.sprs[{k.name!r}] = {v:#x}\n")
+
+        # MSR
+        if self.msr != 0:
+            sout.write(f"{lindent}e.msr = {self.msr:#x}\n")
 
         if sout != sys.stdout:
             sout.close()
@@ -238,9 +345,15 @@ class SimState(State):
     def get_fpscr(self):
         if False:
             yield
-        self.fpscr = self.sim.fpscr.value
+        self.fpscr = int(self.sim.fpscr)
         log("class sim fpscr", hex(self.fpscr))
 
+    def get_msr(self):
+        if False:
+            yield
+        self.msr = int(self.sim.msr)
+        log("class sim msr", hex(self.msr))
+
     def get_intregs(self):
         if False:
             yield
@@ -264,9 +377,34 @@ class SimState(State):
         self.ca32 = self.sim.spr['XER'][XER_bits['CA32']].value
         self.ov = self.ov | (self.ov32 << 1)
         self.ca = self.ca | (self.ca32 << 1)
+        xer_other = SelectableInt(self.sim.spr['XER'])
+        for i in 'SO', 'OV', 'OV32', 'CA', 'CA32':
+            xer_other[XER_bits[i]] = 0
+        self.xer_other = int(xer_other)
         self.xregs.extend((self.so, self.ov, self.ca))
         log("class sim xregs", list(map(hex, self.xregs)))
 
+    def get_sprs(self):
+        if False:
+            yield
+        self.sprs = StateSPRs()
+        for spr in self.sprs:
+            # hacky workaround to workaround luke's hack in caller.py that
+            # aliases HSRR[01] to SRR[01] -- we temporarily clear SRR[01] while
+            # trying to read HSRR[01]
+            clear_srr = spr == SPRfull.HSRR0 or spr == SPRfull.HSRR1
+            if clear_srr:
+                old_srr0 = self.sim.spr['SRR0']
+                old_srr1 = self.sim.spr['SRR1']
+                self.sim.spr['SRR0'] = 0
+                self.sim.spr['SRR1'] = 0
+
+            self.sprs[spr] = self.sim.spr[spr.name]  # setitem converts to int
+
+            if clear_srr:
+                self.sim.spr['SRR0'] = old_srr0
+                self.sim.spr['SRR1'] = old_srr1
+
     def get_pc(self):
         if False:
             yield
@@ -298,7 +436,8 @@ class ExpectedState(State):
     see openpower/test/shift_rot/shift_rot_cases2.py for examples
     """
     def __init__(self, int_regs=None, pc=0, crregs=None,
-                 so=0, ov=0, ca=0, fp_regs=None, fpscr=0):
+                 so=0, ov=0, ca=0, fp_regs=None, fpscr=0, sprs=None,
+                 msr=DEFAULT_MSR, xer_other=0):
         if fp_regs is None:
             fp_regs = 32
         if isinstance(fp_regs, int):
@@ -319,6 +458,9 @@ class ExpectedState(State):
         self.so = so
         self.ov = ov
         self.ca = ca
+        self.xer_other = xer_other
+        self.sprs = StateSPRs(sprs)
+        self.msr = msr
 
     def get_fpregs(self):
         if False: yield
@@ -332,6 +474,15 @@ class ExpectedState(State):
         if False: yield
     def get_pc(self):
         if False: yield
+
+    def get_msr(self):
+        if False:
+            yield
+
+    def get_sprs(self):
+        if False:
+            yield
+
     def get_mem(self):
         if False: yield
 
index 983e1969f807d92eadfc26cc6b6ad29ee0f7b32a..74b408f45d52b597cf761bf6c4aaaaf6a1d85bf8 100644 (file)
@@ -25,6 +25,8 @@ class ParallelPrefixSumCases(TestAccumulatorBase):
             "sv.add *10, *10, *10",
         ])), False)
         e = ExpectedState(pc=0x10, int_regs=gprs)
+        e.sprs['SVSHAPE0'] = 0x1c00000a
+        e.sprs['SVSHAPE1'] = 0x1c00000e
         for i, v in enumerate(expected):
             e.intregs[i + 10] = v
         self.add_case(prog, gprs, expected=e)
@@ -48,6 +50,8 @@ class ParallelPrefixSumCases(TestAccumulatorBase):
             "sv.subf *10, *10, *10",
         ])), False)
         e = ExpectedState(pc=0x10, int_regs=gprs)
+        e.sprs['SVSHAPE0'] = 0x1c00000a
+        e.sprs['SVSHAPE1'] = 0x1c00000e
         for i, v in enumerate(expected):
             e.intregs[i + 10] = v
         self.add_case(prog, gprs, expected=e)