working on code
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 5 Nov 2022 00:30:30 +0000 (17:30 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 5 Nov 2022 00:30:30 +0000 (17:30 -0700)
src/bigint_presentation_code/register_allocator2.py

index c7d9a88a79b873ed705441be57200867da86924d..20ca534b108db69184deaacdf4e96f44fe43d163 100644 (file)
@@ -6,7 +6,7 @@ this uses an algorithm based on:
 """
 
 from itertools import combinations
-from typing import Any, Iterable, Iterator, Mapping, MutableSet
+from typing import Any, Iterable, Iterator, Mapping, MutableMapping, MutableSet, Dict
 
 from cached_property import cached_property
 from nmutil.plain_data import plain_data
@@ -238,136 +238,142 @@ class MergedSSAVal:
 
 
 @final
-class MergedSSAValsMap(Mapping[SSAVal, MergedSSAVal]):
+class SSAValToMergedSSAValMap(Mapping[SSAVal, MergedSSAVal]):
     def __init__(self):
         # type: (...) -> None
-        self.__merge_map = {}  # type: dict[SSAVal, MergedSSAVal]
-        self.__values_set = MergedSSAValsSet(
-            _private_merge_map=self.__merge_map,
-            _private_values_set=OSet())
+        self.__map = {}  # type: dict[SSAVal, MergedSSAVal]
+        self.__ig_node_map = MergedSSAValToIGNodeMap(
+            _private_merged_ssa_val_map=self.__map)
 
     def __getitem__(self, __key):
         # type: (SSAVal) -> MergedSSAVal
-        return self.__merge_map[__key]
+        return self.__map[__key]
 
     def __iter__(self):
         # type: () -> Iterator[SSAVal]
-        return iter(self.__merge_map)
+        return iter(self.__map)
 
     def __len__(self):
         # type: () -> int
-        return len(self.__merge_map)
+        return len(self.__map)
 
     @property
-    def values_set(self):
-        # type: () -> MergedSSAValsSet
-        return self.__values_set
+    def ig_node_map(self):
+        # type: () -> MergedSSAValToIGNodeMap
+        return self.__ig_node_map
 
     def __repr__(self):
         # type: () -> str
-        s = ",\n".join(repr(v) for v in self.__values_set)
-        return f"MergedSSAValsMap({{{s}}})"
+        s = ",\n".join(repr(v) for v in self.__ig_node_map)
+        return f"SSAValToMergedSSAValMap({{{s}}})"
 
 
 @final
-class MergedSSAValsSet(MutableSet[MergedSSAVal]):
-    def __init__(self, *,
-                 _private_merge_map,  # type: dict[SSAVal, MergedSSAVal]
-                 _private_values_set,  # type: OSet[MergedSSAVal]
-                 ):
+class MergedSSAValToIGNodeMap(Mapping[MergedSSAVal, IGNode]):
+    def __init__(
+        self, *,
+        _private_merged_ssa_val_map,  # type: dict[SSAVal, MergedSSAVal]
+    ):
         # type: (...) -> None
-        self.__merge_map = _private_merge_map
-        self.__values_set = _private_values_set
+        self.__merged_ssa_val_map = _private_merged_ssa_val_map
+        self.__map = {}  # type: dict[MergedSSAVal, IGNode]
 
-    @classmethod
-    def _from_iterable(cls, it):
-        # type: (Iterable[MergedSSAVal]) -> OSet[MergedSSAVal]
-        return OSet(it)
-
-    def __contains__(self, value):
-        # type: (MergedSSAVal | Any) -> bool
-        return value in self.__values_set
+    def __getitem__(self, __key):
+        # type: (MergedSSAVal) -> IGNode
+        return self.__map[__key]
 
     def __iter__(self):
         # type: () -> Iterator[MergedSSAVal]
-        return iter(self.__values_set)
+        return iter(self.__map)
 
     def __len__(self):
         # type: () -> int
-        return len(self.__values_set)
+        return len(self.__map)
 
-    def add(self, value):
-        # type: (MergedSSAVal) -> None
-        if value in self:
-            return
+    def add_node(self, merged_ssa_val):
+        # type: (MergedSSAVal) -> IGNode
+        node = self.__map.get(merged_ssa_val, None)
+        if node is not None:
+            return node
         added = 0  # type: int | None
         try:
-            for ssa_val in value.ssa_vals:
-                if ssa_val in self.__merge_map:
+            for ssa_val in merged_ssa_val.ssa_vals:
+                if ssa_val in self.__merged_ssa_val_map:
                     raise ValueError(
                         f"overlapping `MergedSSAVal`s: {ssa_val} is in both "
-                        f"{value} and {self.__merge_map[ssa_val]}")
-                self.__merge_map[ssa_val] = value
+                        f"{merged_ssa_val} and "
+                        f"{self.__merged_ssa_val_map[ssa_val]}")
+                self.__merged_ssa_val_map[ssa_val] = merged_ssa_val
                 added += 1
-            self.__values_set.add(value)
+            retval = IGNode(merged_ssa_val)
+            self.__map[merged_ssa_val] = retval
             added = None
+            return retval
         finally:
             if added is not None:
                 # remove partially added stuff
-                for idx, ssa_val in enumerate(value.ssa_vals):
+                for idx, ssa_val in enumerate(merged_ssa_val.ssa_vals):
                     if idx >= added:
                         break
-                    del self.__merge_map[ssa_val]
-
-    def discard(self, value):
-        # type: (MergedSSAVal) -> None
-        if value not in self:
-            return
+                    del self.__merged_ssa_val_map[ssa_val]
+
+    def merge_into_one_node(self, final_merged_ssa_val):
+        # type: (MergedSSAVal) -> IGNode
+        source_nodes = {}  # type: dict[MergedSSAVal, IGNode]
+        for ssa_val in final_merged_ssa_val.ssa_vals:
+            merged_ssa_val = self.__merged_ssa_val_map[ssa_val]
+            source_nodes[merged_ssa_val] = self.__map[merged_ssa_val]
+            for i in merged_ssa_val.ssa_vals - final_merged_ssa_val.ssa_vals:
+                raise ValueError(
+                    f"SSAVal {i} appears in source IGNode's merged_ssa_val "
+                    f"but not in merged IGNode's merged_ssa_val: "
+                    f"source_node={self.__map[merged_ssa_val]} "
+                    f"final_merged_ssa_val={final_merged_ssa_val}")
+        # FIXME: work on function from here
+        raise NotImplementedError
         self.__values_set.discard(value)
         for ssa_val in value.ssa_val_offsets.keys():
             del self.__merge_map[ssa_val]
 
     def __repr__(self):
         # type: () -> str
-        s = ",\n".join(repr(v) for v in self.__values_set)
-        return f"MergedSSAValsSet({{{s}}})"
+        s = ",\n".join(repr(v) for v in self.__map.values())
+        return f"MergedSSAValToIGNodeMap({{{s}}})"
 
 
 @plain_data(frozen=True)
 @final
-class MergedSSAVals:
-    __slots__ = "fn_analysis", "merge_map", "merged_ssa_vals"
+class InterferenceGraph:
+    __slots__ = "fn_analysis", "merged_ssa_val_map", "nodes"
 
     def __init__(self, fn_analysis, merged_ssa_vals):
         # type: (FnAnalysis, Iterable[MergedSSAVal]) -> None
         self.fn_analysis = fn_analysis
-        self.merge_map = MergedSSAValsMap()
-        self.merged_ssa_vals = self.merge_map.values_set
+        self.merged_ssa_val_map = SSAValToMergedSSAValMap()
+        self.nodes = self.merged_ssa_val_map.ig_node_map
         for i in merged_ssa_vals:
-            self.merged_ssa_vals.add(i)
+            self.nodes.add_node(i)
 
     def merge(self, ssa_val1, ssa_val2, additional_offset=0):
-        # type: (SSAVal, SSAVal, int) -> MergedSSAVal
-        merged1 = self.merge_map[ssa_val1]
-        merged2 = self.merge_map[ssa_val2]
+        # type: (SSAVal, SSAVal, int) -> IGNode
+        merged1 = self.merged_ssa_val_map[ssa_val1]
+        merged2 = self.merged_ssa_val_map[ssa_val2]
         merged = merged1.with_offset_to_match(ssa_val1)
         merged = merged.merged(merged2.with_offset_to_match(
             ssa_val2, additional_offset=additional_offset))
-        self.merged_ssa_vals.remove(merged1)
-        self.merged_ssa_vals.remove(merged2)
-        self.merged_ssa_vals.add(merged)
-        return merged
+        return self.nodes.merge_into_one_node(merged)
 
     @staticmethod
     def minimally_merged(fn_analysis):
-        # type: (FnAnalysis) -> MergedSSAVals
-        retval = MergedSSAVals(fn_analysis=fn_analysis, merged_ssa_vals=())
+        # type: (FnAnalysis) -> InterferenceGraph
+        retval = InterferenceGraph(fn_analysis=fn_analysis, merged_ssa_vals=())
         for op in fn_analysis.fn.ops:
             for inp in op.input_uses:
                 if inp.unspread_start != inp:
                     retval.merge(inp.unspread_start.ssa_val, inp.ssa_val,
                                  additional_offset=inp.reg_offset_in_unspread)
             for out in op.outputs:
+                retval.nodes.add_node(MergedSSAVal(fn_analysis, out))
                 if out.unspread_start != out:
                     retval.merge(out.unspread_start, out,
                                  additional_offset=out.reg_offset_in_unspread)