From df75844fb207370a516bc9927beeb463c729e8e0 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 14 Oct 2022 17:13:26 -0700 Subject: [PATCH] add Fn class rather than global for generating op ids --- src/bigint_presentation_code/compiler_ir.py | 121 ++++++++++-------- .../test_compiler_ir.py | 36 ++---- .../test_register_allocator.py | 67 ++++------ 3 files changed, 105 insertions(+), 119 deletions(-) diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index bfcf159..aa37fa4 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -399,6 +399,21 @@ class EqualityConstraint: raise ValueError("can't constrain an empty list to be equal") +@final +class Fn: + __slots__ = "ops", + + def __init__(self): + # type: () -> None + self.ops = [] # type: list[Op] + + def __repr__(self, short=False): + if short: + return "" + ops = ", ".join(op.__repr__(just_id=True) for op in self.ops) + return f"" + + class _NotSet: """ helper for __repr__ for when fields aren't set """ @@ -411,7 +426,7 @@ _NOT_SET = _NotSet() @plain_data(unsafe_hash=True, frozen=True, repr=False) class Op(metaclass=ABCMeta): - __slots__ = () + __slots__ = "id", "fn" @abstractmethod def inputs(self): @@ -433,19 +448,11 @@ class Op(metaclass=ABCMeta): if False: yield ... - __NEXT_ID = 0 - - @cached_property - def id(self): - # type: () -> int - # use cached_property rather than done in init so id is usable even if - # init hasn't run - retval = Op.__NEXT_ID - Op.__NEXT_ID += 1 - return retval - - def __init__(self): - self.id # initialize + def __init__(self, fn): + # type: (Fn) -> None + self.id = len(fn.ops) + fn.ops.append(self) + self.fn = fn @final def __repr__(self, just_id=False): @@ -461,6 +468,8 @@ class Op(metaclass=ABCMeta): if ((outputs is None or name in outputs) and isinstance(v, SSAVal)): v = v.__repr__(long=True) + elif isinstance(v, Fn): + v = v.__repr__(short=True) else: v = repr(v) fields_list.append(f"{name}={v}") @@ -481,9 +490,9 @@ class OpLoadFromStackSlot(Op): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, src): - # type: (SSAVal[GPRRangeType]) -> None - super().__init__() + def __init__(self, fn, src): + # type: (Fn, SSAVal[GPRRangeType]) -> None + super().__init__(fn) self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length)) self.src = src @@ -501,9 +510,9 @@ class OpStoreToStackSlot(Op): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, src): - # type: (SSAVal[StackSlotType]) -> None - super().__init__() + def __init__(self, fn, src): + # type: (Fn, SSAVal[StackSlotType]) -> None + super().__init__(fn) self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots)) self.src = src @@ -524,9 +533,9 @@ class OpCopy(Op, Generic[_RegSrcType, _RegType]): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, src, dest_ty=None): - # type: (SSAVal[_RegSrcType], _RegType | None) -> None - super().__init__() + def __init__(self, fn, src, dest_ty=None): + # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None + super().__init__(fn) if dest_ty is None: dest_ty = cast(_RegType, src.ty) if isinstance(src.ty, GPRRangeType) \ @@ -560,9 +569,9 @@ class OpConcat(Op): # type: () -> dict[str, SSAVal] return {"dest": self.dest} - def __init__(self, sources): - # type: (Iterable[SSAVal[GPRRangeType]]) -> None - super().__init__() + def __init__(self, fn, sources): + # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None + super().__init__(fn) sources = tuple(sources) self.dest = SSAVal(self, "dest", GPRRangeType( sum(i.ty.length for i in sources))) @@ -586,9 +595,9 @@ class OpSplit(Op): # type: () -> dict[str, SSAVal] return {i.arg_name: i for i in self.results} - def __init__(self, src, split_indexes): - # type: (SSAVal[GPRRangeType], Iterable[int]) -> None - super().__init__() + def __init__(self, fn, src, split_indexes): + # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None + super().__init__(fn) ranges = [] # type: list[GPRRangeType] last = 0 for i in split_indexes: @@ -620,9 +629,9 @@ class OpAddSubE(Op): # type: () -> dict[str, SSAVal] return {"RT": self.RT, "CY_out": self.CY_out} - def __init__(self, RA, RB, CY_in, is_sub): - # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None - super().__init__() + def __init__(self, fn, RA, RB, CY_in, is_sub): + # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None + super().__init__(fn) if RA.ty != RB.ty: raise TypeError(f"source types must match: " f"{RA} doesn't match {RB}") @@ -652,9 +661,9 @@ class OpBigIntMulDiv(Op): # type: () -> dict[str, SSAVal] return {"RT": self.RT, "RS": self.RS} - def __init__(self, RA, RB, RC, is_div): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None - super().__init__() + def __init__(self, fn, RA, RB, RC, is_div): + # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None + super().__init__(fn) self.RT = SSAVal(self, "RT", RA.ty) self.RA = RA self.RB = RB @@ -697,9 +706,9 @@ class OpBigIntShift(Op): # type: () -> dict[str, SSAVal] return {"RT": self.RT} - def __init__(self, inp, sh, kind): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None - super().__init__() + def __init__(self, fn, inp, sh, kind): + # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None + super().__init__(fn) self.RT = SSAVal(self, "RT", inp.ty) self.inp = inp self.sh = sh @@ -724,9 +733,9 @@ class OpLI(Op): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, value, length=1): - # type: (int, int) -> None - super().__init__() + def __init__(self, fn, value, length=1): + # type: (Fn, int, int) -> None + super().__init__(fn) self.out = SSAVal(self, "out", GPRRangeType(length)) self.value = value @@ -744,9 +753,9 @@ class OpClearCY(Op): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self): - # type: () -> None - super().__init__() + def __init__(self, fn): + # type: (Fn) -> None + super().__init__(fn) self.out = SSAVal(self, "out", CYType()) @@ -763,9 +772,9 @@ class OpLoad(Op): # type: () -> dict[str, SSAVal] return {"RT": self.RT} - def __init__(self, RA, offset, mem, length=1): - # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None - super().__init__() + def __init__(self, fn, RA, offset, mem, length=1): + # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None + super().__init__(fn) self.RT = SSAVal(self, "RT", GPRRangeType(length)) self.RA = RA self.offset = offset @@ -790,9 +799,9 @@ class OpStore(Op): # type: () -> dict[str, SSAVal] return {"mem_out": self.mem_out} - def __init__(self, RS, RA, offset, mem_in): - # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None - super().__init__() + def __init__(self, fn, RS, RA, offset, mem_in): + # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None + super().__init__(fn) self.RS = RS self.RA = RA self.offset = offset @@ -813,9 +822,9 @@ class OpFuncArg(Op): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self, ty): - # type: (FixedGPRRangeType) -> None - super().__init__() + def __init__(self, fn, ty): + # type: (Fn, FixedGPRRangeType) -> None + super().__init__(fn) self.out = SSAVal(self, "out", ty) @@ -832,9 +841,9 @@ class OpInputMem(Op): # type: () -> dict[str, SSAVal] return {"out": self.out} - def __init__(self): - # type: () -> None - super().__init__() + def __init__(self, fn): + # type: (Fn) -> None + super().__init__(fn) self.out = SSAVal(self, "out", GlobalMemType()) diff --git a/src/bigint_presentation_code/test_compiler_ir.py b/src/bigint_presentation_code/test_compiler_ir.py index 1f30547..ff52641 100644 --- a/src/bigint_presentation_code/test_compiler_ir.py +++ b/src/bigint_presentation_code/test_compiler_ir.py @@ -1,6 +1,6 @@ import unittest -from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType, +from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType, Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore, op_set_to_list) @@ -9,35 +9,25 @@ class TestCompilerIR(unittest.TestCase): maxDiff = None def test_op_set_to_list(self): - ops = [] # type: list[Op] - op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3))) - ops.append(op0) - op1 = OpCopy(op0.out, GPRType()) - ops.append(op1) + fn = Fn() + op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op1 = OpCopy(fn, op0.out, GPRType()) arg = op1.dest - op2 = OpInputMem() - ops.append(op2) + op2 = OpInputMem(fn) mem = op2.out - op3 = OpLoad(arg, offset=0, mem=mem, length=32) - ops.append(op3) + op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32) a = op3.RT - op4 = OpLI(1) - ops.append(op4) + op4 = OpLI(fn, 1) b_0 = op4.out - op5 = OpLI(0, length=31) - ops.append(op5) + op5 = OpLI(fn, 0, length=31) b_rest = op5.out - op6 = OpConcat([b_0, b_rest]) - ops.append(op6) + op6 = OpConcat(fn, [b_0, b_rest]) b = op6.dest - op7 = OpClearCY() - ops.append(op7) + op7 = OpClearCY(fn) cy = op7.out - op8 = OpAddSubE(a, b, cy, is_sub=False) - ops.append(op8) + op8 = OpAddSubE(fn, a, b, cy, is_sub=False) s = op8.RT - op9 = OpStore(s, arg, offset=0, mem_in=mem) - ops.append(op9) + op9 = OpStore(fn, s, arg, offset=0, mem_in=mem) mem = op9.mem_out expected_ops = [ @@ -53,7 +43,7 @@ class TestCompilerIR(unittest.TestCase): op9, # OpStore(s, arg, offset=0, mem_in=mem) ] - ops = op_set_to_list(reversed(ops)) + ops = op_set_to_list(fn.ops[::-1]) if ops != expected_ops: self.assertEqual(repr(ops), repr(expected_ops)) diff --git a/src/bigint_presentation_code/test_register_allocator.py b/src/bigint_presentation_code/test_register_allocator.py index 8ba74eb..bdc1938 100644 --- a/src/bigint_presentation_code/test_register_allocator.py +++ b/src/bigint_presentation_code/test_register_allocator.py @@ -1,6 +1,6 @@ import unittest -from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, +from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType, GlobalMem, Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, @@ -14,9 +14,10 @@ class TestMergedRegSet(unittest.TestCase): maxDiff = None def test_from_equality_constraint(self): - op0 = OpLI(0, length=1) - op1 = OpLI(0, length=2) - op2 = OpLI(0, length=3) + fn = Fn() + op0 = OpLI(fn, 0, length=1) + op1 = OpLI(fn, 0, length=2) + op2 = OpLI(fn, 0, length=3) self.assertEqual(MergedRegSet.from_equality_constraint([ op0.out, op1.out, @@ -41,15 +42,12 @@ class TestRegisterAllocator(unittest.TestCase): maxDiff = None def test_try_alloc_fail(self): - ops = [] # type: list[Op] - op0 = OpLI(0, length=52) - ops.append(op0) - op1 = OpLI(0, length=64) - ops.append(op1) - op2 = OpConcat([op0.out, op1.out]) - ops.append(op2) - - reg_assignments = try_allocate_registers_without_spilling(ops) + fn = Fn() + op0 = OpLI(fn, 0, length=52) + op1 = OpLI(fn, 0, length=64) + op2 = OpConcat(fn, [op0.out, op1.out]) + + reg_assignments = try_allocate_registers_without_spilling(fn.ops) self.assertEqual( repr(reg_assignments), "AllocationFailed(" @@ -81,38 +79,28 @@ class TestRegisterAllocator(unittest.TestCase): ) def test_try_alloc_bigint_inc(self): - ops = [] # type: list[Op] - op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3))) - ops.append(op0) - op1 = OpCopy(op0.out, GPRType()) - ops.append(op1) + fn = Fn() + op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3))) + op1 = OpCopy(fn, op0.out, GPRType()) arg = op1.dest - op2 = OpInputMem() - ops.append(op2) + op2 = OpInputMem(fn) mem = op2.out - op3 = OpLoad(arg, offset=0, mem=mem, length=32) - ops.append(op3) + op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32) a = op3.RT - op4 = OpLI(1) - ops.append(op4) + op4 = OpLI(fn, 1) b_0 = op4.out - op5 = OpLI(0, length=31) - ops.append(op5) + op5 = OpLI(fn, 0, length=31) b_rest = op5.out - op6 = OpConcat([b_0, b_rest]) - ops.append(op6) + op6 = OpConcat(fn, [b_0, b_rest]) b = op6.dest - op7 = OpClearCY() - ops.append(op7) + op7 = OpClearCY(fn) cy = op7.out - op8 = OpAddSubE(a, b, cy, is_sub=False) - ops.append(op8) + op8 = OpAddSubE(fn, a, b, cy, is_sub=False) s = op8.RT - op9 = OpStore(s, arg, offset=0, mem_in=mem) - ops.append(op9) + op9 = OpStore(fn, s, arg, offset=0, mem_in=mem) mem = op9.mem_out - reg_assignments = try_allocate_registers_without_spilling(ops) + reg_assignments = try_allocate_registers_without_spilling(fn.ops) expected_reg_assignments = { op0.out: GPRRange(start=3, length=1), @@ -132,12 +120,11 @@ class TestRegisterAllocator(unittest.TestCase): def tst_try_alloc_concat(self, expected_regs, expected_dest_reg): # type: (list[GPRRange], GPRRange) -> None - li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)] - ops = [*li_ops] # type: list[Op] - concat = OpConcat([i.out for i in li_ops]) - ops.append(concat) + fn = Fn() + li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)] + concat = OpConcat(fn, [i.out for i in li_ops]) - reg_assignments = try_allocate_registers_without_spilling(ops) + reg_assignments = try_allocate_registers_without_spilling(fn.ops) expected_reg_assignments = {concat.dest: expected_dest_reg} for li_op, reg in zip(li_ops, expected_regs): -- 2.30.2