WIP rewriting compiler IR so regalloc works correctly
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 27 Oct 2022 08:02:11 +0000 (01:02 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 27 Oct 2022 08:02:11 +0000 (01:02 -0700)
src/bigint_presentation_code/compiler_ir2.py [new file with mode: 0644]

diff --git a/src/bigint_presentation_code/compiler_ir2.py b/src/bigint_presentation_code/compiler_ir2.py
new file mode 100644 (file)
index 0000000..eacceb4
--- /dev/null
@@ -0,0 +1,528 @@
+import enum
+from enum import Enum, unique
+from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+
+from cached_property import cached_property
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
+from weakref import WeakValueDictionary
+
+
+@final
+class Fn:
+    def __init__(self):
+        self.ops = []  # type: list[Op]
+        op_names = WeakValueDictionary()
+        self.__op_names = op_names  # type: WeakValueDictionary[str, Op]
+        self.__next_name_suffix = 2
+
+    def _add_op_with_unused_name(self, op, name=""):
+        # type: (Op, str) -> str
+        if op.fn is not self:
+            raise ValueError("can't add Op to wrong Fn")
+        if hasattr(op, "name"):
+            raise ValueError("Op already named")
+        orig_name = name
+        while True:
+            if name not in self.__op_names:
+                self.__op_names[name] = op
+                return name
+            name = orig_name + str(self.__next_name_suffix)
+            self.__next_name_suffix += 1
+
+    def __repr__(self):
+        return "<Fn>"
+
+
+@unique
+@final
+class RegKind(Enum):
+    GPR = enum.auto()
+    CA = enum.auto()
+    VL_MAXVL = enum.auto()
+
+    @cached_property
+    def only_scalar(self):
+        if self is RegKind.GPR:
+            return False
+        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+            return True
+        else:
+            assert_never(self)
+
+    @cached_property
+    def reg_count(self):
+        if self is RegKind.GPR:
+            return 128
+        elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+            return 1
+        else:
+            assert_never(self)
+
+    def __repr__(self):
+        return "RegKind." + self._name_
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OperandType:
+    __slots__ = "kind", "vec"
+
+    def __init__(self, kind, vec):
+        # type: (RegKind, bool) -> None
+        self.kind = kind
+        if kind.only_scalar and vec:
+            raise ValueError(f"kind={kind} must have vec=False")
+        self.vec = vec
+
+    def get_length(self, maxvl):
+        # type: (int) -> int
+        # here's where subvl and elwid would be accounted for
+        if self.vec:
+            return maxvl
+        return 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class RegShape:
+    __slots__ = "kind", "length"
+
+    def __init__(self, kind, length=1):
+        # type: (RegKind, int) -> None
+        self.kind = kind
+        if length < 1 or length > kind.reg_count:
+            raise ValueError("invalid length")
+        self.length = length
+
+    def try_concat(self, *others):
+        # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
+        kind = self.kind
+        length = self.length
+        for other in others:
+            if isinstance(other, (Reg, RegClass)):
+                other = other.shape
+            if other is None:
+                return None
+            if other.kind != self.kind:
+                return None
+            length += other.length
+        if length > kind.reg_count:
+            return None
+        return RegShape(kind=kind, length=length)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Reg:
+    __slots__ = "shape", "start"
+
+    def __init__(self, shape, start):
+        # type: (RegShape, int) -> None
+        self.shape = shape
+        if start < 0 or start + shape.length > shape.kind.reg_count:
+            raise ValueError("start not in valid range")
+        self.start = start
+
+    @property
+    def kind(self):
+        return self.shape.kind
+
+    @property
+    def length(self):
+        return self.shape.length
+
+    def conflicts(self, other):
+        # type: (Reg) -> bool
+        return (self.kind == other.kind
+                and self.start < other.stop and other.start < self.stop)
+
+    @property
+    def stop(self):
+        return self.start + self.length
+
+    def try_concat(self, *others):
+        # type: (*Reg | None) -> Reg | None
+        shape = self.shape.try_concat(*others)
+        if shape is None:
+            return None
+        stop = self.stop
+        for other in others:
+            assert other is not None, "already caught by RegShape.try_concat"
+            if stop != other.start:
+                return None
+            stop = other.stop
+        return Reg(shape, self.start)
+
+
+@final
+class RegClass(AbstractSet[Reg]):
+    def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
+        # type: (Iterable[Reg | int], RegShape | None, int) -> None
+        for reg_or_start in regs_or_starts:
+            if isinstance(reg_or_start, Reg):
+                if shape is None:
+                    shape = reg_or_start.shape
+                elif shape != reg_or_start.shape:
+                    raise ValueError(f"conflicting RegShapes: {shape} and "
+                                     f"{reg_or_start.shape}")
+                start = reg_or_start.start
+            else:
+                start = reg_or_start
+            if start < 0:
+                raise ValueError("a Reg's start is out of range")
+            starts_bitset |= 1 << start
+        if starts_bitset == 0:
+            shape = None
+        self.__shape = shape
+        self.__starts_bitset = starts_bitset
+        if shape is None:
+            if starts_bitset != 0:
+                raise ValueError("non-empty RegClass must have non-None shape")
+            return
+        if self.stops_bitset >= 1 << shape.kind.reg_count:
+            raise ValueError("a Reg's start is out of range")
+
+    @property
+    def shape(self):
+        # type: () -> RegShape | None
+        return self.__shape
+
+    @property
+    def starts_bitset(self):
+        # type: () -> int
+        return self.__starts_bitset
+
+    @property
+    def stops_bitset(self):
+        # type: () -> int
+        if self.__shape is None:
+            return 0
+        return self.__starts_bitset << self.__shape.length
+
+    @cached_property
+    def starts(self):
+        # type: () -> OFSet[int]
+        if self.length is None:
+            return OFSet()
+        # TODO: fixme
+        # return OFSet(for i in range(self.length))
+
+    @cached_property
+    def stops(self):
+        # type: () -> OFSet[int]
+        if self.__shape is None:
+            return OFSet()
+        return OFSet(i + self.__shape.length for i in self.__starts)
+
+    @property
+    def kind(self):
+        if self.__shape is None:
+            return None
+        return self.__shape.kind
+
+    @property
+    def length(self):
+        """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
+        if self.__shape is None:
+            return None
+        return self.__shape.length
+
+    def concat(self, *others):
+        # type: (*RegClass) -> RegClass
+        shape = self.__shape
+        if shape is None:
+            return RegClass()
+        shape = shape.try_concat(*others)
+        if shape is None:
+            return RegClass()
+        starts = OSet(self.starts)
+        offset = shape.length
+        for other in others:
+            assert other.__shape is not None, \
+                "already caught by RegShape.try_concat"
+            starts &= OSet(i - offset for i in other.starts)
+            offset += other.__shape.length
+        return RegClass(starts, shape=shape)
+
+    def __contains__(self, reg):
+        # type: (Reg) -> bool
+        return reg.shape == self.shape and reg.start in self.starts
+
+    def __iter__(self):
+        # type: () -> Iterator[Reg]
+        if self.shape is None:
+            return
+        for start in self.starts:
+            yield Reg(shape=self.shape, start=start)
+
+    def __len__(self):
+        return len(self.starts)
+
+    def __hash__(self):
+        return super()._hash()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Operand:
+    __slots__ = "ty", "regs"
+
+    def __init__(self, ty, regs=None):
+        # type: (OperandType, OFSet[int] | None) -> None
+        pass
+
+
+OT_VGPR = OperandType(RegKind.GPR, vec=True)
+OT_SGPR = OperandType(RegKind.GPR, vec=False)
+OT_CA = OperandType(RegKind.CA, vec=False)
+OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class TiedOutput:
+    __slots__ = "input_index", "output_index"
+
+    def __init__(self, input_index, output_index):
+        # type: (int, int) -> None
+        self.input_index = input_index
+        self.output_index = output_index
+
+
+Constraint = Union[TiedOutput, NoReturn]
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OpProperties:
+    __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+                 "is_copy", "is_load_immediate", "has_side_effects")
+
+    def __init__(self, demo_asm,  # type: str
+                 inputs,  # type: Iterable[OperandType]
+                 outputs,  # type: Iterable[OperandType]
+                 immediates,  # type: Iterable[range]
+                 constraints,  # type: Iterable[Constraint]
+                 is_copy=False,  # type: bool
+                 is_load_immediate=False,  # type: bool
+                 has_side_effects=False,  # type: bool
+                 ):
+        # type: (...) -> None
+        self.demo_asm = demo_asm
+        self.inputs = tuple(inputs)
+        self.outputs = tuple(outputs)
+        self.immediates = tuple(immediates)
+        self.constraints = tuple(constraints)
+        self.is_copy = is_copy
+        self.is_load_immediate = is_load_immediate
+        self.has_side_effects = has_side_effects
+
+
+@unique
+@final
+class OpKind(Enum):
+    def __init__(self, properties):
+        # type: (OpProperties) -> None
+        super().__init__()
+        self.properties = properties
+
+    SvAddE = OpProperties(
+        demo_asm="sv.adde *RT, *RA, *RB",
+        inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
+        outputs=(OT_VGPR, OT_CA),
+        immediates=(),
+        constraints=(),
+    )
+    SvSubFE = OpProperties(
+        demo_asm="sv.subfe *RT, *RA, *RB",
+        inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
+        outputs=(OT_VGPR, OT_CA),
+        immediates=(),
+        constraints=(),
+    )
+    SvMAddEDU = OpProperties(
+        demo_asm="sv.maddedu *RT, *RA, RB, RC",
+        inputs=(OT_VGPR, OT_SGPR, OT_SGPR, OT_VL),
+        outputs=(OT_VGPR, OT_SGPR),
+        immediates=(),
+        constraints=(),
+    )
+    SetVLI = OpProperties(
+        demo_asm="setvl 0, 0, imm, 0, 1, 1",
+        inputs=(),
+        outputs=(OT_VL,),
+        immediates=(range(1, 65),),
+        constraints=(),
+        is_load_immediate=True,
+    )
+    SvLI = OpProperties(
+        demo_asm="sv.addi *RT, 0, imm",
+        inputs=(OT_VL,),
+        outputs=(OT_VGPR,),
+        immediates=(range(-2 ** 15, 2 ** 15),),
+        constraints=(),
+        is_load_immediate=True,
+    )
+    LI = OpProperties(
+        demo_asm="addi RT, 0, imm",
+        inputs=(),
+        outputs=(OT_SGPR,),
+        immediates=(range(-2 ** 15, 2 ** 15),),
+        constraints=(),
+        is_load_immediate=True,
+    )
+    SvMv = OpProperties(
+        demo_asm="sv.or *RT, *src, *src",
+        inputs=(OT_VGPR, OT_VL),
+        outputs=(OT_VGPR,),
+        immediates=(),
+        constraints=(),
+        is_copy=True,
+    )
+    Mv = OpProperties(
+        demo_asm="mv RT, src",
+        inputs=(OT_SGPR,),
+        outputs=(OT_SGPR,),
+        immediates=(),
+        constraints=(),
+        is_copy=True,
+    )
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAVal:
+    __slots__ = "sliced_op_outputs",
+
+    _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
+                              "tuple[Op, int]", "SSAVal"]
+
+    @staticmethod
+    def __process_sliced_op_outputs(inp):
+        # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
+        for v in inp:
+            if isinstance(v, SSAVal):
+                yield from v.sliced_op_outputs
+                continue
+            op = v[0]
+            output_index = v[1]
+            if output_index < 0 or output_index >= len(op.properties.outputs):
+                raise ValueError("invalid output_index")
+            cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
+            slice_ = slice(None) if len(v) == 2 else v[2]
+            if isinstance(slice_, range):
+                slice_ = slice(slice_.start, slice_.stop, slice_.step)
+            if isinstance(slice_, int):
+                # raise exception for out-of-range values
+                idx = range(cur_len)[slice_]
+                range_ = range(idx, idx + 1)
+            else:
+                # raise exception for out-of-range values
+                range_ = range(cur_len)[slice_]
+                if range_.step != 1:
+                    raise ValueError("slice step must be 1")
+                if len(range_) == 0:
+                    continue
+            yield op, output_index, range_
+
+    def __init__(self, sliced_op_outputs):
+        # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
+        # we have length arg so plain_data.replace works
+        if isinstance(sliced_op_outputs, SSAVal):
+            inp = sliced_op_outputs.sliced_op_outputs
+        else:
+            inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
+        processed = []  # type: list[tuple[Op, int, range]]
+        length = 0
+        for op, output_index, range_ in inp:
+            length += len(range_)
+            if len(processed) == 0:
+                processed.append((op, output_index, range_))
+                continue
+            last_op, last_output_index, last_range_ = processed[-1]
+            if last_op == op and last_output_index == output_index \
+                    and last_range_.stop == range_.start:
+                # merge slices
+                range_ = range(last_range_.start, range_.stop)
+                processed[-1] = op, output_index, range_
+            else:
+                processed.append((op, output_index, range_))
+        self.sliced_op_outputs = tuple(processed)
+
+    def __add__(self, other):
+        # type: (SSAVal) -> SSAVal
+        if not isinstance(other, SSAVal):
+            return NotImplemented
+        return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
+
+    def __radd__(self, other):
+        # type: (SSAVal) -> SSAVal
+        if isinstance(other, SSAVal):
+            return other.__add__(self)
+        return NotImplemented
+
+    @cached_property
+    def expanded_sliced_op_outputs(self):
+        # type: () -> tuple[tuple[Op, int, int], ...]
+        retval = []
+        for op, output_index, range_ in self.sliced_op_outputs:
+            for i in range_:
+                retval.append((op, output_index, i))
+        # must be tuple to not be modifiable since it's cached
+        return tuple(retval)
+
+    def __getitem__(self, idx):
+        # type: (int | slice) -> SSAVal
+        if isinstance(idx, int):
+            return SSAVal([self.expanded_sliced_op_outputs[idx]])
+        return SSAVal(self.expanded_sliced_op_outputs[idx])
+
+    def __len__(self):
+        return len(self.expanded_sliced_op_outputs)
+
+    def __iter__(self):
+        # type: () -> Iterator[SSAVal]
+        for v in self.expanded_sliced_op_outputs:
+            yield SSAVal([v])
+
+    def __repr__(self):
+        # type: () -> str
+        if len(self.sliced_op_outputs) == 0:
+            return "SSAVal([])"
+        parts = []
+        for op, output_index, range_ in self.sliced_op_outputs:
+            out_len = op.properties.outputs[output_index].get_length(op.maxvl)
+            parts.append(f"<{op.name}#{output_index}>")
+            if range_ != range(out_len):
+                parts[-1] += f"[{range_.start}:{range_.stop}]"
+        return " + ".join(parts)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class Op:
+    __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
+
+    def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
+        # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
+        self.fn = fn
+        self.kind = kind
+        self.inputs = list(inputs)
+        self.immediates = list(immediates)
+        self.maxvl = maxvl
+        outputs_len = len(self.properties.outputs)
+        self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
+        self.name = fn._add_op_with_unused_name(self, name)
+
+    @property
+    def properties(self):
+        return self.kind.properties
+
+    def __eq__(self, other):
+        if isinstance(other, Op):
+            return self is other
+        return NotImplemented
+
+    def __hash__(self):
+        return object.__hash__(self)