power_enums: unify predicates classes
authorDmitry Selyutin <ghostmansd@gmail.com>
Wed, 16 Nov 2022 19:47:40 +0000 (22:47 +0300)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Fri, 2 Jun 2023 18:51:16 +0000 (19:51 +0100)
src/openpower/decoder/power_enums.py
src/openpower/decoder/power_insn.py

index 4a72e551c7710e556795bad2ce9be31998d26320..72532ec7ce1b1b844813b7c0faf7c4376a4eb5b2 100644 (file)
@@ -291,156 +291,213 @@ class SVP64PredMode(Enum):
     ALWAYS = 0
     INT = 1
     CR = 2
+    RC1 = 3
 
 
 @unique
 class SVP64PredInt(Enum):
-    ALWAYS = 0
-    R3_UNARY = 1
-    R3 = 2
-    R3_N = 3
-    R10 = 4
-    R10_N = 5
-    R30 = 6
-    R30_N = 7
+    ALWAYS = 0b000
+    R3_UNARY = 0b001
+    R3 = 0b010
+    R3_N = 0b011
+    R10 = 0b100
+    R10_N = 0b101
+    R30 = 0b110
+    R30_N = 0b111
+
+    @classmethod
+    def _missing_(cls, desc):
+        if isinstance(desc, str):
+            value = desc
+            values = {
+                "^r3": cls.R3_UNARY,
+                "r3": cls.R3,
+                "~r3": cls.R3_N,
+                "r10": cls.R10,
+                "~r10": cls.R10_N,
+                "r30": cls.R30,
+                "~r30": cls.R30_N,
+            }
+            if value.startswith("~"):
+                value = f"~{value[1:].strip()}"
+            elif "<<" in value: # 1 << r3
+                (lhs, _, rhs) = value.partition("<<")
+                lhs = lhs.strip().lower()
+                rhs = rhs.strip().lower()
+                if (lhs == "1") and (rhs in ("r3", "%r3")):
+                    value = "^r3"
+
+            return values.get(value)
+
+        return super()._missing_(desc)
+
+    def __str__(self):
+        return {
+            self.__class__.ALWAYS: "",
+            self.__class__.R3_UNARY: "^r3",
+            self.__class__.R3: "r3",
+            self.__class__.R3_N: "~r3",
+            self.__class__.R10: "r10",
+            self.__class__.R10_N: "~r10",
+            self.__class__.R30: "r30",
+            self.__class__.R30_N: "~r30",
+        }[self]
+
+    def __repr__(self):
+        return f"{self.__class__.__name__}({str(self)})"
+
+    def __int__(self):
+        return self.value
+
+    @property
+    def mode(self):
+        return SVP64PredMode.INT
+
+    @property
+    def inv(self):
+        return (self.value & 0b1)
+
+    @property
+    def state(self):
+        return (self.value >> 1)
 
 
-@unique
 class SVP64PredCR(Enum):
     LT = 0
     GE = 1
+    NL = GE
     GT = 2
     LE = 3
+    NG = LE
     EQ = 4
     NE = 5
     SO = 6
+    UN = SO
     NS = 7
+    NU = NS
 
+    @classmethod
+    def _missing_(cls, desc):
+        if isinstance(desc, str):
+            name = desc.upper()
+            return cls.__members__.get(name)
 
-@unique
-class SVP64RMMode(Enum):
-    NORMAL = 0
-    MAPREDUCE = 1
-    FFIRST = 2
-    SATURATE = 3
-    PREDRES = 4
-    BRANCH = 5
+        return super()._missing_(desc)
 
+    def __int__(self):
+        return self.value
 
-@unique
-class SVP64BCPredMode(Enum):
-    NONE = 0
-    MASKZERO = 1
-    MASKONE = 2
+    @property
+    def mode(self):
+        return SVP64PredMode.CR
+
+    @property
+    def inv(self):
+        return (self.value & 0b1)
+
+    @property
+    def state(self):
+        return (self.value >> 1)
 
 
 @unique
-class SVP64Predicate(Enum):
-    # Integer
-    BITSEL_R3 = ("^r3", True, 0b00)
-    R3 = ("r3", False, 0b01)
-    R3_INV = ("~r3", True, 0b01)
-    R10 = ("r10", False, 0b10)
-    R10_INV = ("~r10", True, 0b10)
-    R30 = ("r30", False, 0b11)
-    R30_INV = ("~r30", True, 0b11)
-
-    # CR
-    LT = ("lt", False, 0b00)
-    NL = ("nl", True, 0b00)
-    GE = ("ge", True, 0b00)
-    GT = ("gt", False, 0b01)
-    NG = ("ng", True, 0b01)
-    LE = ("le", True, 0b01)
-    EQ = ("eq", False, 0b10)
-    NE = ("ne", True, 0b10)
-    SO = ("so", False, 0b11)
-    UN = ("un", False, 0b11)
-    NS = ("ns", True, 0b11)
-    NU = ("nu", True, 0b11)
-
-    # RC1
-    RC1 = ("RC1", False, 0b1)
-    RC1_INV = ("~RC1", True, 0b1)
+class SVP64PredRC1(Enum):
+    RC1 = 0
+    RC1_N = 1
 
     @classmethod
     def _missing_(cls, desc):
         if isinstance(desc, str):
