1 from abc
import ABCMeta
, abstractmethod
2 from collections
import defaultdict
3 from typing
import (AbstractSet
, Any
, Callable
, Iterable
, Iterator
, Mapping
,
4 MutableSet
, TypeVar
, overload
)
6 from bigint_presentation_code
.type_util
import Self
, final
8 _T_co
= TypeVar("_T_co", covariant
=True)
21 "trailing_zero_count",
26 class _InternedMeta(ABCMeta
):
27 def __call__(self
, *args
: Any
, **kwds
: Any
) -> Any
:
28 return super().__call
__(*args
, **kwds
)._Interned
__intern
()
31 class Interned(metaclass
=_InternedMeta
):
32 def __init_intern(self
):
33 # type: (Self) -> Self
35 old_hash
= cls
.__hash
__
36 old_hash
= getattr(old_hash
, "_Interned__old_hash", old_hash
)
38 old_eq
= getattr(old_eq
, "_Interned__old_eq", old_eq
)
42 return self
._Interned
__hash
# type: ignore
43 __hash__
._Interned
__old
_hash
= old_hash
# type: ignore
44 cls
.__hash
__ = __hash__
46 def __eq__(self
, # type: Self
48 *, __eq
=old_eq
, # type: Callable[[Self, Any], bool]
51 if self
.__class
__ is __other
.__class
__:
52 return self
is __other
53 return __eq(self
, __other
)
54 __eq__
._Interned
__old
_eq
= old_eq
# type: ignore
57 table
= defaultdict(list) # type: dict[int, list[Self]]
59 def __intern(self
, # type: Self
60 *, __hash
=old_hash
, # type: Callable[[Self], int]
61 __eq
=old_eq
, # type: Callable[[Self, Any], bool]
62 __table
=table
, # type: dict[int, list[Self]]
63 __NotImplemented
=NotImplemented, # type: Any
70 if v
is not __NotImplemented
and v
:
72 self
.__dict
__["_Interned__hash"] = h
75 cls
._Interned
__intern
= __intern
78 _Interned__intern
= __init_intern
81 class OFSet(AbstractSet
[_T_co
], Interned
):
82 """ ordered frozen set """
83 __slots__
= "__items", "__dict__", "__weakref__"
85 def __init__(self
, items
=()):
86 # type: (Iterable[_T_co]) -> None
88 if isinstance(items
, OFSet
):
89 self
.__items
= items
.__items
91 self
.__items
= {v
: None for v
in items
}
93 def __contains__(self
, x
):
95 return x
in self
.__items
98 # type: () -> Iterator[_T_co]
99 return iter(self
.__items
)
103 return len(self
.__items
)
113 return f
"OFSet({list(self)})"
116 class OSet(MutableSet
[_T
]):
117 """ ordered mutable set """
118 __slots__
= "__items", "__dict__"
120 def __init__(self
, items
=()):
121 # type: (Iterable[_T]) -> None
123 self
.__items
= {v
: None for v
in items
}
125 def __contains__(self
, x
):
126 # type: (Any) -> bool
127 return x
in self
.__items
130 # type: () -> Iterator[_T]
131 return iter(self
.__items
)
135 return len(self
.__items
)
137 def add(self
, value
):
139 self
.__items
[value
] = None
141 def discard(self
, value
):
143 self
.__items
.pop(value
, None)
145 def remove(self
, value
):
147 del self
.__items
[value
]
151 return self
.__items
.popitem()[0]
161 return f
"OSet({list(self)})"
164 class FMap(Mapping
[_T
, _T_co
], Interned
):
165 """ordered frozen hashable mapping"""
166 __slots__
= "__items", "__hash", "__dict__", "__weakref__"
169 def __init__(self
, items
):
170 # type: (Mapping[_T, _T_co]) -> None
174 def __init__(self
, items
):
175 # type: (Iterable[tuple[_T, _T_co]]) -> None
183 def __init__(self
, items
=()):
184 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
186 self
.__items
= dict(items
) # type: dict[_T, _T_co]
187 self
.__hash
= None # type: None | int
189 def __getitem__(self
, item
):
190 # type: (_T) -> _T_co
191 return self
.__items
[item
]
194 # type: () -> Iterator[_T]
195 return iter(self
.__items
)
199 return len(self
.__items
)
201 def __eq__(self
, other
):
202 # type: (FMap[Any, Any] | Any) -> bool
203 if isinstance(other
, FMap
):
204 return self
.__items
== other
.__items
205 return super().__eq
__(other
)
209 if self
.__hash
is None:
210 self
.__hash
= hash(frozenset(self
.items()))
215 return f
"FMap({self.__items})"
217 def get(self
, key
, default
=None):
218 # type: (_T, _T_co | _T2) -> _T_co | _T2
219 return self
.__items
.get(key
, default
)
221 def __contains__(self
, key
):
222 # type: (_T | object) -> bool
223 return key
in self
.__items
226 def trailing_zero_count(v
, default
=-1):
227 # type: (int, int) -> int
228 without_bit
= v
& (v
- 1) # clear lowest set bit
229 bit
= v
& ~without_bit
# extract lowest set bit
230 return top_set_bit_index(bit
, default
)
233 def top_set_bit_index(v
, default
=-1):
234 # type: (int, int) -> int
237 return v
.bit_length() - 1
241 # added in cpython 3.10
242 bit_count
= int.bit_count
# type: ignore
243 except AttributeError:
246 """returns the number of 1 bits in the absolute value of the input"""
247 return bin(abs(v
)).count('1')
250 class BaseBitSet(AbstractSet
[int]):
251 __slots__
= "__bits", "__dict__", "__weakref__"
260 def _from_bits(cls
, bits
):
261 # type: (int) -> Self
262 return cls(bits
=bits
)
264 def __init__(self
, items
=(), bits
=0):
265 # type: (Iterable[int], int) -> None
267 if isinstance(items
, BaseBitSet
):
272 raise ValueError("can't store negative integers")
275 raise ValueError("can't store an infinite set")
284 def bits(self
, bits
):
285 # type: (int) -> None
287 raise AttributeError("can't write to frozen bitset's bits")
289 raise ValueError("can't store an infinite set")
292 def __contains__(self
, x
):
293 # type: (Any) -> bool
294 if isinstance(x
, int) and x
>= 0:
295 return (1 << x
) & self
.bits
!= 0
299 # type: () -> Iterator[int]
302 index
= trailing_zero_count(bits
)
306 def __reversed__(self
):
307 # type: () -> Iterator[int]
310 index
= top_set_bit_index(bits
)
316 return bit_count(self
.bits
)
321 return f
"{self.__class__.__name__}()"
325 return f
"{self.__class__.__name__}({v})"
326 ranges
= [] # type: list[range]
329 if len(ranges
) != 0 and ranges
[-1].stop
== i
:
331 ranges
[-1].start
, i
+ ranges
[-1].step
, ranges
[-1].step
)
332 elif len(ranges
) != 0 and len(ranges
[-1]) == 1:
333 start
= ranges
[-1][0]
336 ranges
[-1] = range(start
, stop
, step
)
337 elif len(ranges
) != 0 and len(ranges
[-1]) == 2:
338 single
= ranges
[-1][0]
339 start
= ranges
[-1][1]
340 ranges
[-1] = range(single
, single
+ 1)
343 ranges
.append(range(start
, stop
, step
))
345 ranges
.append(range(i
, i
+ 1))
346 if len(ranges
) > MAX_RANGES
:
349 return f
"{self.__class__.__name__}({ranges[0]})"
350 if len(ranges
) <= MAX_RANGES
:
351 range_strs
= [] # type: list[str]
354 range_strs
.append(str(r
[0]))
356 range_strs
.append(f
"*{r}")
357 ranges_str
= ", ".join(range_strs
)
358 return f
"{self.__class__.__name__}([{ranges_str}])"
359 if self
.bits
> 0xFFFFFFFF and len_self
< 10:
361 return f
"{self.__class__.__name__}({v})"
362 return f
"{self.__class__.__name__}(bits={hex(self.bits)})"
364 def __eq__(self
, other
):
365 # type: (Any) -> bool
366 if not isinstance(other
, BaseBitSet
):
367 return super().__eq
__(other
)
368 return self
.bits
== other
.bits
370 def __and__(self
, other
):
371 # type: (Iterable[Any]) -> Self
372 if isinstance(other
, BaseBitSet
):
373 return self
._from
_bits
(self
.bits
& other
.bits
)
376 if isinstance(item
, int) and item
>= 0:
378 return self
._from
_bits
(self
.bits
& bits
)
382 def __or__(self
, other
):
383 # type: (Iterable[Any]) -> Self
384 if isinstance(other
, BaseBitSet
):
385 return self
._from
_bits
(self
.bits | other
.bits
)
388 if isinstance(item
, int) and item
>= 0:
390 return self
._from
_bits
(bits
)
394 def __xor__(self
, other
):
395 # type: (Iterable[Any]) -> Self
396 if isinstance(other
, BaseBitSet
):
397 return self
._from
_bits
(self
.bits ^ other
.bits
)
400 if isinstance(item
, int) and item
>= 0:
402 return self
._from
_bits
(bits
)
406 def __sub__(self
, other
):
407 # type: (Iterable[Any]) -> Self
408 if isinstance(other
, BaseBitSet
):
409 return self
._from
_bits
(self
.bits
& ~other
.bits
)
412 if isinstance(item
, int) and item
>= 0:
414 return self
._from
_bits
(bits
)
416 def __rsub__(self
, other
):
417 # type: (Iterable[Any]) -> Self
418 if isinstance(other
, BaseBitSet
):
419 return self
._from
_bits
(~self
.bits
& other
.bits
)
422 if isinstance(item
, int) and item
>= 0:
424 return self
._from
_bits
(~self
.bits
& bits
)
426 def isdisjoint(self
, other
):
427 # type: (Iterable[Any]) -> bool
428 if isinstance(other
, BaseBitSet
):
429 return self
.bits
& other
.bits
== 0
430 return super().isdisjoint(other
)
433 class BitSet(BaseBitSet
, MutableSet
[int]):
434 """Mutable Bit Set"""
442 def add(self
, value
):
443 # type: (int) -> None
445 raise ValueError("can't store negative integers")
446 self
.bits |
= 1 << value
448 def discard(self
, value
):
449 # type: (int) -> None
451 self
.bits
&= ~
(1 << value
)
457 def __ior__(self
, it
):
458 # type: (AbstractSet[Any]) -> Self
459 if isinstance(it
, BaseBitSet
):
462 return super().__ior
__(it
)
464 def __iand__(self
, it
):
465 # type: (AbstractSet[Any]) -> Self
466 if isinstance(it
, BaseBitSet
):
469 return super().__iand
__(it
)
471 def __ixor__(self
, it
):
472 # type: (AbstractSet[Any]) -> Self
473 if isinstance(it
, BaseBitSet
):
476 return super().__ixor
__(it
)
478 def __isub__(self
, it
):
479 # type: (AbstractSet[Any]) -> Self
480 if isinstance(it
, BaseBitSet
):
481 self
.bits
&= ~it
.bits
483 return super().__isub
__(it
)
486 class FBitSet(BaseBitSet
, Interned
):
497 return super()._hash
()