attempt to speed up code
[bigint-presentation-code.git] / src / bigint_presentation_code / util.py
1 from abc import ABCMeta, abstractmethod
2 from collections import defaultdict
3 from typing import (AbstractSet, Any, Callable, Iterable, Iterator, Mapping,
4 MutableSet, TypeVar, overload)
5
6 from bigint_presentation_code.type_util import Self, final
7
8 _T_co = TypeVar("_T_co", covariant=True)
9 _T = TypeVar("_T")
10 _T2 = TypeVar("_T2")
11
12 __all__ = [
13 "BaseBitSet",
14 "bit_count",
15 "BitSet",
16 "FBitSet",
17 "FMap",
18 "OFSet",
19 "OSet",
20 "top_set_bit_index",
21 "trailing_zero_count",
22 "Interned",
23 ]
24
25
26 class _InternedMeta(ABCMeta):
27 def __call__(self, *args: Any, **kwds: Any) -> Any:
28 return super().__call__(*args, **kwds)._Interned__intern()
29
30
31 class Interned(metaclass=_InternedMeta):
32 def __init_intern(self):
33 # type: (Self) -> Self
34 cls = type(self)
35 old_hash = cls.__hash__
36 old_hash = getattr(old_hash, "_Interned__old_hash", old_hash)
37 old_eq = cls.__eq__
38 old_eq = getattr(old_eq, "_Interned__old_eq", old_eq)
39
40 def __hash__(self):
41 # type: (Self) -> int
42 return self._Interned__hash # type: ignore
43 __hash__._Interned__old_hash = old_hash # type: ignore
44 cls.__hash__ = __hash__
45
46 def __eq__(self, # type: Self
47 __other, # type: Any
48 *, __eq=old_eq, # type: Callable[[Self, Any], bool]
49 ):
50 # type: (...) -> bool
51 if self.__class__ is __other.__class__:
52 return self is __other
53 return __eq(self, __other)
54 __eq__._Interned__old_eq = old_eq # type: ignore
55 cls.__eq__ = __eq__
56
57 table = defaultdict(list) # type: dict[int, list[Self]]
58
59 def __intern(self, # type: Self
60 *, __hash=old_hash, # type: Callable[[Self], int]
61 __eq=old_eq, # type: Callable[[Self, Any], bool]
62 __table=table, # type: dict[int, list[Self]]
63 __NotImplemented=NotImplemented, # type: Any
64 ):
65 # type: (...) -> Self
66 h = __hash(self)
67 bucket = __table[h]
68 for i in bucket:
69 v = __eq(self, i)
70 if v is not __NotImplemented and v:
71 return i
72 self.__dict__["_Interned__hash"] = h
73 bucket.append(self)
74 return self
75 cls._Interned__intern = __intern
76 return __intern(self)
77
78 _Interned__intern = __init_intern
79
80
81 class OFSet(AbstractSet[_T_co], Interned):
82 """ ordered frozen set """
83 __slots__ = "__items", "__dict__", "__weakref__"
84
85 def __init__(self, items=()):
86 # type: (Iterable[_T_co]) -> None
87 super().__init__()
88 if isinstance(items, OFSet):
89 self.__items = items.__items
90 else:
91 self.__items = {v: None for v in items}
92
93 def __contains__(self, x):
94 # type: (Any) -> bool
95 return x in self.__items
96
97 def __iter__(self):
98 # type: () -> Iterator[_T_co]
99 return iter(self.__items)
100
101 def __len__(self):
102 # type: () -> int
103 return len(self.__items)
104
105 def __hash__(self):
106 # type: () -> int
107 return self._hash()
108
109 def __repr__(self):
110 # type: () -> str
111 if len(self) == 0:
112 return "OFSet()"
113 return f"OFSet({list(self)})"
114
115
116 class OSet(MutableSet[_T]):
117 """ ordered mutable set """
118 __slots__ = "__items", "__dict__"
119
120 def __init__(self, items=()):
121 # type: (Iterable[_T]) -> None
122 super().__init__()
123 self.__items = {v: None for v in items}
124
125 def __contains__(self, x):
126 # type: (Any) -> bool
127 return x in self.__items
128
129 def __iter__(self):
130 # type: () -> Iterator[_T]
131 return iter(self.__items)
132
133 def __len__(self):
134 # type: () -> int
135 return len(self.__items)
136
137 def add(self, value):
138 # type: (_T) -> None
139 self.__items[value] = None
140
141 def discard(self, value):
142 # type: (_T) -> None
143 self.__items.pop(value, None)
144
145 def remove(self, value):
146 # type: (_T) -> None
147 del self.__items[value]
148
149 def pop(self):
150 # type: () -> _T
151 return self.__items.popitem()[0]
152
153 def clear(self):
154 # type: () -> None
155 self.__items.clear()
156
157 def __repr__(self):
158 # type: () -> str
159 if len(self) == 0:
160 return "OSet()"
161 return f"OSet({list(self)})"
162
163
164 class FMap(Mapping[_T, _T_co], Interned):
165 """ordered frozen hashable mapping"""
166 __slots__ = "__items", "__hash", "__dict__", "__weakref__"
167
168 @overload
169 def __init__(self, items):
170 # type: (Mapping[_T, _T_co]) -> None
171 ...
172
173 @overload
174 def __init__(self, items):
175 # type: (Iterable[tuple[_T, _T_co]]) -> None
176 ...
177
178 @overload
179 def __init__(self):
180 # type: () -> None
181 ...
182
183 def __init__(self, items=()):
184 # type: (Mapping[_T, _T_co] | Iterable[tuple[_T, _T_co]]) -> None
185 super().__init__()
186 self.__items = dict(items) # type: dict[_T, _T_co]
187 self.__hash = None # type: None | int
188
189 def __getitem__(self, item):
190 # type: (_T) -> _T_co
191 return self.__items[item]
192
193 def __iter__(self):
194 # type: () -> Iterator[_T]
195 return iter(self.__items)
196
197 def __len__(self):
198 # type: () -> int
199 return len(self.__items)
200
201 def __eq__(self, other):
202 # type: (FMap[Any, Any] | Any) -> bool
203 if isinstance(other, FMap):
204 return self.__items == other.__items
205 return super().__eq__(other)
206
207 def __hash__(self):
208 # type: () -> int
209 if self.__hash is None:
210 self.__hash = hash(frozenset(self.items()))
211 return self.__hash
212
213 def __repr__(self):
214 # type: () -> str
215 return f"FMap({self.__items})"
216
217 def get(self, key, default=None):
218 # type: (_T, _T_co | _T2) -> _T_co | _T2
219 return self.__items.get(key, default)
220
221 def __contains__(self, key):
222 # type: (_T | object) -> bool
223 return key in self.__items
224
225
226 def trailing_zero_count(v, default=-1):
227 # type: (int, int) -> int
228 without_bit = v & (v - 1) # clear lowest set bit
229 bit = v & ~without_bit # extract lowest set bit
230 return top_set_bit_index(bit, default)
231
232
233 def top_set_bit_index(v, default=-1):
234 # type: (int, int) -> int
235 if v <= 0:
236 return default
237 return v.bit_length() - 1
238
239
240 try:
241 # added in cpython 3.10
242 bit_count = int.bit_count # type: ignore
243 except AttributeError:
244 def bit_count(v):
245 # type: (int) -> int
246 """returns the number of 1 bits in the absolute value of the input"""
247 return bin(abs(v)).count('1')
248
249
250 class BaseBitSet(AbstractSet[int]):
251 __slots__ = "__bits", "__dict__", "__weakref__"
252
253 @classmethod
254 @abstractmethod
255 def _frozen(cls):
256 # type: () -> bool
257 return False
258
259 @classmethod
260 def _from_bits(cls, bits):
261 # type: (int) -> Self
262 return cls(bits=bits)
263
264 def __init__(self, items=(), bits=0):
265 # type: (Iterable[int], int) -> None
266 super().__init__()
267 if isinstance(items, BaseBitSet):
268 bits |= items.bits
269 else:
270 for item in items:
271 if item < 0:
272 raise ValueError("can't store negative integers")
273 bits |= 1 << item
274 if bits < 0:
275 raise ValueError("can't store an infinite set")
276 self.__bits = bits
277
278 @property
279 def bits(self):
280 # type: () -> int
281 return self.__bits
282
283 @bits.setter
284 def bits(self, bits):
285 # type: (int) -> None
286 if self._frozen():
287 raise AttributeError("can't write to frozen bitset's bits")
288 if bits < 0:
289 raise ValueError("can't store an infinite set")
290 self.__bits = bits
291
292 def __contains__(self, x):
293 # type: (Any) -> bool
294 if isinstance(x, int) and x >= 0:
295 return (1 << x) & self.bits != 0
296 return False
297
298 def __iter__(self):
299 # type: () -> Iterator[int]
300 bits = self.bits
301 while bits != 0:
302 index = trailing_zero_count(bits)
303 yield index
304 bits -= 1 << index
305
306 def __reversed__(self):
307 # type: () -> Iterator[int]
308 bits = self.bits
309 while bits != 0:
310 index = top_set_bit_index(bits)
311 yield index
312 bits -= 1 << index
313
314 def __len__(self):
315 # type: () -> int
316 return bit_count(self.bits)
317
318 def __repr__(self):
319 # type: () -> str
320 if self.bits == 0:
321 return f"{self.__class__.__name__}()"
322 len_self = len(self)
323 if len_self <= 3:
324 v = list(self)
325 return f"{self.__class__.__name__}({v})"
326 ranges = [] # type: list[range]
327 MAX_RANGES = 5
328 for i in self:
329 if len(ranges) != 0 and ranges[-1].stop == i:
330 ranges[-1] = range(
331 ranges[-1].start, i + ranges[-1].step, ranges[-1].step)
332 elif len(ranges) != 0 and len(ranges[-1]) == 1:
333 start = ranges[-1][0]
334 step = i - start
335 stop = i + step
336 ranges[-1] = range(start, stop, step)
337 elif len(ranges) != 0 and len(ranges[-1]) == 2:
338 single = ranges[-1][0]
339 start = ranges[-1][1]
340 ranges[-1] = range(single, single + 1)
341 step = i - start
342 stop = i + step
343 ranges.append(range(start, stop, step))
344 else:
345 ranges.append(range(i, i + 1))
346 if len(ranges) > MAX_RANGES:
347 break
348 if len(ranges) == 1:
349 return f"{self.__class__.__name__}({ranges[0]})"
350 if len(ranges) <= MAX_RANGES:
351 range_strs = [] # type: list[str]
352 for r in ranges:
353 if len(r) == 1:
354 range_strs.append(str(r[0]))
355 else:
356 range_strs.append(f"*{r}")
357 ranges_str = ", ".join(range_strs)
358 return f"{self.__class__.__name__}([{ranges_str}])"
359 if self.bits > 0xFFFFFFFF and len_self < 10:
360 v = list(self)
361 return f"{self.__class__.__name__}({v})"
362 return f"{self.__class__.__name__}(bits={hex(self.bits)})"
363
364 def __eq__(self, other):
365 # type: (Any) -> bool
366 if not isinstance(other, BaseBitSet):
367 return super().__eq__(other)
368 return self.bits == other.bits
369
370 def __and__(self, other):
371 # type: (Iterable[Any]) -> Self
372 if isinstance(other, BaseBitSet):
373 return self._from_bits(self.bits & other.bits)
374 bits = 0
375 for item in other:
376 if isinstance(item, int) and item >= 0:
377 bits |= 1 << item
378 return self._from_bits(self.bits & bits)
379
380 __rand__ = __and__
381
382 def __or__(self, other):
383 # type: (Iterable[Any]) -> Self
384 if isinstance(other, BaseBitSet):
385 return self._from_bits(self.bits | other.bits)
386 bits = self.bits
387 for item in other:
388 if isinstance(item, int) and item >= 0:
389 bits |= 1 << item
390 return self._from_bits(bits)
391
392 __ror__ = __or__
393
394 def __xor__(self, other):
395 # type: (Iterable[Any]) -> Self
396 if isinstance(other, BaseBitSet):
397 return self._from_bits(self.bits ^ other.bits)
398 bits = self.bits
399 for item in other:
400 if isinstance(item, int) and item >= 0:
401 bits ^= 1 << item
402 return self._from_bits(bits)
403
404 __rxor__ = __xor__
405
406 def __sub__(self, other):
407 # type: (Iterable[Any]) -> Self
408 if isinstance(other, BaseBitSet):
409 return self._from_bits(self.bits & ~other.bits)
410 bits = self.bits
411 for item in other:
412 if isinstance(item, int) and item >= 0:
413 bits &= ~(1 << item)
414 return self._from_bits(bits)
415
416 def __rsub__(self, other):
417 # type: (Iterable[Any]) -> Self
418 if isinstance(other, BaseBitSet):
419 return self._from_bits(~self.bits & other.bits)
420 bits = 0
421 for item in other:
422 if isinstance(item, int) and item >= 0:
423 bits |= 1 << item
424 return self._from_bits(~self.bits & bits)
425
426 def isdisjoint(self, other):
427 # type: (Iterable[Any]) -> bool
428 if isinstance(other, BaseBitSet):
429 return self.bits & other.bits == 0
430 return super().isdisjoint(other)
431
432
433 class BitSet(BaseBitSet, MutableSet[int]):
434 """Mutable Bit Set"""
435
436 @final
437 @classmethod
438 def _frozen(cls):
439 # type: () -> bool
440 return False
441
442 def add(self, value):
443 # type: (int) -> None
444 if value < 0:
445 raise ValueError("can't store negative integers")
446 self.bits |= 1 << value
447
448 def discard(self, value):
449 # type: (int) -> None
450 if value >= 0:
451 self.bits &= ~(1 << value)
452
453 def clear(self):
454 # type: () -> None
455 self.bits = 0
456
457 def __ior__(self, it):
458 # type: (AbstractSet[Any]) -> Self
459 if isinstance(it, BaseBitSet):
460 self.bits |= it.bits
461 return self
462 return super().__ior__(it)
463
464 def __iand__(self, it):
465 # type: (AbstractSet[Any]) -> Self
466 if isinstance(it, BaseBitSet):
467 self.bits &= it.bits
468 return self
469 return super().__iand__(it)
470
471 def __ixor__(self, it):
472 # type: (AbstractSet[Any]) -> Self
473 if isinstance(it, BaseBitSet):
474 self.bits ^= it.bits
475 return self
476 return super().__ixor__(it)
477
478 def __isub__(self, it):
479 # type: (AbstractSet[Any]) -> Self
480 if isinstance(it, BaseBitSet):
481 self.bits &= ~it.bits
482 return self
483 return super().__isub__(it)
484
485
486 class FBitSet(BaseBitSet, Interned):
487 """Frozen Bit Set"""
488
489 @final
490 @classmethod
491 def _frozen(cls):
492 # type: () -> bool
493 return True
494
495 def __hash__(self):
496 # type: () -> int
497 return super()._hash()