-            members = {}
-            for (name, member) in cls.__members__.items():
-                members[str(member)] = member.name
-
-            member = desc
-            if "RC1" not in member:
-                member = member.lower()
-
-            if member.startswith("~"):
-                member = f"~{member[1:].strip()}"
-            elif "<<" in member:
-                # 1 << r3
-                (lhs, _, rhs) = member.partition("<<")
-                lhs = lhs.strip().lower()
-                rhs = rhs.strip().lower()
-                if (lhs == "1") and (rhs in ("r3", "%r3")):
-                    member = "^r3"
-            member = members.get(member, member)
-            return cls[member]
+            value = desc.upper()
+            if value.startswith("~"):
+                value = f"~{value[1:].strip()}"
+
+            return cls.__members__.get(value)
 
         return super()._missing_(desc)
 
-    def __str__(self):
-        return self.value[0]
+    def __int__(self):
+        return 1
 
     @property
-    def type(self):
-        return SVP64PredicateType(self)
+    def mode(self):
+        return SVP64PredMode.RC1
 
     @property
     def inv(self):
-        return self.value[1]
+        return (self is SVP64PredRC1.RC1_N)
 
     @property
     def state(self):
-        return self.value[2]
+        return 1
+
+
+class SVP64Pred(Enum):
+    ALWAYS = SVP64PredInt.ALWAYS
+    R3_UNARY = SVP64PredInt.R3_UNARY
+    R3 = SVP64PredInt.R3
+    R3_N = SVP64PredInt.R3_N
+    R10 = SVP64PredInt.R10
+    R10_N = SVP64PredInt.R10_N
+    R30 = SVP64PredInt.R30
+    R30_N = SVP64PredInt.R30_N
+
+    LT = SVP64PredCR.LT
+    GE = SVP64PredCR.GE
+    GT = SVP64PredCR.GT
+    LE = SVP64PredCR.LE
+    EQ = SVP64PredCR.EQ
+    NE = SVP64PredCR.NE
+    SO = SVP64PredCR.SO
+    NS = SVP64PredCR.NS
+
+    RC1 = SVP64PredRC1.RC1
+    RC1_N = SVP64PredRC1.RC1_N
+
+    @classmethod
+    def _missing_(cls, desc):
+        if isinstance(desc, str):
+            values = {item.value:item for item in cls}
+            for subcls in (SVP64PredInt, SVP64PredCR, SVP64PredRC1):
+                try:
+                    return values.get(subcls(desc))
+                except ValueError:
+                    pass
+            return None
+
+        return super()._missing_(desc)
+
+    def __int__(self):
+        return int(self.value)
+
+    @property
+    def mode(self):
+        return self.value.mode
 
     @property
-    def mask(self):
-        return ((int(self.state) << 1) | (int(self.inv) << 0))
+    def inv(self):
+        return self.value.inv
 
+    @property
+    def state(self):
+        return self.value.state
 
-class SVP64PredicateType(Enum):
-    INTEGER = auto()
-    BITSEL_R3 = INTEGER
-    R3 = INTEGER
-    R3_INV = INTEGER
-    R10 = INTEGER
-    R10_INV = INTEGER
-    R30 = INTEGER
-    R30_INV = INTEGER
 
-    CR = auto()
-    LT = CR
-    NL = CR
-    GE = CR
-    GT = CR
-    NG = CR
-    LE = CR
-    EQ = CR
-    NE = CR
-    SO = CR
-    UN = CR
-    NS = CR
-    NU = CR
-
-    RC1 = auto()
-    RC1_INV = RC1
+@unique
+class SVP64RMMode(Enum):
+    NORMAL = 0
+    MAPREDUCE = 1
+    FFIRST = 2
+    SATURATE = 3
+    PREDRES = 4
+    BRANCH = 5
 
-    @classmethod
-    def _missing_(cls, desc):
-        if isinstance(desc, SVP64Predicate):
-            return cls.__members__.get(desc.name)
 
