--- /dev/null
+import enum
+from enum import Enum, unique
+from typing import AbstractSet, Iterable, Iterator, NoReturn, Tuple, Union, overload
+
+from cached_property import cached_property
+from nmutil.plain_data import plain_data
+
+from bigint_presentation_code.util import OFSet, OSet, Self, assert_never, final
+from weakref import WeakValueDictionary
+
+
+@final
+class Fn:
+ def __init__(self):
+ self.ops = [] # type: list[Op]
+ op_names = WeakValueDictionary()
+ self.__op_names = op_names # type: WeakValueDictionary[str, Op]
+ self.__next_name_suffix = 2
+
+ def _add_op_with_unused_name(self, op, name=""):
+ # type: (Op, str) -> str
+ if op.fn is not self:
+ raise ValueError("can't add Op to wrong Fn")
+ if hasattr(op, "name"):
+ raise ValueError("Op already named")
+ orig_name = name
+ while True:
+ if name not in self.__op_names:
+ self.__op_names[name] = op
+ return name
+ name = orig_name + str(self.__next_name_suffix)
+ self.__next_name_suffix += 1
+
+ def __repr__(self):
+ return "<Fn>"
+
+
+@unique
+@final
+class RegKind(Enum):
+ GPR = enum.auto()
+ CA = enum.auto()
+ VL_MAXVL = enum.auto()
+
+ @cached_property
+ def only_scalar(self):
+ if self is RegKind.GPR:
+ return False
+ elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ return True
+ else:
+ assert_never(self)
+
+ @cached_property
+ def reg_count(self):
+ if self is RegKind.GPR:
+ return 128
+ elif self is RegKind.CA or self is RegKind.VL_MAXVL:
+ return 1
+ else:
+ assert_never(self)
+
+ def __repr__(self):
+ return "RegKind." + self._name_
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OperandType:
+ __slots__ = "kind", "vec"
+
+ def __init__(self, kind, vec):
+ # type: (RegKind, bool) -> None
+ self.kind = kind
+ if kind.only_scalar and vec:
+ raise ValueError(f"kind={kind} must have vec=False")
+ self.vec = vec
+
+ def get_length(self, maxvl):
+ # type: (int) -> int
+ # here's where subvl and elwid would be accounted for
+ if self.vec:
+ return maxvl
+ return 1
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class RegShape:
+ __slots__ = "kind", "length"
+
+ def __init__(self, kind, length=1):
+ # type: (RegKind, int) -> None
+ self.kind = kind
+ if length < 1 or length > kind.reg_count:
+ raise ValueError("invalid length")
+ self.length = length
+
+ def try_concat(self, *others):
+ # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
+ kind = self.kind
+ length = self.length
+ for other in others:
+ if isinstance(other, (Reg, RegClass)):
+ other = other.shape
+ if other is None:
+ return None
+ if other.kind != self.kind:
+ return None
+ length += other.length
+ if length > kind.reg_count:
+ return None
+ return RegShape(kind=kind, length=length)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Reg:
+ __slots__ = "shape", "start"
+
+ def __init__(self, shape, start):
+ # type: (RegShape, int) -> None
+ self.shape = shape
+ if start < 0 or start + shape.length > shape.kind.reg_count:
+ raise ValueError("start not in valid range")
+ self.start = start
+
+ @property
+ def kind(self):
+ return self.shape.kind
+
+ @property
+ def length(self):
+ return self.shape.length
+
+ def conflicts(self, other):
+ # type: (Reg) -> bool
+ return (self.kind == other.kind
+ and self.start < other.stop and other.start < self.stop)
+
+ @property
+ def stop(self):
+ return self.start + self.length
+
+ def try_concat(self, *others):
+ # type: (*Reg | None) -> Reg | None
+ shape = self.shape.try_concat(*others)
+ if shape is None:
+ return None
+ stop = self.stop
+ for other in others:
+ assert other is not None, "already caught by RegShape.try_concat"
+ if stop != other.start:
+ return None
+ stop = other.stop
+ return Reg(shape, self.start)
+
+
+@final
+class RegClass(AbstractSet[Reg]):
+ def __init__(self, regs_or_starts=(), shape=None, starts_bitset=0):
+ # type: (Iterable[Reg | int], RegShape | None, int) -> None
+ for reg_or_start in regs_or_starts:
+ if isinstance(reg_or_start, Reg):
+ if shape is None:
+ shape = reg_or_start.shape
+ elif shape != reg_or_start.shape:
+ raise ValueError(f"conflicting RegShapes: {shape} and "
+ f"{reg_or_start.shape}")
+ start = reg_or_start.start
+ else:
+ start = reg_or_start
+ if start < 0:
+ raise ValueError("a Reg's start is out of range")
+ starts_bitset |= 1 << start
+ if starts_bitset == 0:
+ shape = None
+ self.__shape = shape
+ self.__starts_bitset = starts_bitset
+ if shape is None:
+ if starts_bitset != 0:
+ raise ValueError("non-empty RegClass must have non-None shape")
+ return
+ if self.stops_bitset >= 1 << shape.kind.reg_count:
+ raise ValueError("a Reg's start is out of range")
+
+ @property
+ def shape(self):
+ # type: () -> RegShape | None
+ return self.__shape
+
+ @property
+ def starts_bitset(self):
+ # type: () -> int
+ return self.__starts_bitset
+
+ @property
+ def stops_bitset(self):
+ # type: () -> int
+ if self.__shape is None:
+ return 0
+ return self.__starts_bitset << self.__shape.length
+
+ @cached_property
+ def starts(self):
+ # type: () -> OFSet[int]
+ if self.length is None:
+ return OFSet()
+ # TODO: fixme
+ # return OFSet(for i in range(self.length))
+
+ @cached_property
+ def stops(self):
+ # type: () -> OFSet[int]
+ if self.__shape is None:
+ return OFSet()
+ return OFSet(i + self.__shape.length for i in self.__starts)
+
+ @property
+ def kind(self):
+ if self.__shape is None:
+ return None
+ return self.__shape.kind
+
+ @property
+ def length(self):
+ """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
+ if self.__shape is None:
+ return None
+ return self.__shape.length
+
+ def concat(self, *others):
+ # type: (*RegClass) -> RegClass
+ shape = self.__shape
+ if shape is None:
+ return RegClass()
+ shape = shape.try_concat(*others)
+ if shape is None:
+ return RegClass()
+ starts = OSet(self.starts)
+ offset = shape.length
+ for other in others:
+ assert other.__shape is not None, \
+ "already caught by RegShape.try_concat"
+ starts &= OSet(i - offset for i in other.starts)
+ offset += other.__shape.length
+ return RegClass(starts, shape=shape)
+
+ def __contains__(self, reg):
+ # type: (Reg) -> bool
+ return reg.shape == self.shape and reg.start in self.starts
+
+ def __iter__(self):
+ # type: () -> Iterator[Reg]
+ if self.shape is None:
+ return
+ for start in self.starts:
+ yield Reg(shape=self.shape, start=start)
+
+ def __len__(self):
+ return len(self.starts)
+
+ def __hash__(self):
+ return super()._hash()
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class Operand:
+ __slots__ = "ty", "regs"
+
+ def __init__(self, ty, regs=None):
+ # type: (OperandType, OFSet[int] | None) -> None
+ pass
+
+
+OT_VGPR = OperandType(RegKind.GPR, vec=True)
+OT_SGPR = OperandType(RegKind.GPR, vec=False)
+OT_CA = OperandType(RegKind.CA, vec=False)
+OT_VL = OperandType(RegKind.VL_MAXVL, vec=False)
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+class TiedOutput:
+ __slots__ = "input_index", "output_index"
+
+ def __init__(self, input_index, output_index):
+ # type: (int, int) -> None
+ self.input_index = input_index
+ self.output_index = output_index
+
+
+Constraint = Union[TiedOutput, NoReturn]
+
+
+@plain_data(frozen=True, unsafe_hash=True)
+@final
+class OpProperties:
+ __slots__ = ("demo_asm", "inputs", "outputs", "immediates", "constraints",
+ "is_copy", "is_load_immediate", "has_side_effects")
+
+ def __init__(self, demo_asm, # type: str
+ inputs, # type: Iterable[OperandType]
+ outputs, # type: Iterable[OperandType]
+ immediates, # type: Iterable[range]
+ constraints, # type: Iterable[Constraint]
+ is_copy=False, # type: bool
+ is_load_immediate=False, # type: bool
+ has_side_effects=False, # type: bool
+ ):
+ # type: (...) -> None
+ self.demo_asm = demo_asm
+ self.inputs = tuple(inputs)
+ self.outputs = tuple(outputs)
+ self.immediates = tuple(immediates)
+ self.constraints = tuple(constraints)
+ self.is_copy = is_copy
+ self.is_load_immediate = is_load_immediate
+ self.has_side_effects = has_side_effects
+
+
+@unique
+@final
+class OpKind(Enum):
+ def __init__(self, properties):
+ # type: (OpProperties) -> None
+ super().__init__()
+ self.properties = properties
+
+ SvAddE = OpProperties(
+ demo_asm="sv.adde *RT, *RA, *RB",
+ inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
+ outputs=(OT_VGPR, OT_CA),
+ immediates=(),
+ constraints=(),
+ )
+ SvSubFE = OpProperties(
+ demo_asm="sv.subfe *RT, *RA, *RB",
+ inputs=(OT_VGPR, OT_VGPR, OT_CA, OT_VL),
+ outputs=(OT_VGPR, OT_CA),
+ immediates=(),
+ constraints=(),
+ )
+ SvMAddEDU = OpProperties(
+ demo_asm="sv.maddedu *RT, *RA, RB, RC",
+ inputs=(OT_VGPR, OT_SGPR, OT_SGPR, OT_VL),
+ outputs=(OT_VGPR, OT_SGPR),
+ immediates=(),
+ constraints=(),
+ )
+ SetVLI = OpProperties(
+ demo_asm="setvl 0, 0, imm, 0, 1, 1",
+ inputs=(),
+ outputs=(OT_VL,),
+ immediates=(range(1, 65),),
+ constraints=(),
+ is_load_immediate=True,
+ )
+ SvLI = OpProperties(
+ demo_asm="sv.addi *RT, 0, imm",
+ inputs=(OT_VL,),
+ outputs=(OT_VGPR,),
+ immediates=(range(-2 ** 15, 2 ** 15),),
+ constraints=(),
+ is_load_immediate=True,
+ )
+ LI = OpProperties(
+ demo_asm="addi RT, 0, imm",
+ inputs=(),
+ outputs=(OT_SGPR,),
+ immediates=(range(-2 ** 15, 2 ** 15),),
+ constraints=(),
+ is_load_immediate=True,
+ )
+ SvMv = OpProperties(
+ demo_asm="sv.or *RT, *src, *src",
+ inputs=(OT_VGPR, OT_VL),
+ outputs=(OT_VGPR,),
+ immediates=(),
+ constraints=(),
+ is_copy=True,
+ )
+ Mv = OpProperties(
+ demo_asm="mv RT, src",
+ inputs=(OT_SGPR,),
+ outputs=(OT_SGPR,),
+ immediates=(),
+ constraints=(),
+ is_copy=True,
+ )
+
+
+@plain_data(frozen=True, unsafe_hash=True, repr=False)
+@final
+class SSAVal:
+ __slots__ = "sliced_op_outputs",
+
+ _SlicedOpOutputIn = Union["tuple[Op, int, int | range | slice]",
+ "tuple[Op, int]", "SSAVal"]
+
+ @staticmethod
+ def __process_sliced_op_outputs(inp):
+ # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
+ for v in inp:
+ if isinstance(v, SSAVal):
+ yield from v.sliced_op_outputs
+ continue
+ op = v[0]
+ output_index = v[1]
+ if output_index < 0 or output_index >= len(op.properties.outputs):
+ raise ValueError("invalid output_index")
+ cur_len = op.properties.outputs[output_index].get_length(op.maxvl)
+ slice_ = slice(None) if len(v) == 2 else v[2]
+ if isinstance(slice_, range):
+ slice_ = slice(slice_.start, slice_.stop, slice_.step)
+ if isinstance(slice_, int):
+ # raise exception for out-of-range values
+ idx = range(cur_len)[slice_]
+ range_ = range(idx, idx + 1)
+ else:
+ # raise exception for out-of-range values
+ range_ = range(cur_len)[slice_]
+ if range_.step != 1:
+ raise ValueError("slice step must be 1")
+ if len(range_) == 0:
+ continue
+ yield op, output_index, range_
+
+ def __init__(self, sliced_op_outputs):
+ # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
+ # we have length arg so plain_data.replace works
+ if isinstance(sliced_op_outputs, SSAVal):
+ inp = sliced_op_outputs.sliced_op_outputs
+ else:
+ inp = SSAVal.__process_sliced_op_outputs(sliced_op_outputs)
+ processed = [] # type: list[tuple[Op, int, range]]
+ length = 0
+ for op, output_index, range_ in inp:
+ length += len(range_)
+ if len(processed) == 0:
+ processed.append((op, output_index, range_))
+ continue
+ last_op, last_output_index, last_range_ = processed[-1]
+ if last_op == op and last_output_index == output_index \
+ and last_range_.stop == range_.start:
+ # merge slices
+ range_ = range(last_range_.start, range_.stop)
+ processed[-1] = op, output_index, range_
+ else:
+ processed.append((op, output_index, range_))
+ self.sliced_op_outputs = tuple(processed)
+
+ def __add__(self, other):
+ # type: (SSAVal) -> SSAVal
+ if not isinstance(other, SSAVal):
+ return NotImplemented
+ return SSAVal(self.sliced_op_outputs + other.sliced_op_outputs)
+
+ def __radd__(self, other):
+ # type: (SSAVal) -> SSAVal
+ if isinstance(other, SSAVal):
+ return other.__add__(self)
+ return NotImplemented
+
+ @cached_property
+ def expanded_sliced_op_outputs(self):
+ # type: () -> tuple[tuple[Op, int, int], ...]
+ retval = []
+ for op, output_index, range_ in self.sliced_op_outputs:
+ for i in range_:
+ retval.append((op, output_index, i))
+ # must be tuple to not be modifiable since it's cached
+ return tuple(retval)
+
+ def __getitem__(self, idx):
+ # type: (int | slice) -> SSAVal
+ if isinstance(idx, int):
+ return SSAVal([self.expanded_sliced_op_outputs[idx]])
+ return SSAVal(self.expanded_sliced_op_outputs[idx])
+
+ def __len__(self):
+ return len(self.expanded_sliced_op_outputs)
+
+ def __iter__(self):
+ # type: () -> Iterator[SSAVal]
+ for v in self.expanded_sliced_op_outputs:
+ yield SSAVal([v])
+
+ def __repr__(self):
+ # type: () -> str
+ if len(self.sliced_op_outputs) == 0:
+ return "SSAVal([])"
+ parts = []
+ for op, output_index, range_ in self.sliced_op_outputs:
+ out_len = op.properties.outputs[output_index].get_length(op.maxvl)
+ parts.append(f"<{op.name}#{output_index}>")
+ if range_ != range(out_len):
+ parts[-1] += f"[{range_.start}:{range_.stop}]"
+ return " + ".join(parts)
+
+
+@plain_data(frozen=True, eq=False)
+@final
+class Op:
+ __slots__ = "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
+
+ def __init__(self, fn, kind, inputs, immediates, maxvl, name=""):
+ # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
+ self.fn = fn
+ self.kind = kind
+ self.inputs = list(inputs)
+ self.immediates = list(immediates)
+ self.maxvl = maxvl
+ outputs_len = len(self.properties.outputs)
+ self.outputs = tuple(SSAVal([(self, i)]) for i in range(outputs_len))
+ self.name = fn._add_op_with_unused_name(self, name)
+
+ @property
+ def properties(self):
+ return self.kind.properties
+
+ def __eq__(self, other):
+ if isinstance(other, Op):
+ return self is other
+ return NotImplemented
+
+ def __hash__(self):
+ return object.__hash__(self)