from collections import defaultdict
from enum import Enum, EnumMeta, unique
from functools import lru_cache
-from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
- TypeVar, cast)
+from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast
from cached_property import cached_property
from nmutil.plain_data import fields, plain_data
+from bigint_presentation_code.ordered_set import OFSet, OSet
+
if TYPE_CHECKING:
from typing_extensions import final
else:
def get_subreg_at_offset(self, subreg_type, offset):
# type: (RegType, int) -> GPRRange
- if not isinstance(subreg_type, GPRRangeType):
- raise ValueError(f"subreg_type is not a "
+ if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)):
+ raise ValueError(f"subreg_type is not a FixedGPRRangeType or "
f"GPRRangeType: {subreg_type}")
if offset < 0 or offset + subreg_type.length > self.stop:
raise ValueError(f"sub-register offset is out of range: {offset}")
@final
-class RegClass(AbstractSet[RegLoc]):
+class RegClass(OFSet[RegLoc]):
""" an ordered set of registers.
earlier registers are preferred by the register allocator.
"""
- def __init__(self, regs):
- # type: (Iterable[RegLoc]) -> None
-
- # use dict to maintain order
- self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None]
-
- def __len__(self):
- return len(self.__regs)
-
- def __iter__(self):
- return iter(self.__regs)
-
- def __contains__(self, v):
- # type: (RegLoc) -> bool
- return v in self.__regs
-
- def __hash__(self):
- return super()._hash()
-
@lru_cache(maxsize=None, typed=True)
def max_conflicts_with(self, other):
# type: (RegClass | RegLoc) -> int
@plain_data(frozen=True, unsafe_hash=True)
@final
-class FixedGPRRangeType(GPRRangeType):
+class FixedGPRRangeType(RegType):
__slots__ = "reg",
def __init__(self, reg):
# type: (GPRRange) -> None
- super().__init__(length=reg.length)
self.reg = reg
@property
# type: () -> RegClass
return RegClass([self.reg])
+ @property
+ def length(self):
+ # type: () -> int
+ return self.reg.length
+
@plain_data(frozen=True, unsafe_hash=True)
@final
def __hash__(self):
return hash((id(self.op), self.arg_name))
- def __repr__(self):
+ def __repr__(self, long=False):
+ if not long:
+ return f"<#{self.op.id}.{self.arg_name}>"
fields_list = []
for name in fields(self):
v = getattr(self, name, None)
@cached_property
def id(self):
+ # type: () -> int
+ # use cached_property rather than done in init so id is usable even if
+ # init hasn't run
retval = Op.__NEXT_ID
Op.__NEXT_ID += 1
return retval
+ def __init__(self):
+ self.id # initialize
+
@final
def __repr__(self, just_id=False):
fields_list = [f"#{self.id}"]
+ outputs = None
+ try:
+ outputs = self.outputs()
+ except AttributeError:
+ pass
if not just_id:
for name in fields(self):
v = getattr(self, name, _NOT_SET)
- fields_list.append(f"{name}={v!r}")
+ if ((outputs is None or name in outputs)
+ and isinstance(v, SSAVal)):
+ v = v.__repr__(long=True)
+ else:
+ v = repr(v)
+ fields_list.append(f"{name}={v}")
fields_str = ', '.join(fields_list)
return f"{self.__class__.__name__}({fields_str})"
def __init__(self, src):
# type: (SSAVal[GPRRangeType]) -> None
+ super().__init__()
self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
self.src = src
def __init__(self, src):
# type: (SSAVal[StackSlotType]) -> None
+ super().__init__()
self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
self.src = src
def __init__(self, src, dest_ty=None):
# type: (SSAVal[_RegSrcType], _RegType | None) -> None
+ super().__init__()
if dest_ty is None:
dest_ty = cast(_RegType, src.ty)
if isinstance(src.ty, GPRRangeType) \
+ and isinstance(dest_ty, FixedGPRRangeType):
+ if src.ty.length != dest_ty.reg.length:
+ raise ValueError(f"incompatible source and destination "
+ f"types: {src.ty} and {dest_ty}")
+ elif isinstance(src.ty, FixedGPRRangeType) \
and isinstance(dest_ty, GPRRangeType):
- if src.ty.length != dest_ty.length:
+ if src.ty.reg.length != dest_ty.length:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
elif src.ty != dest_ty:
def __init__(self, sources):
# type: (Iterable[SSAVal[GPRRangeType]]) -> None
+ super().__init__()
sources = tuple(sources)
self.dest = SSAVal(self, "dest", GPRRangeType(
sum(i.ty.length for i in sources)))
def __init__(self, src, split_indexes):
# type: (SSAVal[GPRRangeType], Iterable[int]) -> None
+ super().__init__()
ranges = [] # type: list[GPRRangeType]
last = 0
for i in split_indexes:
def __init__(self, RA, RB, CY_in, is_sub):
# type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+ super().__init__()
if RA.ty != RB.ty:
raise TypeError(f"source types must match: "
f"{RA} doesn't match {RB}")
def __init__(self, RA, RB, RC, is_div):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", RA.ty)
self.RA = RA
self.RB = RB
def __init__(self, inp, sh, kind):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", inp.ty)
self.inp = inp
self.sh = sh
def __init__(self, value, length=1):
# type: (int, int) -> None
+ super().__init__()
self.out = SSAVal(self, "out", GPRRangeType(length))
self.value = value
def __init__(self):
# type: () -> None
+ super().__init__()
self.out = SSAVal(self, "out", CYType())
def __init__(self, RA, offset, mem, length=1):
# type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+ super().__init__()
self.RT = SSAVal(self, "RT", GPRRangeType(length))
self.RA = RA
self.offset = offset
def __init__(self, RS, RA, offset, mem_in):
# type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+ super().__init__()
self.RS = RS
self.RA = RA
self.offset = offset
def __init__(self, ty):
# type: (FixedGPRRangeType) -> None
+ super().__init__()
self.out = SSAVal(self, "out", ty)
def __init__(self):
# type: () -> None
+ super().__init__()
self.out = SSAVal(self, "out", GlobalMemType())
ops_to_pending_input_count_map[op] = input_count
worklists[input_count][op] = None
retval = [] # type: list[Op]
- ready_vals = set() # type: set[SSAVal]
+ ready_vals = OSet() # type: OSet[SSAVal]
while len(worklists[0]) != 0:
writing_op = next(iter(worklists[0]))
del worklists[0][writing_op]
--- /dev/null
+from typing import AbstractSet, Iterable, MutableSet, TypeVar
+
+_T_co = TypeVar("_T_co", covariant=True)
+_T = TypeVar("_T")
+
+
+class OFSet(AbstractSet[_T_co]):
+ """ ordered frozen set """
+
+ def __init__(self, items=()):
+ # type: (Iterable[_T_co]) -> None
+ self.__items = {v: None for v in items}
+
+ def __contains__(self, x):
+ return x in self.__items
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def __hash__(self):
+ return self._hash()
+
+ def __repr__(self):
+ if len(self) == 0:
+ return "OFSet()"
+ return f"OFSet({list(self)})"
+
+
+class OSet(MutableSet[_T]):
+ """ ordered mutable set """
+
+ def __init__(self, items=()):
+ # type: (Iterable[_T]) -> None
+ self.__items = {v: None for v in items}
+
+ def __contains__(self, x):
+ return x in self.__items
+
+ def __iter__(self):
+ return iter(self.__items)
+
+ def __len__(self):
+ return len(self.__items)
+
+ def add(self, value):
+ # type: (_T) -> None
+ self.__items[value] = None
+
+ def discard(self, value):
+ # type: (_T) -> None
+ self.__items.pop(value, None)
+
+ def __repr__(self):
+ if len(self) == 0:
+ return "OSet()"
+ return f"OSet({list(self)})"
from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass,
RegLoc, RegType, SSAVal)
+from bigint_presentation_code.ordered_set import OFSet, OSet
if TYPE_CHECKING:
- from typing_extensions import Self, final
+ from typing_extensions import final
else:
def final(v):
return v
self.__start = start # type: int
self.__stop = stop # type: int
self.__ty = ty # type: RegType
- self.__hash = hash(frozenset(self.items()))
+ self.__hash = hash(OFSet(self.items()))
@staticmethod
def from_equality_constraint(constraint_sequence):
for e in op.get_equality_constraints():
lhs_set = MergedRegSet.from_equality_constraint(e.lhs)
rhs_set = MergedRegSet.from_equality_constraint(e.rhs)
- lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set)
- rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set)
- full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()])
+ items = [] # type: list[tuple[SSAVal, int]]
+ for i in e.lhs:
+ s = merged_sets[i].with_offset_to_match(lhs_set)
+ items.extend(s.items())
+ for i in e.rhs:
+ s = merged_sets[i].with_offset_to_match(rhs_set)
+ items.extend(s.items())
+ full_set = MergedRegSet(items)
for val in full_set.keys():
merged_sets[val] = full_set
else:
live_intervals[reg_set] += op_idx
self.__live_intervals = live_intervals
- live_after = [] # type: list[set[MergedRegSet[_RegType]]]
- live_after += (set() for _ in ops)
+ live_after = [] # type: list[OSet[MergedRegSet[_RegType]]]
+ live_after += (OSet() for _ in ops)
for reg_set, live_interval in self.__live_intervals.items():
for i in live_interval.live_after_op_range:
live_after[i].add(reg_set)
- self.__live_after = [frozenset(i) for i in live_after]
+ self.__live_after = [OFSet(i) for i in live_after]
@property
def merged_reg_sets(self):
def __iter__(self):
return iter(self.__live_intervals)
+ def __len__(self):
+ return len(self.__live_intervals)
+
def reg_sets_live_after(self, op_index):
- # type: (int) -> frozenset[MergedRegSet[_RegType]]
+ # type: (int) -> OFSet[MergedRegSet[_RegType]]
return self.__live_after[op_index]
def __repr__(self):
def __init__(self, merged_reg_set, edges=(), reg=None):
# type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
self.merged_reg_set = merged_reg_set
- self.edges = set(edges)
+ self.edges = OSet(edges)
self.reg = reg
def add_edge(self, other):
def __iter__(self):
return iter(self.__nodes)
+ def __len__(self):
+ return len(self.__nodes)
+
def __repr__(self):
nodes = {}
nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()]
if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0:
interference_graph[i].add_edge(interference_graph[j])
- nodes_remaining = set(interference_graph.values())
+ nodes_remaining = OSet(interference_graph.values())
def local_colorability_score(node):
# type: (IGNode) -> int
maxDiff = None
def test_op_set_to_list(self):
- ops = [] # list[Op]
+ ops = [] # type: list[Op]
op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
ops.append(op0)
op1 = OpCopy(op0.out, GPRType())
import unittest
-from bigint_presentation_code.compiler_ir import Op
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange,
+ GPRType, GlobalMem, Op, OpAddSubE,
+ OpClearCY, OpConcat, OpCopy,
+ OpFuncArg, OpInputMem, OpLI,
+ OpLoad, OpStore, XERBit)
from bigint_presentation_code.register_allocator import (
- AllocationFailed, allocate_registers,
+ AllocationFailed, allocate_registers, MergedRegSet,
try_allocate_registers_without_spilling)
+class TestMergedRegSet(unittest.TestCase):
+ maxDiff = None
+
+ def test_from_equality_constraint(self):
+ op0 = OpLI(0, length=1)
+ op1 = OpLI(0, length=2)
+ op2 = OpLI(0, length=3)
+ self.assertEqual(MergedRegSet.from_equality_constraint([
+ op0.out,
+ op1.out,
+ op2.out,
+ ]), MergedRegSet({
+ op0.out: 0,
+ op1.out: 1,
+ op2.out: 3,
+ }.items()))
+ self.assertEqual(MergedRegSet.from_equality_constraint([
+ op1.out,
+ op0.out,
+ op2.out,
+ ]), MergedRegSet({
+ op1.out: 0,
+ op0.out: 2,
+ op2.out: 3,
+ }.items()))
+
+
class TestRegisterAllocator(unittest.TestCase):
- pass # no tests yet, just testing importing
+ maxDiff = None
+
+ def test_try_alloc_fail(self):
+ ops = [] # type: list[Op]
+ op0 = OpLI(0, length=52)
+ ops.append(op0)
+ op1 = OpLI(0, length=64)
+ ops.append(op1)
+ op2 = OpConcat([op0.out, op1.out])
+ ops.append(op2)
+
+ reg_assignments = try_allocate_registers_without_spilling(ops)
+ self.assertEqual(
+ repr(reg_assignments),
+ "AllocationFailed("
+ "node=IGNode(#0, merged_reg_set=MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "edges={}, reg=None), "
+ "live_intervals=LiveIntervals("
+ "live_intervals={"
+ "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): "
+ "LiveInterval(first_write=0, last_use=2)}, "
+ "merged_reg_sets=MergedRegSets(data={"
+ "<#0.out>: MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "<#1.out>: MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "<#2.dest>: MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), "
+ "reg_sets_live_after={"
+ "0: OFSet([MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+ "1: OFSet([MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), "
+ "2: OFSet()}), "
+ "interference_graph=InterferenceGraph(nodes={"
+ "...: IGNode(#0, "
+ "merged_reg_set=MergedRegSet(["
+ "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), "
+ "edges={}, reg=None)}))"
+ )
+
+ def test_try_alloc_bigint_inc(self):
+ ops = [] # type: list[Op]
+ op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+ ops.append(op0)
+ op1 = OpCopy(op0.out, GPRType())
+ ops.append(op1)
+ arg = op1.dest
+ op2 = OpInputMem()
+ ops.append(op2)
+ mem = op2.out
+ op3 = OpLoad(arg, offset=0, mem=mem, length=32)
+ ops.append(op3)
+ a = op3.RT
+ op4 = OpLI(1)
+ ops.append(op4)
+ b_0 = op4.out
+ op5 = OpLI(0, length=31)
+ ops.append(op5)
+ b_rest = op5.out
+ op6 = OpConcat([b_0, b_rest])
+ ops.append(op6)
+ b = op6.dest
+ op7 = OpClearCY()
+ ops.append(op7)
+ cy = op7.out
+ op8 = OpAddSubE(a, b, cy, is_sub=False)
+ ops.append(op8)
+ s = op8.RT
+ op9 = OpStore(s, arg, offset=0, mem_in=mem)
+ ops.append(op9)
+ mem = op9.mem_out
+
+ reg_assignments = try_allocate_registers_without_spilling(ops)
+
+ expected_reg_assignments = {
+ op0.out: GPRRange(start=3, length=1),
+ op1.dest: GPRRange(start=3, length=1),
+ op2.out: GlobalMem.GlobalMem,
+ op3.RT: GPRRange(start=78, length=32),
+ op4.out: GPRRange(start=46, length=1),
+ op5.out: GPRRange(start=47, length=31),
+ op6.dest: GPRRange(start=46, length=32),
+ op7.out: XERBit.CY,
+ op8.RT: GPRRange(start=14, length=32),
+ op8.CY_out: XERBit.CY,
+ op9.mem_out: GlobalMem.GlobalMem,
+ }
+
+ self.assertEqual(reg_assignments, expected_reg_assignments)
+
+ def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
+ # type: (list[GPRRange], GPRRange) -> None
+ li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)]
+ ops = [*li_ops] # type: list[Op]
+ concat = OpConcat([i.out for i in li_ops])
+ ops.append(concat)
+
+ reg_assignments = try_allocate_registers_without_spilling(ops)
+
+ expected_reg_assignments = {concat.dest: expected_dest_reg}
+ for li_op, reg in zip(li_ops, expected_regs):
+ expected_reg_assignments[li_op.out] = reg
+
+ self.assertEqual(reg_assignments, expected_reg_assignments)
+
+ def test_try_alloc_concat_1(self):
+ self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3))
+
+ def test_try_alloc_concat_3(self):
+ self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3))
+
+ def test_try_alloc_concat_3_5(self):
+ self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)],
+ GPRRange(3, 8))
+
+ def test_try_alloc_concat_5_3(self):
+ self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)],
+ GPRRange(3, 8))
+
+ def test_try_alloc_concat_1_2_3_4_5_6(self):
+ self.tst_try_alloc_concat([
+ GPRRange(14, 1),
+ GPRRange(15, 2),
+ GPRRange(17, 3),
+ GPRRange(20, 4),
+ GPRRange(24, 5),
+ GPRRange(29, 6),
+ ], GPRRange(14, 21))
if __name__ == "__main__":