-        return super()._missing_(desc)
+@unique
+class SVP64BCPredMode(Enum):
+    NONE = 0
+    MASKZERO = 1
+    MASKONE = 2
 
 
 @unique
index c8b59e2e1e04e977b04fd312cb506ec74f4f161c..417652eca01d7805dba46fe1ed17712fb99abdaa 100644 (file)
@@ -38,9 +38,9 @@ from openpower.decoder.power_enums import (
     SVP64RMMode as _SVP64RMMode,
     SVExtraRegType as _SVExtraRegType,
     SVExtraReg as _SVExtraReg,
-    SVP64Predicate as _SVP64Predicate,
-    SVP64PredicateType as _SVP64PredicateType,
     SVP64SubVL as _SVP64SubVL,
+    SVP64Pred as _SVP64Pred,
+    SVP64PredMode as _SVP64PredMode,
 )
 from openpower.decoder.selectable_int import (
     SelectableInt as _SelectableInt,
@@ -2508,7 +2508,7 @@ class SpecifierSubVL(Specifier):
 @_dataclasses.dataclass(eq=True, frozen=True)
 class SpecifierPredicate(Specifier):
     mode: str
-    pred: _SVP64Predicate
+    pred: _SVP64Pred
 
     @classmethod
     def match(cls, desc, record, mode_match, pred_match):
@@ -2518,7 +2518,7 @@ class SpecifierPredicate(Specifier):
         if not mode_match(mode):
             return None
 
-        pred = _SVP64Predicate(pred.strip())
+        pred = _SVP64Pred(pred.strip())
         if not pred_match(pred):
             raise ValueError(pred)
 
@@ -2531,9 +2531,9 @@ class SpecifierFFPR(SpecifierPredicate):
     def match(cls, desc, record, mode):
         return super().match(desc=desc, record=record,
             mode_match=lambda mode_arg: mode_arg == mode,
-            pred_match=lambda pred_arg: pred_arg.type in (
-                _SVP64PredicateType.CR,
-                _SVP64PredicateType.RC1,
+            pred_match=lambda pred_arg: pred_arg.mode in (
+                _SVP64PredMode.CR,
+                _SVP64PredMode.RC1,
             ))
 
     def assemble(self, insn):
@@ -2599,9 +2599,9 @@ class SpecifierMask(SpecifierPredicate):
     def match(cls, desc, record, mode):
         return super().match(desc=desc, record=record,
             mode_match=lambda mode_arg: mode_arg == mode,
-            pred_match=lambda pred_arg: pred_arg.type in (
-                _SVP64PredicateType.INTEGER,
-                _SVP64PredicateType.CR,
+            pred_match=lambda pred_arg: pred_arg.mode in (
+                _SVP64PredMode.INT,
+                _SVP64PredMode.CR,
             ))
 
     def assemble(self, insn):
@@ -2625,7 +2625,7 @@ class SpecifierM(SpecifierMask):
             spec.validate(others=items)
 
     def assemble(self, insn):
-        insn.prefix.rm.mask = self.pred.mask
+        insn.prefix.rm.mask = int(self.pred)
 
 
 @_dataclasses.dataclass(eq=True, frozen=True)
@@ -2638,7 +2638,7 @@ class SpecifierSM(SpecifierMask):
         if self.record.svp64.ptype is _SVPType.P1:
             raise ValueError("source-mask on non-twin predicate")
 
-        if self.pred.type is _SVP64PredicateType.CR:
+        if self.pred.mode is _SVP64PredMode.CR:
             twin = None
             items = list(others)
             while items:
@@ -2653,7 +2653,7 @@ class SpecifierSM(SpecifierMask):
                 raise ValueError(f"predicate masks mismatch: {self!r} vs {twin!r}")
 
     def assemble(self, insn):
-        insn.prefix.rm.smask = self.pred.mask
+        insn.prefix.rm.smask = int(self.pred)
 
 
 @_dataclasses.dataclass(eq=True, frozen=True)
@@ -2666,7 +2666,7 @@ class SpecifierDM(SpecifierMask):
         if self.record.svp64.ptype is _SVPType.P1:
             raise ValueError("dest-mask on non-twin predicate")
 
-        if self.pred.type is _SVP64PredicateType.CR:
+        if self.pred.mode is _SVP64PredMode.CR:
             twin = None
             items = list(others)
             while items:
@@ -2681,7 +2681,7 @@ class SpecifierDM(SpecifierMask):
                 raise ValueError(f"predicate masks mismatch: {self!r} vs {twin!r}")
 
     def assemble(self, insn):
-        insn.prefix.rm.mask = self.pred.mask
+        insn.prefix.rm.mask = int(self.pred)
 
 
 class Specifiers(tuple):