raise ValueError("can't constrain an empty list to be equal")
+@final
+class Fn:
+ __slots__ = "ops",
+
+ def __init__(self):
+ # type: () -> None
+ self.ops = [] # type: list[Op]
+
+ def __repr__(self, short=False):
+ if short:
+ return "<Fn>"
+ ops = ", ".join(op.__repr__(just_id=True) for op in self.ops)
+ return f"<Fn([{ops}])>"
+
+
class _NotSet:
""" helper for __repr__ for when fields aren't set """
@plain_data(unsafe_hash=True, frozen=True, repr=False)
class Op(metaclass=ABCMeta):
- __slots__ = ()
+ __slots__ = "id", "fn"
@abstractmethod
def inputs(self):
if False:
yield ...
- __NEXT_ID = 0
-
- @cached_property
- def id(self):
- # type: () -> int
- # use cached_property rather than done in init so id is usable even if
- # init hasn't run
- retval = Op.__NEXT_ID
- Op.__NEXT_ID += 1
- return retval
-
- def __init__(self):
- self.id # initialize
+ def __init__(self, fn):
+ # type: (Fn) -> None
+ self.id = len(fn.ops)
+ fn.ops.append(self)
+ self.fn = fn
@final
def __repr__(self, just_id=False):
if ((outputs is None or name in outputs)
and isinstance(v, SSAVal)):
v = v.__repr__(long=True)
+ elif isinstance(v, Fn):
+ v = v.__repr__(short=True)
else:
v = repr(v)
fields_list.append(f"{name}={v}")
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, src):
- # type: (SSAVal[GPRRangeType]) -> None
- super().__init__()
+ def __init__(self, fn, src):
+ # type: (Fn, SSAVal[GPRRangeType]) -> None
+ super().__init__(fn)
self.dest = SSAVal(self, "dest", StackSlotType(src.ty.length))
self.src = src
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, src):
- # type: (SSAVal[StackSlotType]) -> None
- super().__init__()
+ def __init__(self, fn, src):
+ # type: (Fn, SSAVal[StackSlotType]) -> None
+ super().__init__(fn)
self.dest = SSAVal(self, "dest", GPRRangeType(src.ty.length_in_slots))
self.src = src
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, src, dest_ty=None):
- # type: (SSAVal[_RegSrcType], _RegType | None) -> None
- super().__init__()
+ def __init__(self, fn, src, dest_ty=None):
+ # type: (Fn, SSAVal[_RegSrcType], _RegType | None) -> None
+ super().__init__(fn)
if dest_ty is None:
dest_ty = cast(_RegType, src.ty)
if isinstance(src.ty, GPRRangeType) \
# type: () -> dict[str, SSAVal]
return {"dest": self.dest}
- def __init__(self, sources):
- # type: (Iterable[SSAVal[GPRRangeType]]) -> None
- super().__init__()
+ def __init__(self, fn, sources):
+ # type: (Fn, Iterable[SSAVal[GPRRangeType]]) -> None
+ super().__init__(fn)
sources = tuple(sources)
self.dest = SSAVal(self, "dest", GPRRangeType(
sum(i.ty.length for i in sources)))
# type: () -> dict[str, SSAVal]
return {i.arg_name: i for i in self.results}
- def __init__(self, src, split_indexes):
- # type: (SSAVal[GPRRangeType], Iterable[int]) -> None
- super().__init__()
+ def __init__(self, fn, src, split_indexes):
+ # type: (Fn, SSAVal[GPRRangeType], Iterable[int]) -> None
+ super().__init__(fn)
ranges = [] # type: list[GPRRangeType]
last = 0
for i in split_indexes:
# type: () -> dict[str, SSAVal]
return {"RT": self.RT, "CY_out": self.CY_out}
- def __init__(self, RA, RB, CY_in, is_sub):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
- super().__init__()
+ def __init__(self, fn, RA, RB, CY_in, is_sub):
+ # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRRangeType], SSAVal[CYType], bool) -> None
+ super().__init__(fn)
if RA.ty != RB.ty:
raise TypeError(f"source types must match: "
f"{RA} doesn't match {RB}")
# type: () -> dict[str, SSAVal]
return {"RT": self.RT, "RS": self.RS}
- def __init__(self, RA, RB, RC, is_div):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
- super().__init__()
+ def __init__(self, fn, RA, RB, RC, is_div):
+ # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], SSAVal[GPRType], bool) -> None
+ super().__init__(fn)
self.RT = SSAVal(self, "RT", RA.ty)
self.RA = RA
self.RB = RB
# type: () -> dict[str, SSAVal]
return {"RT": self.RT}
- def __init__(self, inp, sh, kind):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
- super().__init__()
+ def __init__(self, fn, inp, sh, kind):
+ # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], ShiftKind) -> None
+ super().__init__(fn)
self.RT = SSAVal(self, "RT", inp.ty)
self.inp = inp
self.sh = sh
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self, value, length=1):
- # type: (int, int) -> None
- super().__init__()
+ def __init__(self, fn, value, length=1):
+ # type: (Fn, int, int) -> None
+ super().__init__(fn)
self.out = SSAVal(self, "out", GPRRangeType(length))
self.value = value
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self):
- # type: () -> None
- super().__init__()
+ def __init__(self, fn):
+ # type: (Fn) -> None
+ super().__init__(fn)
self.out = SSAVal(self, "out", CYType())
# type: () -> dict[str, SSAVal]
return {"RT": self.RT}
- def __init__(self, RA, offset, mem, length=1):
- # type: (SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
- super().__init__()
+ def __init__(self, fn, RA, offset, mem, length=1):
+ # type: (Fn, SSAVal[GPRType], int, SSAVal[GlobalMemType], int) -> None
+ super().__init__(fn)
self.RT = SSAVal(self, "RT", GPRRangeType(length))
self.RA = RA
self.offset = offset
# type: () -> dict[str, SSAVal]
return {"mem_out": self.mem_out}
- def __init__(self, RS, RA, offset, mem_in):
- # type: (SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
- super().__init__()
+ def __init__(self, fn, RS, RA, offset, mem_in):
+ # type: (Fn, SSAVal[GPRRangeType], SSAVal[GPRType], int, SSAVal[GlobalMemType]) -> None
+ super().__init__(fn)
self.RS = RS
self.RA = RA
self.offset = offset
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self, ty):
- # type: (FixedGPRRangeType) -> None
- super().__init__()
+ def __init__(self, fn, ty):
+ # type: (Fn, FixedGPRRangeType) -> None
+ super().__init__(fn)
self.out = SSAVal(self, "out", ty)
# type: () -> dict[str, SSAVal]
return {"out": self.out}
- def __init__(self):
- # type: () -> None
- super().__init__()
+ def __init__(self, fn):
+ # type: (Fn) -> None
+ super().__init__(fn)
self.out = SSAVal(self, "out", GlobalMemType())
import unittest
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType,
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange, GPRType,
Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore,
op_set_to_list)
maxDiff = None
def test_op_set_to_list(self):
- ops = [] # type: list[Op]
- op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
- ops.append(op0)
- op1 = OpCopy(op0.out, GPRType())
- ops.append(op1)
+ fn = Fn()
+ op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op1 = OpCopy(fn, op0.out, GPRType())
arg = op1.dest
- op2 = OpInputMem()
- ops.append(op2)
+ op2 = OpInputMem(fn)
mem = op2.out
- op3 = OpLoad(arg, offset=0, mem=mem, length=32)
- ops.append(op3)
+ op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
a = op3.RT
- op4 = OpLI(1)
- ops.append(op4)
+ op4 = OpLI(fn, 1)
b_0 = op4.out
- op5 = OpLI(0, length=31)
- ops.append(op5)
+ op5 = OpLI(fn, 0, length=31)
b_rest = op5.out
- op6 = OpConcat([b_0, b_rest])
- ops.append(op6)
+ op6 = OpConcat(fn, [b_0, b_rest])
b = op6.dest
- op7 = OpClearCY()
- ops.append(op7)
+ op7 = OpClearCY(fn)
cy = op7.out
- op8 = OpAddSubE(a, b, cy, is_sub=False)
- ops.append(op8)
+ op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
s = op8.RT
- op9 = OpStore(s, arg, offset=0, mem_in=mem)
- ops.append(op9)
+ op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
mem = op9.mem_out
expected_ops = [
op9, # OpStore(s, arg, offset=0, mem_in=mem)
]
- ops = op_set_to_list(reversed(ops))
+ ops = op_set_to_list(fn.ops[::-1])
if ops != expected_ops:
self.assertEqual(repr(ops), repr(expected_ops))
import unittest
-from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange,
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, Fn, GPRRange,
GPRType, GlobalMem, Op, OpAddSubE,
OpClearCY, OpConcat, OpCopy,
OpFuncArg, OpInputMem, OpLI,
maxDiff = None
def test_from_equality_constraint(self):
- op0 = OpLI(0, length=1)
- op1 = OpLI(0, length=2)
- op2 = OpLI(0, length=3)
+ fn = Fn()
+ op0 = OpLI(fn, 0, length=1)
+ op1 = OpLI(fn, 0, length=2)
+ op2 = OpLI(fn, 0, length=3)
self.assertEqual(MergedRegSet.from_equality_constraint([
op0.out,
op1.out,
maxDiff = None
def test_try_alloc_fail(self):
- ops = [] # type: list[Op]
- op0 = OpLI(0, length=52)
- ops.append(op0)
- op1 = OpLI(0, length=64)
- ops.append(op1)
- op2 = OpConcat([op0.out, op1.out])
- ops.append(op2)
-
- reg_assignments = try_allocate_registers_without_spilling(ops)
+ fn = Fn()
+ op0 = OpLI(fn, 0, length=52)
+ op1 = OpLI(fn, 0, length=64)
+ op2 = OpConcat(fn, [op0.out, op1.out])
+
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
self.assertEqual(
repr(reg_assignments),
"AllocationFailed("
)
def test_try_alloc_bigint_inc(self):
- ops = [] # type: list[Op]
- op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
- ops.append(op0)
- op1 = OpCopy(op0.out, GPRType())
- ops.append(op1)
+ fn = Fn()
+ op0 = OpFuncArg(fn, FixedGPRRangeType(GPRRange(3)))
+ op1 = OpCopy(fn, op0.out, GPRType())
arg = op1.dest
- op2 = OpInputMem()
- ops.append(op2)
+ op2 = OpInputMem(fn)
mem = op2.out
- op3 = OpLoad(arg, offset=0, mem=mem, length=32)
- ops.append(op3)
+ op3 = OpLoad(fn, arg, offset=0, mem=mem, length=32)
a = op3.RT
- op4 = OpLI(1)
- ops.append(op4)
+ op4 = OpLI(fn, 1)
b_0 = op4.out
- op5 = OpLI(0, length=31)
- ops.append(op5)
+ op5 = OpLI(fn, 0, length=31)
b_rest = op5.out
- op6 = OpConcat([b_0, b_rest])
- ops.append(op6)
+ op6 = OpConcat(fn, [b_0, b_rest])
b = op6.dest
- op7 = OpClearCY()
- ops.append(op7)
+ op7 = OpClearCY(fn)
cy = op7.out
- op8 = OpAddSubE(a, b, cy, is_sub=False)
- ops.append(op8)
+ op8 = OpAddSubE(fn, a, b, cy, is_sub=False)
s = op8.RT
- op9 = OpStore(s, arg, offset=0, mem_in=mem)
- ops.append(op9)
+ op9 = OpStore(fn, s, arg, offset=0, mem_in=mem)
mem = op9.mem_out
- reg_assignments = try_allocate_registers_without_spilling(ops)
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
expected_reg_assignments = {
op0.out: GPRRange(start=3, length=1),
def tst_try_alloc_concat(self, expected_regs, expected_dest_reg):
# type: (list[GPRRange], GPRRange) -> None
- li_ops = [OpLI(i, reg.length) for i, reg in enumerate(expected_regs)]
- ops = [*li_ops] # type: list[Op]
- concat = OpConcat([i.out for i in li_ops])
- ops.append(concat)
+ fn = Fn()
+ li_ops = [OpLI(fn, i, r.length) for i, r in enumerate(expected_regs)]
+ concat = OpConcat(fn, [i.out for i in li_ops])
- reg_assignments = try_allocate_registers_without_spilling(ops)
+ reg_assignments = try_allocate_registers_without_spilling(fn.ops)
expected_reg_assignments = {concat.dest: expected_dest_reg}
for li_op, reg in zip(li_ops, expected_regs):