from enum import Enum, EnumMeta, unique
from functools import lru_cache
from typing import (TYPE_CHECKING, AbstractSet, Generic, Iterable, Sequence,
- TypeVar)
+ TypeVar, cast)
-from nmutil.plain_data import plain_data
+from cached_property import cached_property
+from nmutil.plain_data import fields, plain_data
if TYPE_CHECKING:
from typing_extensions import final
return ...
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+_RegType = TypeVar("_RegType", bound=RegType)
@plain_data(frozen=True, eq=False)
return hash(self.length_in_slots)
-@plain_data(frozen=True, eq=False)
+@plain_data(frozen=True, eq=False, repr=False)
@final
-class SSAVal(Generic[_RegT_co]):
- __slots__ = "op", "arg_name", "ty", "arg_index"
+class SSAVal(Generic[_RegType]):
+ __slots__ = "op", "arg_name", "ty",
def __init__(self, op, arg_name, ty):
- # type: (Op, str, _RegT_co) -> None
+ # type: (Op, str, _RegType) -> None
self.op = op
"""the Op that writes this SSAVal"""
def __hash__(self):
return hash((id(self.op), self.arg_name))
+ def __repr__(self):
+ fields_list = []
+ for name in fields(self):
+ v = getattr(self, name, None)
+ if v is not None:
+ if name == "op":
+ v = v.__repr__(just_id=True)
+ else:
+ v = repr(v)
+ fields_list.append(f"{name}={v}")
+ fields_str = ", ".join(fields_list)
+ return f"SSAVal({fields_str})"
+
@final
@plain_data(unsafe_hash=True, frozen=True)
raise ValueError("can't constrain an empty list to be equal")
-@plain_data(unsafe_hash=True, frozen=True)
+class _NotSet:
+ """ helper for __repr__ for when fields aren't set """
+
+ def __repr__(self):
+ return "<not set>"
+
+
+_NOT_SET = _NotSet()
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
class Op(metaclass=ABCMeta):
__slots__ = ()
if False:
yield ...
- def __init__(self):
- pass
+ __NEXT_ID = 0
+ @cached_property
+ def id(self):
+ retval = Op.__NEXT_ID
+ Op.__NEXT_ID += 1
+ return retval
-@plain_data(unsafe_hash=True, frozen=True)
+ @final
+ def __repr__(self, just_id=False):
+ fields_list = [f"#{self.id}"]
+ if not just_id:
+ for name in fields(self):
+ v = getattr(self, name, _NOT_SET)
+ fields_list.append(f"{name}={v!r}")
+ fields_str = ', '.join(fields_list)
+ return f"{self.__class__.__name__}({fields_str})"
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLoadFromStackSlot(Op):
__slots__ = "dest", "src"
self.src = src
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpStoreToStackSlot(Op):
__slots__ = "dest", "src"
self.src = src
-@plain_data(unsafe_hash=True, frozen=True)
+_RegSrcType = TypeVar("_RegSrcType", bound=RegType)
+
+
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
-class OpCopy(Op, Generic[_RegT_co]):
+class OpCopy(Op, Generic[_RegSrcType, _RegType]):
__slots__ = "dest", "src"
def inputs(self):
return {"dest": self.dest}
def __init__(self, src, dest_ty=None):
- # type: (SSAVal[_RegT_co], _RegT_co | None) -> None
+ # type: (SSAVal[_RegSrcType], _RegType | None) -> None
if dest_ty is None:
- dest_ty = src.ty
+ dest_ty = cast(_RegType, src.ty)
if isinstance(src.ty, GPRRangeType) \
and isinstance(dest_ty, GPRRangeType):
if src.ty.length != dest_ty.length:
raise ValueError(f"incompatible source and destination "
f"types: {src.ty} and {dest_ty}")
- self.dest = SSAVal(self, "dest", dest_ty)
+ self.dest = SSAVal(self, "dest", dest_ty) # type: SSAVal[_RegType]
self.src = src
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpConcat(Op):
__slots__ = "dest", "sources"
yield EqualityConstraint([self.dest], [*self.sources])
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpSplit(Op):
__slots__ = "results", "src"
yield EqualityConstraint([*self.results], [self.src])
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpAddSubE(Op):
__slots__ = "RT", "RA", "RB", "CY_in", "CY_out", "is_sub"
yield self.RT, self.RB
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpBigIntMulDiv(Op):
__slots__ = "RT", "RA", "RB", "RC", "RS", "is_div"
Sra = "sra"
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpBigIntShift(Op):
__slots__ = "RT", "inp", "sh", "kind"
yield self.RT, self.sh
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLI(Op):
__slots__ = "out", "value"
self.value = value
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpClearCY(Op):
__slots__ = "out",
self.out = SSAVal(self, "out", CYType())
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpLoad(Op):
__slots__ = "RT", "RA", "offset", "mem"
yield self.RT, self.RA
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpStore(Op):
__slots__ = "RS", "RA", "offset", "mem_in", "mem_out"
self.mem_out = SSAVal(self, "mem_out", mem_in.ty)
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpFuncArg(Op):
__slots__ = "out",
self.out = SSAVal(self, "out", ty)
-@plain_data(unsafe_hash=True, frozen=True)
+@plain_data(unsafe_hash=True, frozen=True, repr=False)
@final
class OpInputMem(Op):
__slots__ = "out",
def op_set_to_list(ops):
# type: (Iterable[Op]) -> list[Op]
- worklists = [set()] # type: list[set[Op]]
- input_vals_to_ops_map = defaultdict(set) # type: dict[SSAVal, set[Op]]
+ worklists = [{}] # type: list[dict[Op, None]]
+ inps_to_ops_map = defaultdict(dict) # type: dict[SSAVal, dict[Op, None]]
ops_to_pending_input_count_map = {} # type: dict[Op, int]
for op in ops:
input_count = 0
for val in op.inputs().values():
input_count += 1
- input_vals_to_ops_map[val].add(op)
+ inps_to_ops_map[val][op] = None
while len(worklists) <= input_count:
- worklists.append(set())
+ worklists.append({})
ops_to_pending_input_count_map[op] = input_count
- worklists[input_count].add(op)
+ worklists[input_count][op] = None
retval = [] # type: list[Op]
ready_vals = set() # type: set[SSAVal]
while len(worklists[0]) != 0:
- writing_op = worklists[0].pop()
+ writing_op = next(iter(worklists[0]))
+ del worklists[0][writing_op]
retval.append(writing_op)
for val in writing_op.outputs().values():
if val in ready_vals:
raise ValueError(f"multiple instructions must not write "
f"to the same SSA value: {val}")
ready_vals.add(val)
- for reading_op in input_vals_to_ops_map[val]:
+ for reading_op in inps_to_ops_map[val]:
pending = ops_to_pending_input_count_map[reading_op]
- worklists[pending].remove(reading_op)
+ del worklists[pending][reading_op]
pending -= 1
- worklists[pending].add(reading_op)
+ worklists[pending][reading_op] = None
ops_to_pending_input_count_map[reading_op] = pending
for worklist in worklists:
for op in worklist:
return v
-_RegT_co = TypeVar("_RegT_co", bound=RegType, covariant=True)
+_RegType = TypeVar("_RegType", bound=RegType)
@plain_data(unsafe_hash=True, order=True, frozen=True)
@final
-class MergedRegSet(Mapping[SSAVal[_RegT_co], int]):
+class MergedRegSet(Mapping[SSAVal[_RegType], int]):
def __init__(self, reg_set):
- # type: (Iterable[tuple[SSAVal[_RegT_co], int]] | SSAVal[_RegT_co]) -> None
- self.__items = {} # type: dict[SSAVal[_RegT_co], int]
+ # type: (Iterable[tuple[SSAVal[_RegType], int]] | SSAVal[_RegType]) -> None
+ self.__items = {} # type: dict[SSAVal[_RegType], int]
if isinstance(reg_set, SSAVal):
reg_set = [(reg_set, 0)]
for ssa_val, offset in reg_set:
@staticmethod
def from_equality_constraint(constraint_sequence):
- # type: (list[SSAVal[_RegT_co]]) -> MergedRegSet[_RegT_co]
+ # type: (list[SSAVal[_RegType]]) -> MergedRegSet[_RegType]
if len(constraint_sequence) == 1:
# any type allowed with len = 1
return MergedRegSet(constraint_sequence[0])
return range(self.__start, self.__stop)
def offset_by(self, amount):
- # type: (int) -> MergedRegSet[_RegT_co]
+ # type: (int) -> MergedRegSet[_RegType]
return MergedRegSet((k, v + amount) for k, v in self.items())
def normalized(self):
- # type: () -> MergedRegSet[_RegT_co]
+ # type: () -> MergedRegSet[_RegType]
return self.offset_by(-self.start)
def with_offset_to_match(self, target):
- # type: (MergedRegSet[_RegT_co]) -> MergedRegSet[_RegT_co]
+ # type: (MergedRegSet[_RegType]) -> MergedRegSet[_RegType]
for ssa_val, offset in self.items():
if ssa_val in target:
return self.offset_by(target[ssa_val] - offset)
raise ValueError("can't change offset to match unrelated MergedRegSet")
def __getitem__(self, item):
- # type: (SSAVal[_RegT_co]) -> int
+ # type: (SSAVal[_RegType]) -> int
return self.__items[item]
def __iter__(self):
@final
-class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegT_co]], Generic[_RegT_co]):
+class MergedRegSets(Mapping[SSAVal, MergedRegSet[_RegType]], Generic[_RegType]):
def __init__(self, ops):
# type: (Iterable[Op]) -> None
- merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegT_co]]
+ merged_sets = {} # type: dict[SSAVal, MergedRegSet[_RegType]]
for op in ops:
for val in (*op.inputs().values(), *op.outputs().values()):
if val not in merged_sets:
@final
-class LiveIntervals(Mapping[MergedRegSet[_RegT_co], LiveInterval]):
+class LiveIntervals(Mapping[MergedRegSet[_RegType], LiveInterval]):
def __init__(self, ops):
# type: (list[Op]) -> None
self.__merged_reg_sets = MergedRegSets(ops)
- live_intervals = {} # type: dict[MergedRegSet[_RegT_co], LiveInterval]
+ live_intervals = {} # type: dict[MergedRegSet[_RegType], LiveInterval]
for op_idx, op in enumerate(ops):
for val in op.inputs().values():
live_intervals[self.__merged_reg_sets[val]] += op_idx
else:
live_intervals[reg_set] += op_idx
self.__live_intervals = live_intervals
- live_after = [] # type: list[set[MergedRegSet[_RegT_co]]]
+ live_after = [] # type: list[set[MergedRegSet[_RegType]]]
live_after += (set() for _ in ops)
for reg_set, live_interval in self.__live_intervals.items():
for i in live_interval.live_after_op_range:
return self.__merged_reg_sets
def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> LiveInterval
+ # type: (MergedRegSet[_RegType]) -> LiveInterval
return self.__live_intervals[key]
def __iter__(self):
return iter(self.__live_intervals)
def reg_sets_live_after(self, op_index):
- # type: (int) -> frozenset[MergedRegSet[_RegT_co]]
+ # type: (int) -> frozenset[MergedRegSet[_RegType]]
return self.__live_after[op_index]
def __repr__(self):
@final
-class IGNode(Generic[_RegT_co]):
+class IGNode(Generic[_RegType]):
""" interference graph node """
__slots__ = "merged_reg_set", "edges", "reg"
def __init__(self, merged_reg_set, edges=(), reg=None):
- # type: (MergedRegSet[_RegT_co], Iterable[IGNode], RegLoc | None) -> None
+ # type: (MergedRegSet[_RegType], Iterable[IGNode], RegLoc | None) -> None
self.merged_reg_set = merged_reg_set
self.edges = set(edges)
self.reg = reg
@final
-class InterferenceGraph(Mapping[MergedRegSet[_RegT_co], IGNode[_RegT_co]]):
+class InterferenceGraph(Mapping[MergedRegSet[_RegType], IGNode[_RegType]]):
def __init__(self, merged_reg_sets):
- # type: (Iterable[MergedRegSet[_RegT_co]]) -> None
+ # type: (Iterable[MergedRegSet[_RegType]]) -> None
self.__nodes = {i: IGNode(i) for i in merged_reg_sets}
def __getitem__(self, key):
- # type: (MergedRegSet[_RegT_co]) -> IGNode
+ # type: (MergedRegSet[_RegType]) -> IGNode
return self.__nodes[key]
def __iter__(self):
import unittest
-from bigint_presentation_code.compiler_ir import Op, op_set_to_list
+from bigint_presentation_code.compiler_ir import (FixedGPRRangeType, GPRRange, GPRType,
+ Op, OpAddSubE, OpClearCY, OpConcat, OpCopy, OpFuncArg, OpInputMem, OpLI, OpLoad, OpStore,
+ op_set_to_list)
class TestCompilerIR(unittest.TestCase):
- pass # no tests yet, just testing importing
+ maxDiff = None
+
+ def test_op_set_to_list(self):
+ ops = [] # list[Op]
+ op0 = OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+ ops.append(op0)
+ op1 = OpCopy(op0.out, GPRType())
+ ops.append(op1)
+ arg = op1.dest
+ op2 = OpInputMem()
+ ops.append(op2)
+ mem = op2.out
+ op3 = OpLoad(arg, offset=0, mem=mem, length=32)
+ ops.append(op3)
+ a = op3.RT
+ op4 = OpLI(1)
+ ops.append(op4)
+ b_0 = op4.out
+ op5 = OpLI(0, length=31)
+ ops.append(op5)
+ b_rest = op5.out
+ op6 = OpConcat([b_0, b_rest])
+ ops.append(op6)
+ b = op6.dest
+ op7 = OpClearCY()
+ ops.append(op7)
+ cy = op7.out
+ op8 = OpAddSubE(a, b, cy, is_sub=False)
+ ops.append(op8)
+ s = op8.RT
+ op9 = OpStore(s, arg, offset=0, mem_in=mem)
+ ops.append(op9)
+ mem = op9.mem_out
+
+ expected_ops = [
+ op7, # OpClearCY()
+ op5, # OpLI(0, length=31)
+ op4, # OpLI(1)
+ op2, # OpInputMem()
+ op0, # OpFuncArg(FixedGPRRangeType(GPRRange(3)))
+ op6, # OpConcat([b_0, b_rest])
+ op1, # OpCopy(op0.out, GPRType())
+ op3, # OpLoad(arg, offset=0, mem=mem, length=32)
+ op8, # OpAddSubE(a, b, cy, is_sub=False)
+ op9, # OpStore(s, arg, offset=0, mem_in=mem)
+ ]
+
+ ops = op_set_to_list(reversed(ops))
+ if ops != expected_ops:
+ self.assertEqual(repr(ops), repr(expected_ops))
if __name__ == "__main__":