From: Jacob Lifshay Date: Tue, 8 Nov 2022 06:55:51 +0000 (-0800) Subject: remove old register allocator and compiler ir X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=d1e6a2834f477210a89d86f76b4683527be9ceec;p=bigint-presentation-code.git remove old register allocator and compiler ir --- diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py deleted file mode 100644 index 820c305..0000000 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ /dev/null @@ -1,120 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, GPRType, - OpBigIntAddSub, OpConcat, - OpCopy, OpFuncArg, - OpInputMem, OpLI, OpLoad, - OpSetCA, OpSetVLImm, OpStore, - RegLoc, SSAVal, XERBit, - generate_assembly, - op_set_to_list) - - -class TestCompilerIR(unittest.TestCase): - maxDiff = None - - def test_op_set_to_list(self): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 1) - b_0 = op5.out - op6 = OpSetVLImm(fn, 31) - vl = op6.out - op7 = OpLI(fn, 0, vl=vl) - b_rest = op7.out - op8 = OpConcat(fn, [b_0, b_rest]) - b = op8.dest - op9 = OpSetVLImm(fn, 32) - vl = op9.out - op10 = OpSetCA(fn, False) - ca = op10.out - op11 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op11.out - op12 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op12.mem_out - - expected_ops = [ - op10, # OpSetCA(fn, False) - op9, # OpSetVLImm(fn, 32) - op6, # OpSetVLImm(fn, 31) - op5, # OpLI(fn, 1) - op3, # OpSetVLImm(fn, 32) - op2, # OpInputMem(fn) - op0, # OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op7, # OpLI(fn, 0, vl=vl) - op1, # OpCopy(fn, op0.out, GPRType()) - op8, # OpConcat(fn, [b_0, b_rest]) - op4, # OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - op11, # OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - op12, # OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - ] - - ops = op_set_to_list(fn.ops[::-1]) - if ops != expected_ops: - self.assertEqual(repr(ops), repr(expected_ops)) - - def tst_generate_assembly(self, use_reg_alloc=False): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 0, vl=vl) - b = op5.out - op6 = OpSetCA(fn, True) - ca = op6.out - op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op7.out - op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op8.mem_out - - assigned_registers = { - op0.out: GPRRange(start=3, length=1), - op1.dest: GPRRange(start=3, length=1), - op2.out: GlobalMem.GlobalMem, - op3.out: VL.VL_MAXVL, - op4.RT: GPRRange(start=78, length=32), - op5.out: GPRRange(start=46, length=32), - op6.out: XERBit.CA, - op7.out: GPRRange(start=14, length=32), - op7.CA_out: XERBit.CA, - op8.mem_out: GlobalMem.GlobalMem, - } # type: dict[SSAVal, RegLoc] | None - - if use_reg_alloc: - assigned_registers = None - - asm = generate_assembly(fn.ops, assigned_registers) - self.assertEqual(asm, [ - "setvl 0, 0, 32, 0, 1, 1", - "sv.ld *78, 0(3)", - "sv.addi *46, 0, 0", - "subfic 0, 0, -1", - "sv.adde *14, *78, *46", - "sv.std *14, 0(3)", - "bclr 20, 0, 0", - ]) - - def test_generate_assembly(self): - self.tst_generate_assembly() - - def test_generate_assembly_with_register_allocator(self): - self.tst_generate_assembly(use_reg_alloc=True) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/_tests/test_register_allocator.py b/src/bigint_presentation_code/_tests/test_register_allocator.py deleted file mode 100644 index 1eff254..0000000 --- a/src/bigint_presentation_code/_tests/test_register_allocator.py +++ /dev/null @@ -1,201 +0,0 @@ -import unittest - -from bigint_presentation_code.compiler_ir import (VL, FixedGPRRangeType, Fn, - GlobalMem, GPRRange, GPRType, - OpBigIntAddSub, OpConcat, - OpCopy, OpFuncArg, - OpInputMem, OpLI, OpLoad, - OpSetCA, OpSetVLImm, OpStore, - XERBit) -from bigint_presentation_code.register_allocator import ( - AllocationFailed, MergedRegSet, allocate_registers, - try_allocate_registers_without_spilling) - - -class TestMergedRegSet(unittest.TestCase): - maxDiff = None - - def test_from_equality_constraint(self): - fn = Fn() - li0x1 = OpLI(fn, 0, vl=OpSetVLImm(fn, 1).out) - li0x2 = OpLI(fn, 0, vl=OpSetVLImm(fn, 2).out) - li0x3 = OpLI(fn, 0, vl=OpSetVLImm(fn, 3).out) - self.assertEqual(MergedRegSet.from_equality_constraint([ - li0x1.out, - li0x2.out, - li0x3.out, - ]), MergedRegSet({ - li0x1.out: 0, - li0x2.out: 1, - li0x3.out: 3, - }.items())) - self.assertEqual(MergedRegSet.from_equality_constraint([ - li0x2.out, - li0x1.out, - li0x3.out, - ]), MergedRegSet({ - li0x2.out: 0, - li0x1.out: 2, - li0x3.out: 3, - }.items())) - - -class TestRegisterAllocator(unittest.TestCase): - maxDiff = None - - def test_try_alloc_fail(self): - fn = Fn() - op0 = OpSetVLImm(fn, 52) - op1 = OpLI(fn, 0, vl=op0.out) - op2 = OpSetVLImm(fn, 64) - op3 = OpLI(fn, 0, vl=op2.out) - op4 = OpConcat(fn, [op1.out, op3.out]) - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - self.assertEqual( - repr(reg_assignments), - "AllocationFailed(" - "node=IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "edges={}, reg=None), " - "live_intervals=LiveIntervals(live_intervals={" - "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]): " - "LiveInterval(first_write=0, last_use=1), " - "MergedRegSet([(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]): " - "LiveInterval(first_write=1, last_use=4), " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]): " - "LiveInterval(first_write=2, last_use=3)}, " - "merged_reg_sets=MergedRegSets(data={" - "<#0.out: KnownVLType(length=52)>: " - "MergedRegSet([(<#0.out: KnownVLType(length=52)>, 0)]), " - "<#1.out: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "<#2.out: KnownVLType(length=64)>: " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)]), " - "<#3.out: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "<#4.dest: >: MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])}), " - "reg_sets_live_after={" - "0: OFSet([MergedRegSet([" - "(<#0.out: KnownVLType(length=52)>, 0)])]), " - "1: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])]), " - "2: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), " - "MergedRegSet([(<#2.out: KnownVLType(length=64)>, 0)])]), " - "3: OFSet([MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)])]), " - "4: OFSet()}), " - "interference_graph=InterferenceGraph(nodes={" - "...: IGNode(#0, merged_reg_set=MergedRegSet([" - "(<#0.out: KnownVLType(length=52)>, 0)]), edges={}, reg=None), " - "...: IGNode(#1, merged_reg_set=MergedRegSet([" - "(<#4.dest: >, 0), " - "(<#1.out: >, 0), " - "(<#3.out: >, 52)]), edges={}, reg=None), " - "...: IGNode(#2, merged_reg_set=MergedRegSet([" - "(<#2.out: KnownVLType(length=64)>, 0)]), edges={}, reg=None)}))" - ) - - def test_try_alloc_bigint_inc(self): - fn = Fn() - op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) - op1 = OpCopy(fn, op0.out, GPRType()) - arg = op1.dest - op2 = OpInputMem(fn) - mem = op2.out - op3 = OpSetVLImm(fn, 32) - vl = op3.out - op4 = OpLoad(fn, arg, offset=0, mem=mem, vl=vl) - a = op4.RT - op5 = OpLI(fn, 0, vl=vl) - b = op5.out - op6 = OpSetCA(fn, True) - ca = op6.out - op7 = OpBigIntAddSub(fn, a, b, ca, is_sub=False, vl=vl) - s = op7.out - op8 = OpStore(fn, s, arg, offset=0, mem_in=mem, vl=vl) - mem = op8.mem_out - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - - expected_reg_assignments = { - op0.out: GPRRange(start=3, length=1), - op1.dest: GPRRange(start=3, length=1), - op2.out: GlobalMem.GlobalMem, - op3.out: VL.VL_MAXVL, - op4.RT: GPRRange(start=78, length=32), - op5.out: GPRRange(start=46, length=32), - op6.out: XERBit.CA, - op7.out: GPRRange(start=14, length=32), - op7.CA_out: XERBit.CA, - op8.mem_out: GlobalMem.GlobalMem, - } - - self.assertEqual(reg_assignments, expected_reg_assignments) - - def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): - # type: (list[GPRRange], GPRRange) -> None - fn = Fn() - inputs = [] - expected_reg_assignments = {} - for i, r in enumerate(expected_regs): - vl = OpSetVLImm(fn, r.length).out - expected_reg_assignments[vl] = VL.VL_MAXVL - inp = OpLI(fn, i, vl=vl).out - inputs.append(inp) - expected_reg_assignments[inp] = r - concat = OpConcat(fn, inputs) - expected_reg_assignments[concat.dest] = expected_dest_reg - - reg_assignments = try_allocate_registers_without_spilling(fn.ops) - - for inp, reg in zip(inputs, expected_regs): - expected_reg_assignments[inp] = reg - - self.assertEqual(reg_assignments, expected_reg_assignments) - - def test_try_alloc_concat_1(self): - self.tst_try_alloc_concat([GPRRange(3)], GPRRange(3)) - - def test_try_alloc_concat_3(self): - self.tst_try_alloc_concat([GPRRange(3, 3)], GPRRange(3, 3)) - - def test_try_alloc_concat_3_5(self): - self.tst_try_alloc_concat([GPRRange(3, 3), GPRRange(6, 5)], - GPRRange(3, 8)) - - def test_try_alloc_concat_5_3(self): - self.tst_try_alloc_concat([GPRRange(3, 5), GPRRange(8, 3)], - GPRRange(3, 8)) - - def test_try_alloc_concat_1_2_3_4_5_6(self): - self.tst_try_alloc_concat([ - GPRRange(14, 1), - GPRRange(15, 2), - GPRRange(17, 3), - GPRRange(20, 4), - GPRRange(24, 5), - GPRRange(29, 6), - ], GPRRange(14, 21)) - - -if __name__ == "__main__": - unittest.main() diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py deleted file mode 100644 index c574174..0000000 --- a/src/bigint_presentation_code/compiler_ir.py +++ /dev/null @@ -1,1537 +0,0 @@ -# type: ignore -""" -Compiler IR for Toom-Cook algorithm generator for SVP64 - -This assumes VL != 0 throughout. -""" - -from abc import ABCMeta, abstractmethod -from collections import defaultdict -from enum import Enum, EnumMeta, unique -from functools import lru_cache -from typing import Any, Generic, Iterable, Sequence, Type, TypeVar, cast - -from nmutil.plain_data import fields, plain_data - -from bigint_presentation_code.type_util import final -from bigint_presentation_code.util import FMap, OFSet, OSet - - -class ABCEnumMeta(EnumMeta, ABCMeta): - pass - - -class RegLoc(metaclass=ABCMeta): - __slots__ = () - - @abstractmethod - def conflicts(self, other): - # type: (RegLoc) -> bool - ... - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> RegLoc - if self not in subreg_type.reg_class: - raise ValueError(f"register not a member of subreg_type: " - f"reg={self} subreg_type={subreg_type}") - if offset != 0: - raise ValueError(f"non-zero sub-register offset not supported " - f"for register: {self}") - return self - - -GPR_COUNT = 128 - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class GPRRange(RegLoc, Sequence["GPRRange"]): - __slots__ = "start", "length" - - def __init__(self, start, length=None): - # type: (int | range, int | None) -> None - if isinstance(start, range): - if length is not None: - raise TypeError("can't specify length when input is a range") - if start.step != 1: - raise ValueError("range must have a step of 1") - length = len(start) - start = start.start - elif length is None: - length = 1 - if length <= 0 or start < 0 or start + length > GPR_COUNT: - raise ValueError("invalid GPRRange") - self.start = start - self.length = length - - @property - def stop(self): - return self.start + self.length - - @property - def step(self): - return 1 - - @property - def range(self): - return range(self.start, self.stop, self.step) - - def __len__(self): - return self.length - - def __getitem__(self, item): - # type: (int | slice) -> GPRRange - return GPRRange(self.range[item]) - - def __contains__(self, value): - # type: (GPRRange) -> bool - return value.start >= self.start and value.stop <= self.stop - - def index(self, sub, start=None, end=None): - # type: (GPRRange, int | None, int | None) -> int - r = self.range[start:end] - if sub.start < r.start or sub.stop > r.stop: - raise ValueError("GPR range not found") - return sub.start - self.start - - def count(self, sub, start=None, end=None): - # type: (GPRRange, int | None, int | None) -> int - r = self.range[start:end] - if len(r) == 0: - return 0 - return int(sub in GPRRange(r)) - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, GPRRange): - return self.stop > other.start and other.stop > self.start - return False - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> GPRRange - if not isinstance(subreg_type, (GPRRangeType, FixedGPRRangeType)): - raise ValueError(f"subreg_type is not a FixedGPRRangeType or " - f"GPRRangeType: {subreg_type}") - if offset < 0 or offset + subreg_type.length > self.stop: - raise ValueError(f"sub-register offset is out of range: {offset}") - return GPRRange(self.start + offset, subreg_type.length) - - def __repr__(self): - if self.length == 1: - return f"" - return f"" - - -SPECIAL_GPRS = GPRRange(0), GPRRange(1), GPRRange(2), GPRRange(13) - - -@final -@unique -class XERBit(RegLoc, Enum, metaclass=ABCEnumMeta): - CA = "CA" - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, XERBit): - return self == other - return False - - -@final -@unique -class GlobalMem(RegLoc, Enum, metaclass=ABCEnumMeta): - """singleton representing all non-StackSlot memory -- treated as a single - physical register for register allocation purposes. - """ - GlobalMem = "GlobalMem" - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, GlobalMem): - return self == other - return False - - -@final -@unique -class VL(RegLoc, Enum, metaclass=ABCEnumMeta): - VL_MAXVL = "VL_MAXVL" - """VL and MAXVL""" - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, VL): - return self == other - return False - - -@final -class RegClass(OFSet[RegLoc]): - """ an ordered set of registers. - earlier registers are preferred by the register allocator. - """ - - @lru_cache(maxsize=None, typed=True) - def max_conflicts_with(self, other): - # type: (RegClass | RegLoc) -> int - """the largest number of registers in `self` that a single register - from `other` can conflict with - """ - if isinstance(other, RegClass): - return max(self.max_conflicts_with(i) for i in other) - else: - return sum(other.conflicts(i) for i in self) - - -@plain_data(frozen=True, unsafe_hash=True) -class RegType(metaclass=ABCMeta): - __slots__ = () - - @property - @abstractmethod - def reg_class(self): - # type: () -> RegClass - return ... - - -_RegType = TypeVar("_RegType", bound=RegType) -_RegLoc = TypeVar("_RegLoc", bound=RegLoc) - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class GPRRangeType(RegType): - __slots__ = "length", - - def __init__(self, length=1): - # type: (int) -> None - if length < 1 or length > GPR_COUNT: - raise ValueError("invalid length") - self.length = length - - @staticmethod - @lru_cache(maxsize=None) - def __get_reg_class(length): - # type: (int) -> RegClass - regs = [] - for start in range(GPR_COUNT - length): - reg = GPRRange(start, length) - if any(i in reg for i in SPECIAL_GPRS): - continue - regs.append(reg) - return RegClass(regs) - - @property - @final - def reg_class(self): - # type: () -> RegClass - return GPRRangeType.__get_reg_class(self.length) - - @final - def __eq__(self, other): - if isinstance(other, GPRRangeType): - return self.length == other.length - return False - - @final - def __hash__(self): - return hash(self.length) - - def __repr__(self): - return f"" - - -GPRType = GPRRangeType -"""a length=1 GPRRangeType""" - - -@plain_data(frozen=True, unsafe_hash=True, repr=False) -@final -class FixedGPRRangeType(RegType): - __slots__ = "reg", - - def __init__(self, reg): - # type: (GPRRange) -> None - self.reg = reg - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([self.reg]) - - @property - def length(self): - # type: () -> int - return self.reg.length - - def __repr__(self): - return f"" - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class CAType(RegType): - __slots__ = () - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([XERBit.CA]) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class GlobalMemType(RegType): - __slots__ = () - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([GlobalMem.GlobalMem]) - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class KnownVLType(RegType): - __slots__ = "length", - - def __init__(self, length): - # type: (int) -> None - if not (0 < length <= 64): - raise ValueError("invalid VL value") - self.length = length - - @property - def reg_class(self): - # type: () -> RegClass - return RegClass([VL.VL_MAXVL]) - - -def assert_vl_is(vl, expected_vl): - # type: (SSAKnownVL | KnownVLType | int | None, int) -> None - if vl is None: - vl = 1 - elif isinstance(vl, SSAVal): - vl = vl.ty.length - elif isinstance(vl, KnownVLType): - vl = vl.length - if vl != expected_vl: - raise ValueError( - f"wrong VL: expected {expected_vl} got {vl}") - - -STACK_SLOT_SIZE = 8 - - -@plain_data(frozen=True, unsafe_hash=True) -@final -class StackSlot(RegLoc): - __slots__ = "start_slot", "length_in_slots", - - def __init__(self, start_slot, length_in_slots): - # type: (int, int) -> None - self.start_slot = start_slot - if length_in_slots < 1: - raise ValueError("invalid length_in_slots") - self.length_in_slots = length_in_slots - - @property - def stop_slot(self): - return self.start_slot + self.length_in_slots - - @property - def start_byte(self): - return self.start_slot * STACK_SLOT_SIZE - - def conflicts(self, other): - # type: (RegLoc) -> bool - if isinstance(other, StackSlot): - return (self.stop_slot > other.start_slot - and other.stop_slot > self.start_slot) - return False - - def get_subreg_at_offset(self, subreg_type, offset): - # type: (RegType, int) -> StackSlot - if not isinstance(subreg_type, StackSlotType): - raise ValueError(f"subreg_type is not a " - f"StackSlotType: {subreg_type}") - if offset < 0 or offset + subreg_type.length_in_slots > self.stop_slot: - raise ValueError(f"sub-register offset is out of range: {offset}") - return StackSlot(self.start_slot + offset, subreg_type.length_in_slots) - - -STACK_SLOT_COUNT = 128 - - -@plain_data(frozen=True, eq=False) -@final -class StackSlotType(RegType): - __slots__ = "length_in_slots", - - def __init__(self, length_in_slots=1): - # type: (int) -> None - if length_in_slots < 1: - raise ValueError("invalid length_in_slots") - self.length_in_slots = length_in_slots - - @staticmethod - @lru_cache(maxsize=None) - def __get_reg_class(length_in_slots): - # type: (int) -> RegClass - regs = [] - for start in range(STACK_SLOT_COUNT - length_in_slots): - reg = StackSlot(start, length_in_slots) - regs.append(reg) - return RegClass(regs) - - @property - def reg_class(self): - # type: () -> RegClass - return StackSlotType.__get_reg_class(self.length_in_slots) - - @final - def __eq__(self, other): - if isinstance(other, StackSlotType): - return self.length_in_slots == other.length_in_slots - return False - - @final - def __hash__(self): - return hash(self.length_in_slots) - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class SSAVal(Generic[_RegType]): - __slots__ = "op", "arg_name", "ty", - - def __init__(self, op, arg_name, ty): - # type: (Op, str, _RegType) -> None - self.op = op - """the Op that writes this SSAVal""" - - self.arg_name = arg_name - """the name of the argument of self.op that writes this SSAVal""" - - self.ty = ty - - def __eq__(self, rhs): - if isinstance(rhs, SSAVal): - return (self.op is rhs.op - and self.arg_name == rhs.arg_name) - return False - - def __hash__(self): - return hash((id(self.op), self.arg_name)) - - def __repr__(self): - return f"<#{self.op.id}.{self.arg_name}: {self.ty}>" - - -SSAGPRRange = SSAVal[GPRRangeType] -SSAGPR = SSAVal[GPRType] -SSAKnownVL = SSAVal[KnownVLType] - - -@final -@plain_data(unsafe_hash=True, frozen=True) -class EqualityConstraint: - __slots__ = "lhs", "rhs" - - def __init__(self, lhs, rhs): - # type: (list[SSAVal], list[SSAVal]) -> None - self.lhs = lhs - self.rhs = rhs - if len(lhs) == 0 or len(rhs) == 0: - raise ValueError("can't constrain an empty list to be equal") - - -@final -class Fn: - __slots__ = "ops", - - def __init__(self): - # type: () -> None - self.ops = [] # type: list[Op] - - def __repr__(self, short=False): - if short: - return "" - ops = ", ".join(op.__repr__(just_id=True) for op in self.ops) - return f"" - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - for op in self.ops: - op.pre_ra_sim(state) - - -class _NotSet: - """ helper for __repr__ for when fields aren't set """ - - def __repr__(self): - return "" - - -_NOT_SET = _NotSet() - - -@final -class AsmContext: - def __init__(self, assigned_registers): - # type: (dict[SSAVal, RegLoc]) -> None - self.__assigned_registers = assigned_registers - - def reg(self, ssa_val, expected_ty): - # type: (SSAVal[Any], Type[_RegLoc]) -> _RegLoc - try: - reg = self.__assigned_registers[ssa_val] - except KeyError as e: - raise ValueError(f"SSAVal not assigned a register: {ssa_val}") - wrong_len = (isinstance(reg, GPRRange) - and reg.length != ssa_val.ty.length) - if not isinstance(reg, expected_ty) or wrong_len: - raise TypeError( - f"SSAVal is assigned a register of the wrong type: " - f"ssa_val={ssa_val} expected_ty={expected_ty} reg={reg}") - return reg - - def gpr_range(self, ssa_val): - # type: (SSAGPRRange | SSAVal[FixedGPRRangeType]) -> GPRRange - return self.reg(ssa_val, GPRRange) - - def stack_slot(self, ssa_val): - # type: (SSAVal[StackSlotType]) -> StackSlot - return self.reg(ssa_val, StackSlot) - - def gpr(self, ssa_val, vec, offset=0): - # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], bool, int) -> str - reg = self.gpr_range(ssa_val).start + offset - return "*" * vec + str(reg) - - def vgpr(self, ssa_val, offset=0): - # type: (SSAGPRRange | SSAVal[FixedGPRRangeType], int) -> str - return self.gpr(ssa_val=ssa_val, vec=True, offset=offset) - - def sgpr(self, ssa_val, offset=0): - # type: (SSAGPR | SSAVal[FixedGPRRangeType], int) -> str - return self.gpr(ssa_val=ssa_val, vec=False, offset=offset) - - def needs_sv(self, *regs): - # type: (*SSAGPRRange | SSAVal[FixedGPRRangeType]) -> bool - for reg in regs: - reg = self.gpr_range(reg) - if reg.length != 1 or reg.start >= 32: - return True - return False - - -GPR_SIZE_IN_BYTES = 8 -GPR_SIZE_IN_BITS = GPR_SIZE_IN_BYTES * 8 -GPR_VALUE_MASK = (1 << GPR_SIZE_IN_BITS) - 1 - - -@plain_data(frozen=True) -@final -class PreRASimState: - __slots__ = ("gprs", "VLs", "CAs", - "global_mems", "stack_slots", - "fixed_gprs") - - def __init__( - self, - gprs, # type: dict[SSAGPRRange, tuple[int, ...]] - VLs, # type: dict[SSAKnownVL, int] - CAs, # type: dict[SSAVal[CAType], bool] - global_mems, # type: dict[SSAVal[GlobalMemType], FMap[int, int]] - stack_slots, # type: dict[SSAVal[StackSlotType], tuple[int, ...]] - fixed_gprs, # type: dict[SSAVal[FixedGPRRangeType], tuple[int, ...]] - ): - # type: (...) -> None - self.gprs = gprs - self.VLs = VLs - self.CAs = CAs - self.global_mems = global_mems - self.stack_slots = stack_slots - self.fixed_gprs = fixed_gprs - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -class Op(metaclass=ABCMeta): - __slots__ = "id", "fn" - - @abstractmethod - def inputs(self): - # type: () -> dict[str, SSAVal] - ... - - @abstractmethod - def outputs(self): - # type: () -> dict[str, SSAVal] - ... - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - if False: - yield ... - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - if False: - yield ... - - def __init__(self, fn): - # type: (Fn) -> None - self.id = len(fn.ops) - fn.ops.append(self) - self.fn = fn - - @final - def __repr__(self, just_id=False): - fields_list = [f"#{self.id}"] - outputs = None - try: - outputs = self.outputs() - except AttributeError: - pass - if not just_id: - for name in fields(self): - if name in ("id", "fn"): - continue - v = getattr(self, name, _NOT_SET) - if (outputs is not None and name in outputs - and outputs[name] is v): - fields_list.append(repr(v)) - else: - fields_list.append(f"{name}={v!r}") - fields_str = ', '.join(fields_list) - return f"{self.__class__.__name__}({fields_str})" - - @abstractmethod - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - """get the lines of assembly for this Op""" - ... - - @abstractmethod - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - """simulate op before register allocation""" - ... - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpLoadFromStackSlot(Op): - __slots__ = "dest", "src", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {"src": self.src} # type: dict[str, SSAVal[Any]] - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, fn, src, vl=None): - # type: (Fn, SSAVal[StackSlotType], SSAKnownVL | None) -> None - super().__init__(fn) - self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) - self.src = src - self.vl = vl - assert_vl_is(vl, self.dest.ty.length) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - dest = ctx.gpr(self.dest, vec=self.dest.ty.length != 1) - src = ctx.stack_slot(self.src) - if ctx.needs_sv(self.dest): - return [f"sv.ld {dest}, {src.start_byte}(1)"] - return [f"ld {dest}, {src.start_byte}(1)"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - """simulate op before register allocation""" - state.gprs[self.dest] = state.stack_slots[self.src] - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpStoreToStackSlot(Op): - __slots__ = "dest", "src", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {"src": self.src} # type: dict[str, SSAVal[Any]] - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, fn, src, vl=None): - # type: (Fn, SSAGPRRange, SSAKnownVL | None) -> None - super().__init__(fn) - self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) - self.src = src - self.vl = vl - assert_vl_is(vl, src.ty.length) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - src = ctx.gpr(self.src, vec=self.src.ty.length != 1) - dest = ctx.stack_slot(self.dest) - if ctx.needs_sv(self.src): - return [f"sv.std {src}, {dest.start_byte}(1)"] - return [f"std {src}, {dest.start_byte}(1)"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - """simulate op before register allocation""" - state.stack_slots[self.dest] = state.gprs[self.src] - - -_RegSrcType = TypeVar("_RegSrcType", bound=RegType) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpCopy(Op, Generic[_RegSrcType, _RegType]): - __slots__ = "dest", "src", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {"src": self.src} # type: dict[str, SSAVal[Any]] - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, fn, src, dest_ty=None, vl=None): - # type: (Fn, SSAVal[_RegSrcType], _RegType | None, SSAKnownVL | None) -> None - super().__init__(fn) - if dest_ty is None: - dest_ty = cast(_RegType, src.ty) - if isinstance(src.ty, GPRRangeType) \ - and isinstance(dest_ty, FixedGPRRangeType): - if src.ty.length != dest_ty.reg.length: - raise ValueError(f"incompatible source and destination " - f"types: {src.ty} and {dest_ty}") - length = src.ty.length - elif isinstance(src.ty, FixedGPRRangeType) \ - and isinstance(dest_ty, GPRRangeType): - if src.ty.reg.length != dest_ty.length: - raise ValueError(f"incompatible source and destination " - f"types: {src.ty} and {dest_ty}") - length = src.ty.length - elif src.ty != dest_ty: - raise ValueError(f"incompatible source and destination " - f"types: {src.ty} and {dest_ty}") - elif isinstance(src.ty, StackSlotType): - raise ValueError("can't use OpCopy on stack slots") - elif isinstance(src.ty, (GPRRangeType, FixedGPRRangeType)): - length = src.ty.length - else: - length = 1 - - self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType] - self.src = src - self.vl = vl - assert_vl_is(vl, length) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - if ctx.reg(self.src, RegLoc) == ctx.reg(self.dest, RegLoc): - return [] - if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and - isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))): - vec = self.dest.ty.length != 1 - dest = ctx.gpr_range(self.dest) # type: ignore - src = ctx.gpr_range(self.src) # type: ignore - dest_s = ctx.gpr(self.dest, vec=vec) # type: ignore - src_s = ctx.gpr(self.src, vec=vec) # type: ignore - mrr = "" - if src.conflicts(dest) and src.start > dest.start: - mrr = "/mrr" - if ctx.needs_sv(self.src, self.dest): # type: ignore - return [f"sv.or{mrr} {dest_s}, {src_s}, {src_s}"] - return [f"or {dest_s}, {src_s}, {src_s}"] - raise NotImplementedError - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - if (isinstance(self.src.ty, (GPRRangeType, FixedGPRRangeType)) and - isinstance(self.dest.ty, (GPRRangeType, FixedGPRRangeType))): - if isinstance(self.src.ty, GPRRangeType): - v = state.gprs[self.src] # type: ignore - else: - v = state.fixed_gprs[self.src] # type: ignore - if isinstance(self.dest.ty, GPRRangeType): - state.gprs[self.dest] = v # type: ignore - else: - state.fixed_gprs[self.dest] = v # type: ignore - elif (isinstance(self.src.ty, FixedGPRRangeType) and - isinstance(self.dest.ty, GPRRangeType)): - state.gprs[self.dest] = state.fixed_gprs[self.src] # type: ignore - elif (isinstance(self.src.ty, GPRRangeType) and - isinstance(self.dest.ty, FixedGPRRangeType)): - state.fixed_gprs[self.dest] = state.gprs[self.src] # type: ignore - elif (isinstance(self.src.ty, CAType) and - self.src.ty == self.dest.ty): - state.CAs[self.dest] = state.CAs[self.src] # type: ignore - elif (isinstance(self.src.ty, KnownVLType) and - self.src.ty == self.dest.ty): - state.VLs[self.dest] = state.VLs[self.src] # type: ignore - elif (isinstance(self.src.ty, GlobalMemType) and - self.src.ty == self.dest.ty): - v = state.global_mems[self.src] # type: ignore - state.global_mems[self.dest] = v # type: ignore - else: - raise NotImplementedError - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpConcat(Op): - __slots__ = "dest", "sources" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {f"sources[{i}]": v for i, v in enumerate(self.sources)} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"dest": self.dest} - - def __init__(self, fn, sources): - # type: (Fn, Iterable[SSAGPRRange]) -> None - super().__init__(fn) - sources = tuple(sources) - self.dest = SSAVal(self, "dest", GPRRangeType( - sum(i.ty.length for i in sources))) - self.sources = sources - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([self.dest], [*self.sources]) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - return [] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - v = [] - for src in self.sources: - v.extend(state.gprs[src]) - state.gprs[self.dest] = tuple(v) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpSplit(Op): - __slots__ = "results", "src" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"src": self.src} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {i.arg_name: i for i in self.results} - - def __init__(self, fn, src, split_indexes): - # type: (Fn, SSAGPRRange, Iterable[int]) -> None - super().__init__(fn) - ranges = [] # type: list[GPRRangeType] - last = 0 - for i in split_indexes: - if not (0 < i < src.ty.length): - raise ValueError(f"invalid split index: {i}, must be in " - f"0 < i < {src.ty.length}") - ranges.append(GPRRangeType(i - last)) - last = i - ranges.append(GPRRangeType(src.ty.length - last)) - self.src = src - self.results = tuple( - SSAVal(self, f"results[{i}]", r) for i, r in enumerate(ranges)) - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([*self.results], [self.src]) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - return [] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - rest = state.gprs[self.src] - for dest in reversed(self.results): - state.gprs[dest] = rest[-dest.ty.length:] - rest = rest[:-dest.ty.length] - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpBigIntAddSub(Op): - __slots__ = "out", "lhs", "rhs", "CA_in", "CA_out", "is_sub", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - retval["lhs"] = self.lhs - retval["rhs"] = self.rhs - retval["CA_in"] = self.CA_in - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out, "CA_out": self.CA_out} - - def __init__(self, fn, lhs, rhs, CA_in, is_sub, vl=None): - # type: (Fn, SSAGPRRange, SSAGPRRange, SSAVal[CAType], bool, SSAKnownVL | None) -> None - super().__init__(fn) - if lhs.ty != rhs.ty: - raise TypeError(f"source types must match: " - f"{lhs} doesn't match {rhs}") - self.out = SSAVal(self, "out", lhs.ty) - self.lhs = lhs - self.rhs = rhs - self.CA_in = CA_in - self.CA_out = SSAVal(self, "CA_out", CA_in.ty) - self.is_sub = is_sub - self.vl = vl - assert_vl_is(vl, lhs.ty.length) - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.out, self.lhs - yield self.out, self.rhs - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - vec = self.out.ty.length != 1 - out = ctx.gpr(self.out, vec=vec) - RA = ctx.gpr(self.lhs, vec=vec) - RB = ctx.gpr(self.rhs, vec=vec) - mnemonic = "adde" - if self.is_sub: - mnemonic = "subfe" - RA, RB = RB, RA # reorder to match subfe - if ctx.needs_sv(self.out, self.lhs, self.rhs): - return [f"sv.{mnemonic} {out}, {RA}, {RB}"] - return [f"{mnemonic} {out}, {RA}, {RB}"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - carry = state.CAs[self.CA_in] - out = [] # type: list[int] - for l, r in zip(state.gprs[self.lhs], state.gprs[self.rhs]): - if self.is_sub: - r = r ^ GPR_VALUE_MASK - s = l + r + carry - carry = s != (s & GPR_VALUE_MASK) - out.append(s & GPR_VALUE_MASK) - state.CAs[self.CA_out] = carry - state.gprs[self.out] = tuple(out) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpBigIntMulDiv(Op): - __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - retval["RA"] = self.RA - retval["RB"] = self.RB - retval["RC"] = self.RC - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT, "RS": self.RS} - - def __init__(self, fn, RA, RB, RC, is_div, vl): - # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, bool, SSAKnownVL | None) -> None - super().__init__(fn) - self.RT = SSAVal(self, "RT", RA.ty) - self.RA = RA - self.RB = RB - self.RC = RC - self.RS = SSAVal(self, "RS", RC.ty) - self.is_div = is_div - self.vl = vl - assert_vl_is(vl, RA.ty.length) - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - yield EqualityConstraint([self.RC], [self.RS]) - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.RT, self.RA - yield self.RT, self.RB - yield self.RT, self.RC - yield self.RT, self.RS - yield self.RS, self.RA - yield self.RS, self.RB - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - vec = self.RT.ty.length != 1 - RT = ctx.gpr(self.RT, vec=vec) - RA = ctx.gpr(self.RA, vec=vec) - RB = ctx.sgpr(self.RB) - RC = ctx.sgpr(self.RC) - mnemonic = "maddedu" - if self.is_div: - mnemonic = "divmod2du/mrr" - return [f"sv.{mnemonic} {RT}, {RA}, {RB}, {RC}"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - carry = state.gprs[self.RC][0] - RA = state.gprs[self.RA] - RB = state.gprs[self.RB][0] - RT = [0] * self.RT.ty.length - if self.is_div: - for i in reversed(range(self.RT.ty.length)): - if carry < RB and RB != 0: - div, mod = divmod((carry << 64) | RA[i], RB) - RT[i] = div & GPR_VALUE_MASK - carry = mod & GPR_VALUE_MASK - else: - RT[i] = GPR_VALUE_MASK - carry = 0 - else: - for i in range(self.RT.ty.length): - v = RA[i] * RB + carry - carry = v >> 64 - RT[i] = v & GPR_VALUE_MASK - state.gprs[self.RS] = carry, - state.gprs[self.RT] = tuple(RT) - - -@final -@unique -class ShiftKind(Enum): - Sl = "sl" - Sr = "sr" - Sra = "sra" - - def make_big_int_carry_in(self, fn, inp): - # type: (Fn, SSAGPRRange) -> tuple[SSAGPR, list[Op]] - if self is ShiftKind.Sl or self is ShiftKind.Sr: - li = OpLI(fn, 0) - return li.out, [li] - else: - assert self is ShiftKind.Sra - split = OpSplit(fn, inp, [inp.ty.length - 1]) - shr = OpShiftImm(fn, split.results[1], sh=63, kind=ShiftKind.Sra) - return shr.out, [split, shr] - - def make_big_int_shift(self, fn, inp, sh, vl): - # type: (Fn, SSAGPRRange, SSAGPR, SSAKnownVL | None) -> tuple[SSAGPRRange, list[Op]] - carry_in, ops = self.make_big_int_carry_in(fn, inp) - big_int_shift = OpBigIntShift(fn, inp, sh, carry_in, kind=self, vl=vl) - ops.append(big_int_shift) - return big_int_shift.out, ops - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpBigIntShift(Op): - __slots__ = "out", "inp", "carry_in", "_out_padding", "sh", "kind", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - retval["inp"] = self.inp - retval["sh"] = self.sh - retval["carry_in"] = self.carry_in - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out, "_out_padding": self._out_padding} - - def __init__(self, fn, inp, sh, carry_in, kind, vl=None): - # type: (Fn, SSAGPRRange, SSAGPR, SSAGPR, ShiftKind, SSAKnownVL | None) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", inp.ty) - self._out_padding = SSAVal(self, "_out_padding", GPRRangeType()) - self.carry_in = carry_in - self.inp = inp - self.sh = sh - self.kind = kind - self.vl = vl - assert_vl_is(vl, inp.ty.length) - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - yield self.out, self.sh - - def get_equality_constraints(self): - # type: () -> Iterable[EqualityConstraint] - if self.kind is ShiftKind.Sl: - yield EqualityConstraint([self.carry_in, self.inp], - [self.out, self._out_padding]) - else: - assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra - yield EqualityConstraint([self.inp, self.carry_in], - [self._out_padding, self.out]) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - vec = self.out.ty.length != 1 - if self.kind is ShiftKind.Sl: - RT = ctx.gpr(self.out, vec=vec) - RA = ctx.gpr(self.out, vec=vec, offset=-1) - RB = ctx.sgpr(self.sh) - mrr = "/mrr" if vec else "" - return [f"sv.dsld{mrr} {RT}, {RA}, {RB}, 0"] - else: - assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra - RT = ctx.gpr(self.out, vec=vec) - RA = ctx.gpr(self.out, vec=vec, offset=1) - RB = ctx.sgpr(self.sh) - return [f"sv.dsrd {RT}, {RA}, {RB}, 1"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - out = [0] * self.out.ty.length - carry = state.gprs[self.carry_in][0] - sh = state.gprs[self.sh][0] % 64 - if self.kind is ShiftKind.Sl: - inp = carry, *state.gprs[self.inp] - for i in reversed(range(self.out.ty.length)): - v = inp[i] | (inp[i + 1] << 64) - v <<= sh - out[i] = (v >> 64) & GPR_VALUE_MASK - else: - assert self.kind is ShiftKind.Sr or self.kind is ShiftKind.Sra - inp = *state.gprs[self.inp], carry - for i in range(self.out.ty.length): - v = inp[i] | (inp[i + 1] << 64) - v >>= sh - out[i] = v & GPR_VALUE_MASK - # state.gprs[self._out_padding] is intentionally not written - state.gprs[self.out] = tuple(out) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpShiftImm(Op): - __slots__ = "out", "inp", "sh", "kind", "ca_out" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {"inp": self.inp} - - def outputs(self): - # type: () -> dict[str, SSAVal] - if self.ca_out is not None: - return {"out": self.out, "ca_out": self.ca_out} - return {"out": self.out} - - def __init__(self, fn, inp, sh, kind): - # type: (Fn, SSAGPR, int, ShiftKind) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", inp.ty) - self.inp = inp - if not (0 <= sh < 64): - raise ValueError("shift amount out of range") - self.sh = sh - self.kind = kind - if self.kind is ShiftKind.Sra: - self.ca_out = SSAVal(self, "ca_out", CAType()) - else: - self.ca_out = None - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - out = ctx.sgpr(self.out) - inp = ctx.sgpr(self.inp) - if self.kind is ShiftKind.Sl: - mnemonic = "rldicr" - args = f"{self.sh}, {63 - self.sh}" - elif self.kind is ShiftKind.Sr: - mnemonic = "rldicl" - v = (64 - self.sh) % 64 - args = f"{v}, {self.sh}" - else: - assert self.kind is ShiftKind.Sra - mnemonic = "sradi" - args = f"{self.sh}" - if ctx.needs_sv(self.out, self.inp): - return [f"sv.{mnemonic} {out}, {inp}, {args}"] - return [f"{mnemonic} {out}, {inp}, {args}"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - inp = state.gprs[self.inp][0] - if self.kind is ShiftKind.Sl: - assert self.ca_out is None - out = inp << self.sh - elif self.kind is ShiftKind.Sr: - assert self.ca_out is None - out = inp >> self.sh - else: - assert self.kind is ShiftKind.Sra - assert self.ca_out is not None - if inp & (1 << 63): # sign extend - inp -= 1 << 64 - out = inp >> self.sh - ca = inp < 0 and (out << self.sh) != inp - state.CAs[self.ca_out] = ca - state.gprs[self.out] = out, - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpLI(Op): - __slots__ = "out", "value", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, fn, value, vl=None): - # type: (Fn, int, SSAKnownVL | None) -> None - super().__init__(fn) - if vl is None: - length = 1 - else: - length = vl.ty.length - self.out = SSAVal(self, "out", GPRRangeType(length)) - if not (-1 << 15 <= value <= (1 << 15) - 1): - raise ValueError(f"value out of range: {value}") - self.value = value - self.vl = vl - assert_vl_is(vl, length) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - vec = self.out.ty.length != 1 - out = ctx.gpr(self.out, vec=vec) - if ctx.needs_sv(self.out): - return [f"sv.addi {out}, 0, {self.value}"] - return [f"addi {out}, 0, {self.value}"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - value = self.value & GPR_VALUE_MASK - state.gprs[self.out] = (value,) * self.out.ty.length - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpSetCA(Op): - __slots__ = "out", "value" - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, fn, value): - # type: (Fn, bool) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", CAType()) - self.value = value - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - if self.value: - return ["subfic 0, 0, -1"] - return ["addic 0, 0, 0"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - state.CAs[self.out] = self.value - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpLoad(Op): - __slots__ = "RT", "RA", "offset", "mem", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - retval["RA"] = self.RA - retval["mem"] = self.mem - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"RT": self.RT} - - def __init__(self, fn, RA, offset, mem, vl=None): - # type: (Fn, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None - super().__init__(fn) - if vl is None: - length = 1 - else: - length = vl.ty.length - self.RT = SSAVal(self, "RT", GPRRangeType(length)) - self.RA = RA - if not (-1 << 15 <= offset <= (1 << 15) - 1): - raise ValueError(f"offset out of range: {offset}") - if offset % 4 != 0: - raise ValueError(f"offset not aligned: {offset}") - self.offset = offset - self.mem = mem - self.vl = vl - assert_vl_is(vl, length) - - def get_extra_interferences(self): - # type: () -> Iterable[tuple[SSAVal, SSAVal]] - if self.RT.ty.length > 1: - yield self.RT, self.RA - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - RT = ctx.gpr(self.RT, vec=self.RT.ty.length != 1) - RA = ctx.sgpr(self.RA) - if ctx.needs_sv(self.RT, self.RA): - return [f"sv.ld {RT}, {self.offset}({RA})"] - return [f"ld {RT}, {self.offset}({RA})"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - addr = state.gprs[self.RA][0] - addr += self.offset - RT = [0] * self.RT.ty.length - mem = state.global_mems[self.mem] - for i in range(self.RT.ty.length): - cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK - if cur_addr % GPR_SIZE_IN_BYTES != 0: - raise ValueError(f"can't load from unaligned address: " - f"{cur_addr:#x}") - for j in range(GPR_SIZE_IN_BYTES): - byte_val = mem.get(cur_addr + j, 0) & 0xFF - RT[i] |= byte_val << (j * 8) - state.gprs[self.RT] = tuple(RT) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpStore(Op): - __slots__ = "RS", "RA", "offset", "mem_in", "mem_out", "vl" - - def inputs(self): - # type: () -> dict[str, SSAVal] - retval = {} # type: dict[str, SSAVal[Any]] - retval["RS"] = self.RS - retval["RA"] = self.RA - retval["mem_in"] = self.mem_in - if self.vl is not None: - retval["vl"] = self.vl - return retval - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"mem_out": self.mem_out} - - def __init__(self, fn, RS, RA, offset, mem_in, vl=None): - # type: (Fn, SSAGPRRange, SSAGPR, int, SSAVal[GlobalMemType], SSAKnownVL | None) -> None - super().__init__(fn) - self.RS = RS - self.RA = RA - if not (-1 << 15 <= offset <= (1 << 15) - 1): - raise ValueError(f"offset out of range: {offset}") - if offset % 4 != 0: - raise ValueError(f"offset not aligned: {offset}") - self.offset = offset - self.mem_in = mem_in - self.mem_out = SSAVal(self, "mem_out", mem_in.ty) - self.vl = vl - assert_vl_is(vl, RS.ty.length) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - RS = ctx.gpr(self.RS, vec=self.RS.ty.length != 1) - RA = ctx.sgpr(self.RA) - if ctx.needs_sv(self.RS, self.RA): - return [f"sv.std {RS}, {self.offset}({RA})"] - return [f"std {RS}, {self.offset}({RA})"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - mem = dict(state.global_mems[self.mem_in]) - addr = state.gprs[self.RA][0] - addr += self.offset - RS = state.gprs[self.RS] - for i in range(self.RS.ty.length): - cur_addr = (addr + i * GPR_SIZE_IN_BYTES) & GPR_VALUE_MASK - if cur_addr % GPR_SIZE_IN_BYTES != 0: - raise ValueError(f"can't store to unaligned address: " - f"{cur_addr:#x}") - for j in range(GPR_SIZE_IN_BYTES): - mem[cur_addr + j] = (RS[i] >> (j * 8)) & 0xFF - state.global_mems[self.mem_out] = FMap(mem) - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpFuncArg(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, fn, ty): - # type: (Fn, FixedGPRRangeType) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", ty) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - return [] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - if self.out not in state.fixed_gprs: - state.fixed_gprs[self.out] = (0,) * self.out.ty.length - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpInputMem(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, fn): - # type: (Fn) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", GlobalMemType()) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - return [] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - if self.out not in state.global_mems: - state.global_mems[self.out] = FMap() - - -@plain_data(unsafe_hash=True, frozen=True, repr=False) -@final -class OpSetVLImm(Op): - __slots__ = "out", - - def inputs(self): - # type: () -> dict[str, SSAVal] - return {} - - def outputs(self): - # type: () -> dict[str, SSAVal] - return {"out": self.out} - - def __init__(self, fn, length): - # type: (Fn, int) -> None - super().__init__(fn) - self.out = SSAVal(self, "out", KnownVLType(length)) - - def get_asm_lines(self, ctx): - # type: (AsmContext) -> list[str] - return [f"setvl 0, 0, {self.out.ty.length}, 0, 1, 1"] - - def pre_ra_sim(self, state): - # type: (PreRASimState) -> None - state.VLs[self.out] = self.out.ty.length - - -def op_set_to_list(ops): - # type: (Iterable[Op]) -> list[Op] - worklists = [{}] # type: list[dict[Op, None]] - inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]] - ops_to_pending_input_count_map = {} # type: dict[Op, int] - for op in ops: - input_count = 0 - for val in op.inputs().values(): - input_count += 1 - inps_to_ops_map[val][op] = None - while len(worklists) <= input_count: - worklists.append({}) - ops_to_pending_input_count_map[op] = input_count - worklists[input_count][op] = None - retval = [] # type: list[Op] - ready_vals = OSet() # type: OSet[SSAVal] - while len(worklists[0]) != 0: - writing_op = next(iter(worklists[0])) - del worklists[0][writing_op] - retval.append(writing_op) - for val in writing_op.outputs().values(): - if val in ready_vals: - raise ValueError(f"multiple instructions must not write " - f"to the same SSA value: {val}") - ready_vals.add(val) - for reading_op in inps_to_ops_map[val]: - pending = ops_to_pending_input_count_map[reading_op] - del worklists[pending][reading_op] - pending -= 1 - worklists[pending][reading_op] = None - ops_to_pending_input_count_map[reading_op] = pending - for worklist in worklists: - for op in worklist: - raise ValueError(f"instruction is part of a dependency loop or " - f"its inputs are never written: {op}") - return retval - - -def generate_assembly(ops, assigned_registers=None): - # type: (list[Op], dict[SSAVal, RegLoc] | None) -> list[str] - if assigned_registers is None: - from bigint_presentation_code.register_allocator import \ - allocate_registers - assigned_registers = allocate_registers(ops) - ctx = AsmContext(assigned_registers) - retval = [] # list[str] - for op in ops: - retval.extend(op.get_asm_lines(ctx)) - retval.append("bclr 20, 0, 0") - return retval diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py deleted file mode 100644 index cc794e9..0000000 --- a/src/bigint_presentation_code/register_allocator.py +++ /dev/null @@ -1,432 +0,0 @@ -""" -Register Allocator for Toom-Cook algorithm generator for SVP64 - -this uses an algorithm based on: -[Retargetable Graph-Coloring Register Allocation for Irregular Architectures](https://user.it.uu.se/~svenolof/wpo/AllocSCOPES2003.20030626b.pdf) -""" - -from itertools import combinations -from typing import Generic, Iterable, Mapping, TypeVar - -from nmutil.plain_data import plain_data - -from bigint_presentation_code.compiler_ir import (GPRRangeType, Op, RegClass, - RegLoc, RegType, SSAVal) -from bigint_presentation_code.type_util import final -from bigint_presentation_code.util import OFSet, OSet - -_RegType = TypeVar("_RegType", bound=RegType) - - -@plain_data(unsafe_hash=True, order=True, frozen=True) -class LiveInterval: - __slots__ = "first_write", "last_use" - - def __init__(self, first_write, last_use=None): - # type: (int, int | None) -> None - if last_use is None: - last_use = first_write - if last_use < first_write: - raise ValueError("uses must be after first_write") - if first_write < 0 or last_use < 0: - raise ValueError("indexes must be nonnegative") - self.first_write = first_write - self.last_use = last_use - - def overlaps(self, other): - # type: (LiveInterval) -> bool - if self.first_write == other.first_write: - return True - return self.last_use > other.first_write \ - and other.last_use > self.first_write - - def __add__(self, use): - # type: (int) -> LiveInterval - last_use = max(self.last_use, use) - return LiveInterval(first_write=self.first_write, last_use=last_use) - - @property - def live_after_op_range(self): - """the range of op indexes where self is live immediately after the - Op at each index - """ - return range(self.first_write, self.last_use) - - -@final -class MergedRegSet(Mapping[SSAVal[_RegType], int]): - def __init__(self, reg_set): - # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None - self.__items = {} # type: dict[SSAVal[_RegType], int] - if isinstance(reg_set, SSAVal): - reg_set = [(reg_set, 0)] - for ssa_val, offset in reg_set: - if ssa_val in self.__items: - other = self.__items[ssa_val] - if offset != other: - raise ValueError( - f"can't merge register sets: conflicting offsets: " - f"for {ssa_val}: {offset} != {other}") - else: - self.__items[ssa_val] = offset - first_item = None - for i in self.__items.items(): - first_item = i - break - if first_item is None: - raise ValueError("can't have empty MergedRegs") - first_ssa_val, start = first_item - ty = first_ssa_val.ty - if isinstance(ty, GPRRangeType): - stop = start + ty.length - for ssa_val, offset in self.__items.items(): - if not isinstance(ssa_val.ty, GPRRangeType): - raise ValueError(f"can't merge incompatible types: " - f"{ssa_val.ty} and {ty}") - stop = max(stop, offset + ssa_val.ty.length) - start = min(start, offset) - ty = GPRRangeType(stop - start) - else: - stop = 1 - for ssa_val, offset in self.__items.items(): - if offset != 0: - raise ValueError(f"can't have non-zero offset " - f"for {ssa_val.ty}") - if ty != ssa_val.ty: - raise ValueError(f"can't merge incompatible types: " - f"{ssa_val.ty} and {ty}") - self.__start = start # type: int - self.__stop = stop # type: int - self.__ty = ty # type: RegType - self.__hash = hash(OFSet(self.items())) - - @staticmethod - def from_equality_constraint(constraint_sequence): - # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType] - if len(constraint_sequence) == 1: - # any type allowed with len = 1 - return MergedRegSet(constraint_sequence[0]) - offset = 0 - retval = [] - for val in constraint_sequence: - if not isinstance(val.ty, GPRRangeType): - raise ValueError("equality constraint sequences must only " - "have SSAVal type GPRRangeType") - retval.append((val, offset)) - offset += val.ty.length - return MergedRegSet(retval) - - @property - def ty(self): - return self.__ty - - @property - def stop(self): - return self.__stop - - @property - def start(self): - return self.__start - - @property - def range(self): - return range(self.__start, self.__stop) - - def offset_by(self, amount): - # type: (int) -> MergedRegSet[_RegType] - return MergedRegSet((k, v + amount) for k, v in self.items()) - - def normalized(self): - # type: () -> MergedRegSet[_RegType] - return self.offset_by(-self.start) - - def with_offset_to_match(self, target): - # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType] - for ssa_val, offset in self.items(): - if ssa_val in target: - return self.offset_by(target[ssa_val] - offset) - raise ValueError("can't change offset to match unrelated MergedRegSet") - - def __getitem__(self, item): - # type: (SSAVal[_RegType]) -> int - return self.__items[item] - - def __iter__(self): - return iter(self.__items) - - def __len__(self): - return len(self.__items) - - def __hash__(self): - return self.__hash - - def __repr__(self): - return f"MergedRegSet({list(self.__items.items())})" - - -@final -class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]): - def __init__(self, ops): - # type: (Iterable[Op]) -> None - merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]] - for op in ops: - for val in (*op.inputs().values(), *op.outputs().values()): - if val not in merged_sets: - merged_sets[val] = MergedRegSet(val) - for e in op.get_equality_constraints(): - lhs_set = MergedRegSet.from_equality_constraint(e.lhs) - rhs_set = MergedRegSet.from_equality_constraint(e.rhs) - items = [] # type: list[tuple[SSAVal, int]] - for i in e.lhs: - s = merged_sets[i].with_offset_to_match(lhs_set) - items.extend(s.items()) - for i in e.rhs: - s = merged_sets[i].with_offset_to_match(rhs_set) - items.extend(s.items()) - full_set = MergedRegSet(items) - for val in full_set.keys(): - merged_sets[val] = full_set - - self.__map = {k: v.normalized() for k, v in merged_sets.items()} - - def __getitem__(self, key): - # type: (SSAVal) -> MergedRegSet - return self.__map[key] - - def __iter__(self): - return iter(self.__map) - - def __len__(self): - return len(self.__map) - - def __repr__(self): - return f"MergedRegSets(data={self.__map})" - - -@final -class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): - def __init__(self, ops): - # type: (list[Op]) -> None - self.__merged_reg_sets = MergedRegSets(ops) - live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval] - for op_idx, op in enumerate(ops): - for val in op.inputs().values(): - live_intervals[self.__merged_reg_sets[val]] += op_idx - for val in op.outputs().values(): - reg_set = self.__merged_reg_sets[val] - if reg_set not in live_intervals: - live_intervals[reg_set] = LiveInterval(op_idx) - else: - live_intervals[reg_set] += op_idx - self.__live_intervals = live_intervals - live_after = [] # type: list[OSet[MergedRegSet[_RegType]]] - live_after += (OSet() for _ in ops) - for reg_set, live_interval in self.__live_intervals.items(): - for i in live_interval.live_after_op_range: - live_after[i].add(reg_set) - self.__live_after = [OFSet(i) for i in live_after] - - @property - def merged_reg_sets(self): - return self.__merged_reg_sets - - def __getitem__(self, key): - # type: (MergedRegSet[_RegType]) -> LiveInterval - return self.__live_intervals[key] - - def __iter__(self): - return iter(self.__live_intervals) - - def __len__(self): - return len(self.__live_intervals) - - def reg_sets_live_after(self, op_index): - # type: (int) -> OFSet[MergedRegSet[_RegType]] - return self.__live_after[op_index] - - def __repr__(self): - reg_sets_live_after = dict(enumerate(self.__live_after)) - return (f"LiveIntervals(live_intervals={self.__live_intervals}, " - f"merged_reg_sets={self.merged_reg_sets}, " - f"reg_sets_live_after={reg_sets_live_after})") - - -@final -class IGNode(Generic[_RegType]): - """ interference graph node """ - __slots__ = "merged_reg_set", "edges", "reg" - - def __init__(self, merged_reg_set, edges=(), reg=None): - # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None - self.merged_reg_set = merged_reg_set - self.edges = OSet(edges) - self.reg = reg - - def add_edge(self, other): - # type: (IGNode) -> None - self.edges.add(other) - other.edges.add(self) - - def __eq__(self, other): - # type: (object) -> bool - if isinstance(other, IGNode): - return self.merged_reg_set == other.merged_reg_set - return NotImplemented - - def __hash__(self): - return hash(self.merged_reg_set) - - def __repr__(self, nodes=None): - # type: (None | dict[IGNode, int]) -> str - if nodes is None: - nodes = {} - if self in nodes: - return f"" - nodes[self] = len(nodes) - edges = "{" + ", ".join(i.__repr__(nodes) for i in self.edges) + "}" - return (f"IGNode(#{nodes[self]}, " - f"merged_reg_set={self.merged_reg_set}, " - f"edges={edges}, " - f"reg={self.reg})") - - @property - def reg_class(self): - # type: () -> RegClass - return self.merged_reg_set.ty.reg_class - - def reg_conflicts_with_neighbors(self, reg): - # type: (RegLoc) -> bool - for neighbor in self.edges: - if neighbor.reg is not None and neighbor.reg.conflicts(reg): - return True - return False - - -@final -class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]): - def __init__(self, merged_reg_sets): - # type: (Iterable[MergedRegSet[_RegType]]) -> None - self.__nodes = {i: IGNode(i) for i in merged_reg_sets} - - def __getitem__(self, key): - # type: (MergedRegSet[_RegType]) -> IGNode - return self.__nodes[key] - - def __iter__(self): - return iter(self.__nodes) - - def __len__(self): - return len(self.__nodes) - - def __repr__(self): - nodes = {} - nodes_text = [f"...: {node.__repr__(nodes)}" for node in self.values()] - nodes_text = ", ".join(nodes_text) - return f"InterferenceGraph(nodes={{{nodes_text}}})" - - -@plain_data() -class AllocationFailed: - __slots__ = "node", "live_intervals", "interference_graph" - - def __init__(self, node, live_intervals, interference_graph): - # type: (IGNode, LiveIntervals, InterferenceGraph) -> None - self.node = node - self.live_intervals = live_intervals - self.interference_graph = interference_graph - - -class AllocationFailedError(Exception): - def __init__(self, msg, allocation_failed): - # type: (str, AllocationFailed) -> None - super().__init__(msg, allocation_failed) - self.allocation_failed = allocation_failed - - -def try_allocate_registers_without_spilling(ops): - # type: (list[Op]) -> dict[SSAVal, RegLoc] | AllocationFailed - - live_intervals = LiveIntervals(ops) - merged_reg_sets = live_intervals.merged_reg_sets - interference_graph = InterferenceGraph(merged_reg_sets.values()) - for op_idx, op in enumerate(ops): - reg_sets = live_intervals.reg_sets_live_after(op_idx) - for i, j in combinations(reg_sets, 2): - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: - interference_graph[i].add_edge(interference_graph[j]) - for i, j in op.get_extra_interferences(): - i = merged_reg_sets[i] - j = merged_reg_sets[j] - if i.ty.reg_class.max_conflicts_with(j.ty.reg_class) != 0: - interference_graph[i].add_edge(interference_graph[j]) - - nodes_remaining = OSet(interference_graph.values()) - - def local_colorability_score(node): - # type: (IGNode) -> int - """ returns a positive integer if node is locally colorable, returns - zero or a negative integer if node isn't known to be locally - colorable, the more negative the value, the less colorable - """ - if node not in nodes_remaining: - raise ValueError() - retval = len(node.reg_class) - for neighbor in node.edges: - if neighbor in nodes_remaining: - retval -= node.reg_class.max_conflicts_with(neighbor.reg_class) - return retval - - node_stack = [] # type: list[IGNode] - while True: - best_node = None # type: None | IGNode - best_score = 0 - for node in nodes_remaining: - score = local_colorability_score(node) - if best_node is None or score > best_score: - best_node = node - best_score = score - if best_score > 0: - # it's locally colorable, no need to find a better one - break - - if best_node is None: - break - node_stack.append(best_node) - nodes_remaining.remove(best_node) - - retval = {} # type: dict[SSAVal, RegLoc] - - while len(node_stack) > 0: - node = node_stack.pop() - if node.reg is not None: - if node.reg_conflicts_with_neighbors(node.reg): - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) - else: - # pick the first non-conflicting register in node.reg_class, since - # register classes are ordered from most preferred to least - # preferred register. - for reg in node.reg_class: - if not node.reg_conflicts_with_neighbors(reg): - node.reg = reg - break - if node.reg is None: - return AllocationFailed(node=node, - live_intervals=live_intervals, - interference_graph=interference_graph) - - for ssa_val, offset in node.merged_reg_set.items(): - retval[ssa_val] = node.reg.get_subreg_at_offset(ssa_val.ty, offset) - - return retval - - -def allocate_registers(ops): - # type: (list[Op]) -> dict[SSAVal, RegLoc] - retval = try_allocate_registers_without_spilling(ops) - if isinstance(retval, AllocationFailed): - # TODO: implement spilling - raise AllocationFailedError( - "spilling required but not yet implemented", retval) - return retval