From 8f1f632f8cd2bcfd0fdb6d047d8c4eae4d422d73 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 14 Oct 2022 01:05:57 -0700 Subject: [PATCH] test_op_set_to_list works --- setup.py | 11 ++ src/bigint_presentation_code/compiler_ir.py | 117 ++++++++++++------ .../register_allocator.py | 42 +++---- .../test_compiler_ir.py | 55 +++++++- 4 files changed, 165 insertions(+), 60 deletions(-) diff --git a/setup.py b/setup.py index 36d91f4..8db62c9 100644 --- a/setup.py +++ b/setup.py @@ -5,9 +5,19 @@ README = Path(__file__).with_name('README.md').read_text("UTF-8") version = '0.0.1' +cprop = "git+https://git.libre-soc.org/git/cached-property.git@1.5.2" \ + "#egg=cached-property-1.5.2" + install_requires = [ "libresoc-nmutil", 'libresoc-openpower-isa', + # git url needed for having `pip3 install -e .` install from libre-soc git + 'cached-property@'+cprop, +] + +# git url needed for having `setup.py develop` install from libre-soc git +dependency_links = [ + cprop, ] setup( @@ -30,4 +40,5 @@ setup( include_package_data=True, zip_safe=False, install_requires=install_requires, + dependency_links=dependency_links, ) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index f4f09f7..24b8649 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -7,9 +7,10 @@ from collections import defaultdict from enum import Enum, EnumMeta, unique from functools import lru_cache from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence, - TypeVar) + TypeVar, cast) -from nmutil.plain_data import plain_data +from cached_property import cached_property +from nmutil.plain_data import fields, plain_data if TYPE_CHECKING: from typing_extensions import final @@ -196,7 +197,7 @@ class RegType(metaclass=ABCMeta): return ... -_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) +_RegType = TypeVar("_RegType", bound=RegType) @plain_data(frozen=True, eq=False) @@ -359,13 +360,13 @@ class StackSlotType(RegType): return hash(self.length_in_slots) -@plain_data(frozen=True, eq=False) +@plain_data(frozen=True, eq=False, repr=False) @final -class SSAVal(Generic[_RegT_co]): - __slots__ = "op", "arg_name", "ty", "arg_index" +class SSAVal(Generic[_RegType]): + __slots__ = "op", "arg_name", "ty", def __init__(self, op, arg_name, ty): - # type: (Op, str, _RegT_co) -> None + # type: (Op, str, _RegType) -> None self.op = op """the Op that writes this SSAVal""" @@ -383,6 +384,19 @@ class SSAVal(Generic[_RegT_co]): def __hash__(self): return hash((id(self.op), self.arg_name)) + def __repr__(self): + fields_list = [] + for name in fields(self): + v = getattr(self, name, None) + if v is not None: + if name == "op": + v = v.__repr__(just_id=True) + else: + v = repr(v) + fields_list.append(f"{name}={v}") + fields_str = ", ".join(fields_list) + return f"SSAVal({fields_str})" + @final @plain_data(unsafe_hash=True, frozen=True) @@ -397,7 +411,17 @@ class EqualityConstraint: raise ValueError("can't constrain an empty list to be equal") -@plain_data(unsafe_hash=True, frozen=True) +class _NotSet: + """ helper for __repr__ for when fields aren't set """ + + def __repr__(self): + return "" + + +_NOT_SET = _NotSet() + + +@plain_data(unsafe_hash=True, frozen=True, repr=False) class Op(metaclass=ABCMeta): __slots__ = () @@ -421,11 +445,26 @@ class Op(metaclass=ABCMeta): if False: yield ... - def __init__(self): - pass + __NEXT_ID = 0 + @cached_property + def id(self): + retval = Op.__NEXT_ID + Op.__NEXT_ID += 1 + return retval -@plain_data(unsafe_hash=True, frozen=True) + @final + def __repr__(self, just_id=False): + fields_list = [f"#{self.id}"] + if not just_id: + for name in fields(self): + v = getattr(self, name, _NOT_SET) + fields_list.append(f"{name}={v!r}") + fields_str = ', '.join(fields_list) + return f"{self.__class__.__name__}({fields_str})" + + +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLoadFromStackSlot(Op): __slots__ = "dest", "src" @@ -444,7 +483,7 @@ class OpLoadFromStackSlot(Op): self.src = src -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpStoreToStackSlot(Op): __slots__ = "dest", "src" @@ -463,9 +502,12 @@ class OpStoreToStackSlot(Op): self.src = src -@plain_data(unsafe_hash=True, frozen=True) +_RegSrcType = TypeVar("_RegSrcType", bound=RegType) + + +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final -class OpCopy(Op, Generic[_RegT_co]): +class OpCopy(Op, Generic[_RegSrcType, _RegType]): __slots__ = "dest", "src" def inputs(self): @@ -477,9 +519,9 @@ class OpCopy(Op, Generic[_RegT_co]): return {"dest": self.dest} def __init__(self, src, dest_ty=None): - # type: (SSAVal[_RegT_co], _RegT_co | None) -> None + # type: (SSAVal[_RegSrcType], _RegType | None) -> None if dest_ty is None: - dest_ty = src.ty + dest_ty = cast(_RegType, src.ty) if isinstance(src.ty, GPRRangeType) \ and isinstance(dest_ty, GPRRangeType): if src.ty.length != dest_ty.length: @@ -489,11 +531,11 @@ class OpCopy(Op, Generic[_RegT_co]): raise ValueError(f"incompatible source and destination " f"types: {src.ty} and {dest_ty}") - self.dest = SSAVal(self, "dest", dest_ty) + self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType] self.src = src -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpConcat(Op): __slots__ = "dest", "sources" @@ -518,7 +560,7 @@ class OpConcat(Op): yield EqualityConstraint([self.dest], [*self.sources]) -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpSplit(Op): __slots__ = "results", "src" @@ -551,7 +593,7 @@ class OpSplit(Op): yield EqualityConstraint([*self.results], [self.src]) -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpAddSubE(Op): __slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub" @@ -582,7 +624,7 @@ class OpAddSubE(Op): yield self.RT, self.RB -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpBigIntMulDiv(Op): __slots__ = "RT", "RA", "RB", "RC", "RS", "is_div" @@ -626,7 +668,7 @@ class ShiftKind(Enum): Sra = "sra" -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpBigIntShift(Op): __slots__ = "RT", "inp", "sh", "kind" @@ -652,7 +694,7 @@ class OpBigIntShift(Op): yield self.RT, self.sh -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLI(Op): __slots__ = "out", "value" @@ -671,7 +713,7 @@ class OpLI(Op): self.value = value -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpClearCY(Op): __slots__ = "out", @@ -689,7 +731,7 @@ class OpClearCY(Op): self.out = SSAVal(self, "out", CYType()) -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpLoad(Op): __slots__ = "RT", "RA", "offset", "mem" @@ -715,7 +757,7 @@ class OpLoad(Op): yield self.RT, self.RA -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpStore(Op): __slots__ = "RS", "RA", "offset", "mem_in", "mem_out" @@ -737,7 +779,7 @@ class OpStore(Op): self.mem_out = SSAVal(self, "mem_out", mem_in.ty) -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpFuncArg(Op): __slots__ = "out", @@ -755,7 +797,7 @@ class OpFuncArg(Op): self.out = SSAVal(self, "out", ty) -@plain_data(unsafe_hash=True, frozen=True) +@plain_data(unsafe_hash=True, frozen=True, repr=False) @final class OpInputMem(Op): __slots__ = "out", @@ -775,33 +817,34 @@ class OpInputMem(Op): def op_set_to_list(ops): # type: (Iterable[Op]) -> list[Op] - worklists = [set()] # type: list[set[Op]] - input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]] + worklists = [{}] # type: list[dict[Op, None]] + inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]] ops_to_pending_input_count_map = {} # type: dict[Op, int] for op in ops: input_count = 0 for val in op.inputs().values(): input_count += 1 - input_vals_to_ops_map[val].add(op) + inps_to_ops_map[val][op] = None while len(worklists) <= input_count: - worklists.append(set()) + worklists.append({}) ops_to_pending_input_count_map[op] = input_count - worklists[input_count].add(op) + worklists[input_count][op] = None retval = [] # type: list[Op] ready_vals = set() # type: set[SSAVal] while len(worklists[0]) != 0: - writing_op = worklists[0].pop() + writing_op = next(iter(worklists[0])) + del worklists[0][writing_op] retval.append(writing_op) for val in writing_op.outputs().values(): if val in ready_vals: raise ValueError(f"multiple instructions must not write " f"to the same SSA value: {val}") ready_vals.add(val) - for reading_op in input_vals_to_ops_map[val]: + for reading_op in inps_to_ops_map[val]: pending = ops_to_pending_input_count_map[reading_op] - worklists[pending].remove(reading_op) + del worklists[pending][reading_op] pending -= 1 - worklists[pending].add(reading_op) + worklists[pending][reading_op] = None ops_to_pending_input_count_map[reading_op] = pending for worklist in worklists: for op in worklist: diff --git a/src/bigint_presentation_code/register_allocator.py b/src/bigint_presentation_code/register_allocator.py index 297e3e5..75b3422 100644 --- a/src/bigint_presentation_code/register_allocator.py +++ b/src/bigint_presentation_code/register_allocator.py @@ -20,7 +20,7 @@ else: return v -_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True) +_RegType = TypeVar("_RegType", bound=RegType) @plain_data(unsafe_hash=True, order=True, frozen=True) @@ -59,10 +59,10 @@ class LiveInterval: @final -class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): +class MergedRegSet(Mapping[SSAVal[_RegType], int]): def __init__(self, reg_set): - # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None - self.__items = {} # type: dict[SSAVal[_RegT_co], int] + # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None + self.__items = {} # type: dict[SSAVal[_RegType], int] if isinstance(reg_set, SSAVal): reg_set = [(reg_set, 0)] for ssa_val, offset in reg_set: @@ -107,7 +107,7 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): @staticmethod def from_equality_constraint(constraint_sequence): - # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co] + # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType] if len(constraint_sequence) == 1: # any type allowed with len = 1 return MergedRegSet(constraint_sequence[0]) @@ -138,22 +138,22 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): return range(self.__start, self.__stop) def offset_by(self, amount): - # type: (int) -> MergedRegSet[_RegT_co] + # type: (int) -> MergedRegSet[_RegType] return MergedRegSet((k, v + amount) for k, v in self.items()) def normalized(self): - # type: () -> MergedRegSet[_RegT_co] + # type: () -> MergedRegSet[_RegType] return self.offset_by(-self.start) def with_offset_to_match(self, target): - # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co] + # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType] for ssa_val, offset in self.items(): if ssa_val in target: return self.offset_by(target[ssa_val] - offset) raise ValueError("can't change offset to match unrelated MergedRegSet") def __getitem__(self, item): - # type: (SSAVal[_RegT_co]) -> int + # type: (SSAVal[_RegType]) -> int return self.__items[item] def __iter__(self): @@ -170,10 +170,10 @@ class MergedRegSet(Mapping[SSAVal[_RegT_co], int]): @final -class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]): +class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]): def __init__(self, ops): # type: (Iterable[Op]) -> None - merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]] + merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]] for op in ops: for val in (*op.inputs().values(), *op.outputs().values()): if val not in merged_sets: @@ -204,11 +204,11 @@ class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]): @final -class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): +class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]): def __init__(self, ops): # type: (list[Op]) -> None self.__merged_reg_sets = MergedRegSets(ops) - live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval] + live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval] for op_idx, op in enumerate(ops): for val in op.inputs().values(): live_intervals[self.__merged_reg_sets[val]] += op_idx @@ -219,7 +219,7 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): else: live_intervals[reg_set] += op_idx self.__live_intervals = live_intervals - live_after = [] # type: list[set[MergedRegSet[_RegT_co]]] + live_after = [] # type: list[set[MergedRegSet[_RegType]]] live_after += (set() for _ in ops) for reg_set, live_interval in self.__live_intervals.items(): for i in live_interval.live_after_op_range: @@ -231,14 +231,14 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): return self.__merged_reg_sets def __getitem__(self, key): - # type: (MergedRegSet[_RegT_co]) -> LiveInterval + # type: (MergedRegSet[_RegType]) -> LiveInterval return self.__live_intervals[key] def __iter__(self): return iter(self.__live_intervals) def reg_sets_live_after(self, op_index): - # type: (int) -> frozenset[MergedRegSet[_RegT_co]] + # type: (int) -> frozenset[MergedRegSet[_RegType]] return self.__live_after[op_index] def __repr__(self): @@ -249,12 +249,12 @@ class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]): @final -class IGNode(Generic[_RegT_co]): +class IGNode(Generic[_RegType]): """ interference graph node """ __slots__ = "merged_reg_set", "edges", "reg" def __init__(self, merged_reg_set, edges=(), reg=None): - # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None + # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None self.merged_reg_set = merged_reg_set self.edges = set(edges) self.reg = reg @@ -300,13 +300,13 @@ class IGNode(Generic[_RegT_co]): @final -class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]): +class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]): def __init__(self, merged_reg_sets): - # type: (Iterable[MergedRegSet[_RegT_co]]) -> None + # type: (Iterable[MergedRegSet[_RegType]]) -> None self.__nodes = {i: IGNode(i) for i in merged_reg_sets} def __getitem__(self, key): - # type: (MergedRegSet[_RegT_co]) -> IGNode + # type: (MergedRegSet[_RegType]) -> IGNode return self.__nodes[key] def __iter__(self): diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py index 231f2fb..26d5272 100644 --- a/src/bigint_presentation_code/test_compiler_ir.py +++ b/src/bigint_presentation_code/test_compiler_ir.py @@ -1,10 +1,61 @@ import unittest -from bigint_presentation_code.compiler_ir import Op, op_set_to_list +from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType, + Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore, + op_set_to_list) class TestCompilerIR(unittest.TestCase): - pass # no tests yet, just testing importing + maxDiff = None + + def test_op_set_to_list(self): + ops = [] # list[Op] + op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3))) + ops.append(op0) + op1 = OpCopy(op0.out, GPRType()) + ops.append(op1) + arg = op1.dest + op2 = OpInputMem() + ops.append(op2) + mem = op2.out + op3 = OpLoad(arg, offset=0, mem=mem, length=32) + ops.append(op3) + a = op3.RT + op4 = OpLI(1) + ops.append(op4) + b_0 = op4.out + op5 = OpLI(0, length=31) + ops.append(op5) + b_rest = op5.out + op6 = OpConcat([b_0, b_rest]) + ops.append(op6) + b = op6.dest + op7 = OpClearCY() + ops.append(op7) + cy = op7.out + op8 = OpAddSubE(a, b, cy, is_sub=False) + ops.append(op8) + s = op8.RT + op9 = OpStore(s, arg, offset=0, mem_in=mem) + ops.append(op9) + mem = op9.mem_out + + expected_ops = [ + op7, # OpClearCY() + op5, # OpLI(0, length=31) + op4, # OpLI(1) + op2, # OpInputMem() + op0, # OpFuncArg(FixedGPRRangeType(GPRRange(3))) + op6, # OpConcat([b_0, b_rest]) + op1, # OpCopy(op0.out, GPRType()) + op3, # OpLoad(arg, offset=0, mem=mem, length=32) + op8, # OpAddSubE(a, b, cy, is_sub=False) + op9, # OpStore(s, arg, offset=0, mem_in=mem) + ] + + ops = op_set_to_list(reversed(ops)) + if ops != expected_ops: + self.assertEqual(repr(ops), repr(expected_ops)) if __name__ == "__main__": -- 2.30.2