From: Jacob Lifshay Date: Fri, 14 Oct 2022 09:50:28 +0000 (-0700) Subject: try_allocate_registers_without_spilling works! X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bedf72180fa237857fac79b527691f609a56dd10;p=bigint-presentation-code.git try_allocate_registers_without_spilling works! --- diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index 24b8649..bfcf159 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -6,12 +6,13 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict from enum import Enum, EnumMeta, unique from functools import lru_cache -from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence, - TypeVar, cast) +from typing import TYPE_CHECKING, Generic, Iterable, Sequence, TypeVar, cast from cached_property import cached_property from nmutil.plain_data import fields, plain_data +from bigint_presentation_code.ordered_set import OFSet, OSet + if TYPE_CHECKING: from typing_extensions import final else: @@ -111,8 +112,8 @@ class GPRRange(RegLoc, Sequence["GPRRange"]): def get_subreg_at_offset(self, subreg_type, offset): # type: (RegType, int) -> GPRRange - if not isinstance(subreg_type, GPRRangeType): - raise ValueError(f"subreg_type is not a " + if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)): + raise ValueError(f"subreg_type is not a FixedGPRRangeType or " f"GPRRangeType: {subreg_type}") if offset < 0 or offset + subreg_type.length > self.stop: raise ValueError(f"sub-register offset is out of range: {offset}") @@ -150,30 +151,11 @@ class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): @final -class RegClass(AbstractSet[RegLoc]): +class RegClass(OFSet[RegLoc]): """ an ordered set of registers. earlier registers are preferred by the register allocator. """ - def __init__(self, regs): - # type: (Iterable[RegLoc]) -> None - - # use dict to maintain order - self.__regs = dict.fromkeys(regs) # type: dict[RegLoc, None] - - def __len__(self): - return len(self.__regs) - - def __iter__(self): - return iter(self.__regs) - - def __contains__(self, v): - # type: (RegLoc) -> bool - return v in self.__regs - - def __hash__(self): - return super()._hash() - @lru_cache(maxsize=None, typed=True) def max_conflicts_with(self, other): # type: (RegClass | RegLoc) -> int @@ -251,12 +233,11 @@ class GPRType(GPRRangeType): @plain_data(frozen=True, unsafe_hash=True) @final -class FixedGPRRangeType(GPRRangeType): +class FixedGPRRangeType(RegType): __slots__ = "reg", def __init__(self, reg): # type: (GPRRange) -> None - super().__init__(length=reg.length) self.reg = reg @property @@ -264,6 +245,11 @@ class FixedGPRRangeType(GPRRangeType): # type: () -> RegClass return RegClass([self.reg]) + @property + def length(self): + # type: () -> int + return self.reg.length + @plain_data(frozen=True, unsafe_hash=True) @final @@ -384,7 +370,9 @@ class SSAVal(Generic[_RegType]): def __hash__(self): return hash((id(self.op), self.arg_name)) - def __repr__(self): + def __repr__(self, long=False): + if not long: + return f"<#{self.op.id}.{self.arg_name}>" fields_list = [] for name in fields(self): v = getattr(self, name, None) @@ -449,17 +437,33 @@ class Op(metaclass=ABCMeta): @cached_property def id(self): + # type: () -> int + # use cached_property rather than done in init so id is usable even if + # init hasn't run retval = Op.__NEXT_ID Op.__NEXT_ID += 1 return retval + def __init__(self): + self.id # initialize + @final def __repr__(self, just_id=False): fields_list = [f"#{self.id}"] + outputs = None + try: + outputs = self.outputs() + except AttributeError: + pass if not just_id: for name in fields(self): v = getattr(self, name, _NOT_SET) - fields_list.append(f"{name}={v!r}") + if ((outputs is None or name in outputs) + and isinstance(v, SSAVal)): + v = v.__repr__(long=True) + else: + v = repr(v) + fields_list.append(f"{name}={v}") fields_str = ', '.join(fields_list) return f"{self.__class__.__name__}({fields_str})" @@ -479,6 +483,7 @@ class OpLoadFromStackSlot(Op): def __init__(self, src): # type: (SSAVal[GPRRangeType]) -> None + super().__init__() self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) self.src = src @@ -498,6 +503,7 @@ class OpStoreToStackSlot(Op): def __init__(self, src): # type: (SSAVal[StackSlotType]) -> None + super().__init__() self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) self.src = src @@ -520,11 +526,17 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]): def __init__(self, src, dest_ty=None): # type: (SSAVal[_RegSrcType], _RegType | None) -> None + super().__init__() if dest_ty is None: dest_ty = cast(_RegType, src.ty) if isinstance(src.ty, GPRRangeType) \ + and isinstance(dest_ty, FixedGPRRangeType): + if src.ty.length != dest_ty.reg.length: + raise ValueError(f"incompatible source and destination " + f"types: {src.ty} and {dest_ty}") + elif isinstance(src.ty, FixedGPRRangeType) \ and isinstance(dest_ty, GPRRangeType): - if src.ty.length != dest_ty.length: + if src.ty.reg.length != dest_ty.length: raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") elif src.ty != dest_ty: @@ -550,6 +562,7 @@ class OpConcat(Op): def __init__(self, sources): # type: (Iterable[SSAVal[GPRRangeType]]) -> None + super().__init__() sources = tuple(sources) self.dest = SSAVal(self, "dest", GPRRangeType( sum(i.ty.length for i in sources))) @@ -575,6 +588,7 @@ class OpSplit(Op): def __init__(self, src, split_indexes): # type: (SSAVal[GPRRangeType], Iterable[int]) -> None + super().__init__() ranges = [] # type: list[GPRRangeType] last = 0 for i in split_indexes: @@ -608,6 +622,7 @@ class OpAddSubE(Op): def __init__(self, RA, RB, CY_in, is_sub): # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None + super().__init__() if RA.ty != RB.ty: raise TypeError(f"source types must match: " f"{RA} doesn't match {RB}") @@ -639,6 +654,7 @@ class OpBigIntMulDiv(Op): def __init__(self, RA, RB, RC, is_div): # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None + super().__init__() self.RT = SSAVal(self, "RT", RA.ty) self.RA = RA self.RB = RB @@ -683,6 +699,7 @@ class OpBigIntShift(Op): def __init__(self, inp, sh, kind): # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None + super().__init__() self.RT = SSAVal(self, "RT", inp.ty) self.inp = inp self.sh = sh @@ -709,6 +726,7 @@ class OpLI(Op): def __init__(self, value, length=1): # type: (int, int) -> None + super().__init__() self.out = SSAVal(self, "out", GPRRangeType(length)) self.value = value @@ -728,6 +746,7 @@ class OpClearCY(Op): def __init__(self): # type: () -> None + super().__init__() self.out = SSAVal(self, "out", CYType()) @@ -746,6 +765,7 @@ class OpLoad(Op): def __init__(self, RA, offset, mem, length=1): # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None + super().__init__() self.RT = SSAVal(self, "RT", GPRRangeType(length)) self.RA = RA self.offset = offset @@ -772,6 +792,7 @@ class OpStore(Op): def __init__(self, RS, RA, offset, mem_in): # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None + super().__init__() self.RS = RS self.RA = RA self.offset = offset @@ -794,6 +815,7 @@ class OpFuncArg(Op): def __init__(self, ty): # type: (FixedGPRRangeType) -> None + super().__init__() self.out = SSAVal(self, "out", ty) @@ -812,6 +834,7 @@ class OpInputMem(Op): def __init__(self): # type: () -> None + super().__init__() self.out = SSAVal(self, "out", GlobalMemType()) @@ -830,7 +853,7 @@ def op_set_to_list(ops): ops_to_pending_input_count_map[op] = input_count worklists[input_count][op] = None retval = [] # type: list[Op] - ready_vals = set() # type: set[SSAVal] + ready_vals = OSet() # type: OSet[SSAVal] while len(worklists[0]) != 0: writing_op = next(iter(worklists[0])) del worklists[0][writing_op] diff --git a/src/bigint_presentation_code/ordered_set.py b/src/bigint_presentation_code/ordered_set.py new file mode 100644 index 0000000..018f97b --- /dev/null +++ b/src/bigint_presentation_code/ordered_set.py @@ -0,0 +1,59 @@ +from typing import AbstractSet, Iterable, MutableSet, TypeVar + +_T_co = TypeVar("_T_co", covariant=True) +_T = TypeVar("_T") + + +class OFSet(AbstractSet[_T_co]): + """ ordered frozen set """ + + def __init__(self, items=()): + # type: (Iterable[_T_co]) -> None + self.__items = {v: None for v in items} + + def __contains__(self, x): + return x in self.__items + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def __hash__(self): + return self._hash() + + def __repr__(self): + if len(self) == 0: + return "OFSet()" + return f"OFSet({list(self)})" + + +class OSet(MutableSet[_T]): + """ ordered mutable set """ + + def __init__(self, items=()): + # type: (Iterable[_T]) -> None + self.__items = {v: None for v in items} + + def __contains__(self, x): + return x in self.__items + + def __iter__(self): + return iter(self.__items) + + def __len__(self): + return len(self.__items) + + def add(self, value): + # type: (_T) -> None + self.__items[value] = None + + def discard(self, value): + # type: (_T) -> None + self.__items.pop(value, None) + + def __repr__(self): + if len(self) == 0: + return "OSet()" + return f"OSet({list(self)})" diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index 75b3422..b44c32d 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -12,9 +12,10 @@ from nmutil.plain_data import plain_data from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, RegLoc, RegType, SSAVal) +from bigint_presentation_code.ordered_set import OFSet, OSet if TYPE_CHECKING: - from typing_extensions import Self, final + from typing_extensions import final else: def final(v): return v @@ -103,7 +104,7 @@ class MergedRegSet(Mapping[SSAVal[_RegType], int]): self.__start = start # type: int self.__stop = stop # type: int self.__ty = ty # type: RegType - self.__hash = hash(frozenset(self.items())) + self.__hash = hash(OFSet(self.items())) @staticmethod def from_equality_constraint(constraint_sequence): @@ -181,9 +182,14 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]): for e in op.get_equality_constraints(): lhs_set = MergedRegSet.from_equality_constraint(e.lhs) rhs_set = MergedRegSet.from_equality_constraint(e.rhs) - lhs_set = merged_sets[e.lhs[0]].with_offset_to_match(lhs_set) - rhs_set = merged_sets[e.rhs[0]].with_offset_to_match(rhs_set) - full_set = MergedRegSet([*lhs_set.items(), *rhs_set.items()]) + items = [] # type: list[tuple[SSAVal, int]] + for i in e.lhs: + s = merged_sets[i].with_offset_to_match(lhs_set) + items.extend(s.items()) + for i in e.rhs: + s = merged_sets[i].with_offset_to_match(rhs_set) + items.extend(s.items()) + full_set = MergedRegSet(items) for val in full_set.keys(): merged_sets[val] = full_set @@ -219,12 +225,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): else: live_intervals[reg_set] += op_idx self.__live_intervals = live_intervals - live_after = [] # type: list[set[MergedRegSet[_RegType]]] - live_after += (set() for _ in ops) + live_after = [] # type: list[OSet[MergedRegSet[_RegType]]] + live_after += (OSet() for _ in ops) for reg_set, live_interval in self.__live_intervals.items(): for i in live_interval.live_after_op_range: live_after[i].add(reg_set) - self.__live_after = [frozenset(i) for i in live_after] + self.__live_after = [OFSet(i) for i in live_after] @property def merged_reg_sets(self): @@ -237,8 +243,11 @@ class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): def __iter__(self): return iter(self.__live_intervals) + def __len__(self): + return len(self.__live_intervals) + def reg_sets_live_after(self, op_index): - # type: (int) -> frozenset[MergedRegSet[_RegType]] + # type: (int) -> OFSet[MergedRegSet[_RegType]] return self.__live_after[op_index] def __repr__(self): @@ -256,7 +265,7 @@ class IGNode(Generic[_RegType]): def __init__(self, merged_reg_set, edges=(), reg=None): # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None self.merged_reg_set = merged_reg_set - self.edges = set(edges) + self.edges = OSet(edges) self.reg = reg def add_edge(self, other): @@ -312,6 +321,9 @@ class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]): def __iter__(self): return iter(self.__nodes) + def __len__(self): + return len(self.__nodes) + def __repr__(self): nodes = {} nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] @@ -347,7 +359,7 @@ def try_allocate_registers_without_spilling(ops): if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: interference_graph[i].add_edge(interference_graph[j]) - nodes_remaining = set(interference_graph.values()) + nodes_remaining = OSet(interference_graph.values()) def local_colorability_score(node): # type: (IGNode) -> int diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py index 26d5272..1f30547 100644 --- a/src/bigint_presentation_code/test_compiler_ir.py +++ b/src/bigint_presentation_code/test_compiler_ir.py @@ -9,7 +9,7 @@ class TestCompilerIR(unittest.TestCase): maxDiff = None def test_op_set_to_list(self): - ops = [] # list[Op] + ops = [] # type: list[Op] op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3))) ops.append(op0) op1 = OpCopy(op0.out, GPRType()) diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py index 43675a9..8ba74eb 100644 --- a/src/bigint_presentation_code/test_register_allocator.py +++ b/src/bigint_presentation_code/test_register_allocator.py @@ -1,13 +1,173 @@ import unittest -from bigint_presentation_code.compiler_ir import Op +from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, + GPRType, GlobalMem, Op, OpAddSubE, + OpClearCY, OpConcat, OpCopy, + OpFuncArg, OpInputMem, OpLI, + OpLoad, OpStore, XERBit) from bigint_presentation_code.register_allocator import ( - AllocationFailed, allocate_registers, + AllocationFailed, allocate_registers, MergedRegSet, try_allocate_registers_without_spilling) +class TestMergedRegSet(unittest.TestCase): + maxDiff = None + + def test_from_equality_constraint(self): + op0 = OpLI(0, length=1) + op1 = OpLI(0, length=2) + op2 = OpLI(0, length=3) + self.assertEqual(MergedRegSet.from_equality_constraint([ + op0.out, + op1.out, + op2.out, + ]), MergedRegSet({ + op0.out: 0, + op1.out: 1, + op2.out: 3, + }.items())) + self.assertEqual(MergedRegSet.from_equality_constraint([ + op1.out, + op0.out, + op2.out, + ]), MergedRegSet({ + op1.out: 0, + op0.out: 2, + op2.out: 3, + }.items())) + + class TestRegisterAllocator(unittest.TestCase): - pass # no tests yet, just testing importing + maxDiff = None + + def test_try_alloc_fail(self): + ops = [] # type: list[Op] + op0 = OpLI(0, length=52) + ops.append(op0) + op1 = OpLI(0, length=64) + ops.append(op1) + op2 = OpConcat([op0.out, op1.out]) + ops.append(op2) + + reg_assignments = try_allocate_registers_without_spilling(ops) + self.assertEqual( + repr(reg_assignments), + "AllocationFailed(" + "node=IGNode(#0, merged_reg_set=MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "edges={}, reg=None), " + "live_intervals=LiveIntervals(" + "live_intervals={" + "MergedRegSet([(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]): " + "LiveInterval(first_write=0, last_use=2)}, " + "merged_reg_sets=MergedRegSets(data={" + "<#0.out>: MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "<#1.out>: MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "<#2.dest>: MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])}), " + "reg_sets_live_after={" + "0: OFSet([MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), " + "1: OFSet([MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)])]), " + "2: OFSet()}), " + "interference_graph=InterferenceGraph(nodes={" + "...: IGNode(#0, " + "merged_reg_set=MergedRegSet([" + "(<#2.dest>, 0), (<#0.out>, 0), (<#1.out>, 52)]), " + "edges={}, reg=None)}))" + ) + + def test_try_alloc_bigint_inc(self): + ops = [] # type: list[Op] + op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3))) + ops.append(op0) + op1 = OpCopy(op0.out, GPRType()) + ops.append(op1) + arg = op1.dest + op2 = OpInputMem() + ops.append(op2) + mem = op2.out + op3 = OpLoad(arg, offset=0, mem=mem, length=32) + ops.append(op3) + a = op3.RT + op4 = OpLI(1) + ops.append(op4) + b_0 = op4.out + op5 = OpLI(0, length=31) + ops.append(op5) + b_rest = op5.out + op6 = OpConcat([b_0, b_rest]) + ops.append(op6) + b = op6.dest + op7 = OpClearCY() + ops.append(op7) + cy = op7.out + op8 = OpAddSubE(a, b, cy, is_sub=False) + ops.append(op8) + s = op8.RT + op9 = OpStore(s, arg, offset=0, mem_in=mem) + ops.append(op9) + mem = op9.mem_out + + reg_assignments = try_allocate_registers_without_spilling(ops) + + expected_reg_assignments = { + op0.out: GPRRange(start=3, length=1), + op1.dest: GPRRange(start=3, length=1), + op2.out: GlobalMem.GlobalMem, + op3.RT: GPRRange(start=78, length=32), + op4.out: GPRRange(start=46, length=1), + op5.out: GPRRange(start=47, length=31), + op6.dest: GPRRange(start=46, length=32), + op7.out: XERBit.CY, + op8.RT: GPRRange(start=14, length=32), + op8.CY_out: XERBit.CY, + op9.mem_out: GlobalMem.GlobalMem, + } + + self.assertEqual(reg_assignments, expected_reg_assignments) + + def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): + # type: (list[GPRRange], GPRRange) -> None + li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)] + ops = [*li_ops] # type: list[Op] + concat = OpConcat([i.out for i in li_ops]) + ops.append(concat) + + reg_assignments = try_allocate_registers_without_spilling(ops) + + expected_reg_assignments = {concat.dest: expected_dest_reg} + for li_op, reg in zip(li_ops, expected_regs): + expected_reg_assignments[li_op.out] = reg + + self.assertEqual(reg_assignments, expected_reg_assignments) + + def test_try_alloc_concat_1(self): + self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3)) + + def test_try_alloc_concat_3(self): + self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3)) + + def test_try_alloc_concat_3_5(self): + self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)], + GPRRange(3, 8)) + + def test_try_alloc_concat_5_3(self): + self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)], + GPRRange(3, 8)) + + def test_try_alloc_concat_1_2_3_4_5_6(self): + self.tst_try_alloc_concat([ + GPRRange(14, 1), + GPRRange(15, 2), + GPRRange(17, 3), + GPRRange(20, 4), + GPRRange(24, 5), + GPRRange(29, 6), + ], GPRRange(14, 21)) if __name__ == "__main__":