From 79a39b6aa2f75bd07d5e411d4605471ac7d741d2 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 6 Jan 2023 13:20:11 -0800 Subject: [PATCH] add DisjointSets, a disjoint-set data structure --- src/bigint_presentation_code/util.py | 106 ++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 3 deletions(-) diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index 9f1e5ab..c697e30 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -1,9 +1,10 @@ from abc import ABCMeta, abstractmethod from collections import defaultdict -from typing import (AbstractSet, Any, Callable, Iterable, Iterator, Mapping, - MutableSet, TypeVar, overload) +from typing import (AbstractSet, Any, Callable, Generic, Iterable, Iterator, + Mapping, MutableSet, NewType, TypeVar, overload) from bigint_presentation_code.type_util import Self, final +from nmutil.plain_data import plain_data _T_co = TypeVar("_T_co", covariant=True) _T = TypeVar("_T") @@ -13,13 +14,15 @@ __all__ = [ "BaseBitSet", "bit_count", "BitSet", + "DisjointSets", + "DisjointSetsItem", "FBitSet", "FMap", + "Interned", "OFSet", "OSet", "top_set_bit_index", "trailing_zero_count", - "Interned", ] @@ -495,3 +498,100 @@ class FBitSet(BaseBitSet, Interned): def __hash__(self): # type: () -> int return super()._hash() + + +DisjointSetsItem = NewType("DisjointSetsItem", int) + + +@plain_data() +@final +class _DisjointSetsEntry(Generic[_T_co]): + __slots__ = "value", "parent", "rank" + + def __init__(self, value, parent, rank): + # type: (_T_co, DisjointSetsItem, int) -> None + self.value = value + self.parent = parent + self.rank = rank + + +@final +class DisjointSets(Generic[_T_co]): + """ Disjoint-set data structure, aka. union-find or merge-find + https://en.wikipedia.org/wiki/Disjoint-set_data_structure + """ + + def __init__(self): + self.__values = [] # type: list[_DisjointSetsEntry[_T_co]] + + def __entry(self, __key): + # type: (DisjointSetsItem) -> _DisjointSetsEntry[_T_co] + if __key < 0 or __key >= len(self.__values): + raise KeyError(__key) + return self.__values[__key] + + def __getitem__(self, __key): + # type: (DisjointSetsItem) -> _T_co + return self.__entry(__key).value + + def __setitem__(self, __key, __value): + # type: (DisjointSetsItem, _T_co) -> None + self.__entry(__key).value = __value + + def __len__(self): + # type: () -> int + return len(self.__values) + + @property + def representatives(self): + # type: () -> Iterator[DisjointSetsItem] + for i, entry in enumerate(self.__values): + item = DisjointSetsItem(i) + if entry.parent == item: + yield item + + def __iter__(self): + # type: () -> Iterator[DisjointSetsItem] + return map(DisjointSetsItem, range(len(self.__values))) + + def add_new_set(self, value): + # type: (_T_co) -> DisjointSetsItem + item = DisjointSetsItem(len(self.__values)) + self.__values.append(_DisjointSetsEntry( + value=value, parent=item, rank=0)) + return item + + def find_representative(self, item): + # type: (DisjointSetsItem) -> DisjointSetsItem + entry = self.__entry(item) + while entry.parent != item: + parent_entry = self.__values[entry.parent] + item = entry.parent = parent_entry.parent + entry = self.__values[item] + return item + + def merge(self, __x, __y): + # type: (DisjointSetsItem, DisjointSetsItem) -> DisjointSetsItem + __x = self.find_representative(__x) + __y = self.find_representative(__y) + if __x == __y: + return __x + x_entry = self.__values[__x] + y_entry = self.__values[__y] + if x_entry.rank < y_entry.rank: + __x, __y = __y, __x + x_entry, y_entry = y_entry, x_entry + y_entry.parent = __x + if x_entry.rank == y_entry.rank: + x_entry.rank += 1 + return __x + + def __repr__(self): + # type: () -> str + sets = defaultdict( + list) # type: dict[DisjointSetsItem, list[DisjointSetsItem]] + values = {} + for item in self: + sets[self.find_representative(item)].append(item) + values[item] = self[item] + return f"DisjointSets(sets={sets}, values={values})" -- 2.30.2