from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseTy,
Fn, FnAnalysis, GenAsmState,
- Loc, LocKind, OpKind,
+ Loc, LocKind, LocSet, OpKind,
OpStage, PreRASimState,
ProgramPoint, SSAVal, Ty)
+from bigint_presentation_code.util import OFSet
class TestCompilerIR(unittest.TestCase):
self.assertEqual(sorted(expected), expected)
+ def test_loc_set_hash_intern(self):
+ # type: () -> None
+ # hashes should match all other collections.abc.Set types, which are
+ # supposed to match frozenset but don't until Python 3.11 because of a
+ # bug fixed in:
+ # https://github.com/python/cpython/commit/c878f5d81772dc6f718d6608c78baa4be9a4f176
+ a = LocSet([])
+ self.assertEqual(hash(a), hash(OFSet()))
+ starts = 0, 1, 0, 1, 2
+ GPR = LocKind.GPR
+ expected = OFSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+ b = LocSet(expected)
+ c = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+ d = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts)
+ # hashes should be equal to OFSet's hash
+ self.assertEqual(hash(b), hash(expected))
+ self.assertEqual(hash(c), hash(expected))
+ self.assertEqual(hash(d), hash(expected))
+ # they should intern to the same object
+ self.assertIs(b, d)
+ self.assertIs(c, d)
+
def make_add_fn(self):
# type: () -> tuple[Fn, SSAVal]
fn = Fn()
@final
-class _LocSetHashHelper(AbstractSet[Loc]):
- """helper to more quickly compute LocSet's hash"""
-
- def __init__(self, locs):
- # type: (Iterable[Loc]) -> None
- super().__init__()
- self.locs = list(locs)
-
- def __hash__(self):
- # type: () -> int
- return super()._hash()
-
- def __contains__(self, x):
- # type: (Loc | Any) -> bool
- return x in self.locs
-
- def __iter__(self):
- # type: () -> Iterator[Loc]
- return iter(self.locs)
-
- def __len__(self):
- return len(self.locs)
-
-
-@plain_data(frozen=True, eq=False, repr=False)
-@final
-class LocSet(AbstractSet[Loc], metaclass=InternedMeta):
- __slots__ = "starts", "ty", "_LocSet__hash"
-
+class LocSet(OFSet[Loc], metaclass=InternedMeta):
def __init__(self, __locs=()):
# type: (Iterable[Loc]) -> None
+ super().__init__(__locs)
if isinstance(__locs, LocSet):
- self.starts = __locs.starts # type: FMap[LocKind, FBitSet]
- self.ty = __locs.ty # type: Ty | None
- self._LocSet__hash = __locs._LocSet__hash # type: int
+ self.__starts = __locs.starts
+ self.__ty = __locs.ty
return
starts = {i: BitSet() for i in LocKind}
ty = None # type: None | Ty
-
- def locs():
- # type: () -> Iterable[Loc]
- nonlocal ty
- for loc in __locs:
- if ty is None:
- ty = loc.ty
- if ty != loc.ty:
- raise ValueError(f"conflicting types: {ty} != {loc.ty}")
- starts[loc.kind].add(loc.start)
- yield loc
- self._LocSet__hash = _LocSetHashHelper(locs()).__hash__()
- self.starts = FMap(
+ for loc in self:
+ if ty is None:
+ ty = loc.ty
+ if ty != loc.ty:
+ raise ValueError(f"conflicting types: {ty} != {loc.ty}")
+ starts[loc.kind].add(loc.start)
+ self.__starts = FMap(
(k, FBitSet(v)) for k, v in starts.items() if len(v) != 0)
- self.ty = ty
+ self.__ty = ty
+
+ @property
+ def starts(self):
+ # type: () -> FMap[LocKind, FBitSet]
+ return self.__starts
+
+ @property
+ def ty(self):
+ # type: () -> Ty | None
+ return self.__ty
@cached_property
def stops(self):
yield loc
return LocSet(locs())
- def __contains__(self, loc):
- # type: (Loc | Any) -> bool
- if not isinstance(loc, Loc) or loc.ty != self.ty:
- return False
- if loc.kind not in self.starts:
- return False
- return loc.start in self.starts[loc.kind]
-
- def __iter__(self):
- # type: () -> Iterator[Loc]
- if self.ty is None:
- return
- for kind, starts in self.starts.items():
- for start in starts:
- yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len)
-
- @cached_property
- def __len(self):
- return sum((len(v) for v in self.starts.values()), 0)
-
- def __len__(self):
- return self.__len
-
- def __hash__(self):
- return self._LocSet__hash
-
- def __eq__(self, __other):
- # type: (LocSet | Any) -> bool
- if isinstance(__other, LocSet):
- return self.ty == __other.ty and self.starts == __other.starts
- return super().__eq__(__other)
-
@lru_cache(maxsize=None, typed=True)
def max_conflicts_with(self, other):
# type: (LocSet | Loc) -> int
return sum(other.conflicts(i) for i in self)
def __repr__(self):
- items = [] # type: list[str]
- for name in fields(self):
- if name.startswith("_"):
- continue
- items.append(f"{name}={getattr(self, name)!r}")
- return f"LocSet({', '.join(items)})"
+ return f"LocSet(starts={self.starts!r}, ty={self.ty!r})"
@plain_data(frozen=True, unsafe_hash=True)