"""
from itertools import combinations
-from typing import Any, Iterable, Iterator, Mapping, MutableSet
+from typing import Any, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Dict
from cached_property import cached_property
from nmutil.plain_data import plain_data
@final
-class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
+class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
def __init__(self):
# type: (...) -> None
- self.__merge_map = {} # type: dict[SSAVal, MergedSSAVal]
- self.__values_set = MergedSSAValsSet(
- _private_merge_map=self.__merge_map,
- _private_values_set=OSet())
+ self.__map = {} # type: dict[SSAVal, MergedSSAVal]
+ self.__ig_node_map = MergedSSAValToIGNodeMap(
+ _private_merged_ssa_val_map=self.__map)
def __getitem__(self, __key):
# type: (SSAVal) -> MergedSSAVal
- return self.__merge_map[__key]
+ return self.__map[__key]
def __iter__(self):
# type: () -> Iterator[SSAVal]
- return iter(self.__merge_map)
+ return iter(self.__map)
def __len__(self):
# type: () -> int
- return len(self.__merge_map)
+ return len(self.__map)
@property
- def values_set(self):
- # type: () -> MergedSSAValsSet
- return self.__values_set
+ def ig_node_map(self):
+ # type: () -> MergedSSAValToIGNodeMap
+ return self.__ig_node_map
def __repr__(self):
# type: () -> str
- s = ",\n".join(repr(v) for v in self.__values_set)
- return f"MergedSSAValsMap({{{s}}})"
+ s = ",\n".join(repr(v) for v in self.__ig_node_map)
+ return f"SSAValToMergedSSAValMap({{{s}}})"
@final
-class MergedSSAValsSet(MutableSet[MergedSSAVal]):
- def __init__(self, *,
- _private_merge_map, # type: dict[SSAVal, MergedSSAVal]
- _private_values_set, # type: OSet[MergedSSAVal]
- ):
+class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
+ def __init__(
+ self, *,
+ _private_merged_ssa_val_map, # type: dict[SSAVal, MergedSSAVal]
+ ):
# type: (...) -> None
- self.__merge_map = _private_merge_map
- self.__values_set = _private_values_set
+ self.__merged_ssa_val_map = _private_merged_ssa_val_map
+ self.__map = {} # type: dict[MergedSSAVal, IGNode]
- @classmethod
- def _from_iterable(cls, it):
- # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
- return OSet(it)
-
- def __contains__(self, value):
- # type: (MergedSSAVal | Any) -> bool
- return value in self.__values_set
+ def __getitem__(self, __key):
+ # type: (MergedSSAVal) -> IGNode
+ return self.__map[__key]
def __iter__(self):
# type: () -> Iterator[MergedSSAVal]
- return iter(self.__values_set)
+ return iter(self.__map)
def __len__(self):
# type: () -> int
- return len(self.__values_set)
+ return len(self.__map)
- def add(self, value):
- # type: (MergedSSAVal) -> None
- if value in self:
- return
+ def add_node(self, merged_ssa_val):
+ # type: (MergedSSAVal) -> IGNode
+ node = self.__map.get(merged_ssa_val, None)
+ if node is not None:
+ return node
added = 0 # type: int | None
try:
- for ssa_val in value.ssa_vals:
- if ssa_val in self.__merge_map:
+ for ssa_val in merged_ssa_val.ssa_vals:
+ if ssa_val in self.__merged_ssa_val_map:
raise ValueError(
f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
- f"{value} and {self.__merge_map[ssa_val]}")
- self.__merge_map[ssa_val] = value
+ f"{merged_ssa_val} and "
+ f"{self.__merged_ssa_val_map[ssa_val]}")
+ self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
added += 1
- self.__values_set.add(value)
+ retval = IGNode(merged_ssa_val)
+ self.__map[merged_ssa_val] = retval
added = None
+ return retval
finally:
if added is not None:
# remove partially added stuff
- for idx, ssa_val in enumerate(value.ssa_vals):
+ for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
if idx >= added:
break
- del self.__merge_map[ssa_val]
-
- def discard(self, value):
- # type: (MergedSSAVal) -> None
- if value not in self:
- return
+ del self.__merged_ssa_val_map[ssa_val]
+
+ def merge_into_one_node(self, final_merged_ssa_val):
+ # type: (MergedSSAVal) -> IGNode
+ source_nodes = {} # type: dict[MergedSSAVal, IGNode]
+ for ssa_val in final_merged_ssa_val.ssa_vals:
+ merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
+ source_nodes[merged_ssa_val] = self.__map[merged_ssa_val]
+ for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
+ raise ValueError(
+ f"SSAVal {i} appears in source IGNode's merged_ssa_val "
+ f"but not in merged IGNode's merged_ssa_val: "
+ f"source_node={self.__map[merged_ssa_val]} "
+ f"final_merged_ssa_val={final_merged_ssa_val}")
+ # FIXME: work on function from here
+ raise NotImplementedError
self.__values_set.discard(value)
for ssa_val in value.ssa_val_offsets.keys():
del self.__merge_map[ssa_val]
def __repr__(self):
# type: () -> str
- s = ",\n".join(repr(v) for v in self.__values_set)
- return f"MergedSSAValsSet({{{s}}})"
+ s = ",\n".join(repr(v) for v in self.__map.values())
+ return f"MergedSSAValToIGNodeMap({{{s}}})"
@plain_data(frozen=True)
@final
-class MergedSSAVals:
- __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals"
+class InterferenceGraph:
+ __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
def __init__(self, fn_analysis, merged_ssa_vals):
# type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
self.fn_analysis = fn_analysis
- self.merge_map = MergedSSAValsMap()
- self.merged_ssa_vals = self.merge_map.values_set
+ self.merged_ssa_val_map = SSAValToMergedSSAValMap()
+ self.nodes = self.merged_ssa_val_map.ig_node_map
for i in merged_ssa_vals:
- self.merged_ssa_vals.add(i)
+ self.nodes.add_node(i)
def merge(self, ssa_val1, ssa_val2, additional_offset=0):
- # type: (SSAVal, SSAVal, int) -> MergedSSAVal
- merged1 = self.merge_map[ssa_val1]
- merged2 = self.merge_map[ssa_val2]
+ # type: (SSAVal, SSAVal, int) -> IGNode
+ merged1 = self.merged_ssa_val_map[ssa_val1]
+ merged2 = self.merged_ssa_val_map[ssa_val2]
merged = merged1.with_offset_to_match(ssa_val1)
merged = merged.merged(merged2.with_offset_to_match(
ssa_val2, additional_offset=additional_offset))
- self.merged_ssa_vals.remove(merged1)
- self.merged_ssa_vals.remove(merged2)
- self.merged_ssa_vals.add(merged)
- return merged
+ return self.nodes.merge_into_one_node(merged)
@staticmethod
def minimally_merged(fn_analysis):
- # type: (FnAnalysis) -> MergedSSAVals
- retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=())
+ # type: (FnAnalysis) -> InterferenceGraph
+ retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
for op in fn_analysis.fn.ops:
for inp in op.input_uses:
if inp.unspread_start != inp:
retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
additional_offset=inp.reg_offset_in_unspread)
for out in op.outputs:
+ retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
if out.unspread_start != out:
retval.merge(out.unspread_start, out,
additional_offset=out.reg_offset_in_unspread)