From 6c689405250738a6b67b56db13fdad2c50e3b97c Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Tue, 8 Nov 2022 18:18:00 -0800 Subject: [PATCH] get LocSet hash working correctly --- .../_tests/test_compiler_ir.py | 25 +++- src/bigint_presentation_code/compiler_ir.py | 107 ++++-------------- src/bigint_presentation_code/util.py | 5 +- 3 files changed, 51 insertions(+), 86 deletions(-) diff --git a/src/bigint_presentation_code/_tests/test_compiler_ir.py b/src/bigint_presentation_code/_tests/test_compiler_ir.py index ba29ee0..9763a07 100644 --- a/src/bigint_presentation_code/_tests/test_compiler_ir.py +++ b/src/bigint_presentation_code/_tests/test_compiler_ir.py @@ -2,9 +2,10 @@ import unittest from bigint_presentation_code.compiler_ir import (GPR_SIZE_IN_BYTES, BaseTy, Fn, FnAnalysis, GenAsmState, - Loc, LocKind, OpKind, + Loc, LocKind, LocSet, OpKind, OpStage, PreRASimState, ProgramPoint, SSAVal, Ty) +from bigint_presentation_code.util import OFSet class TestCompilerIR(unittest.TestCase): @@ -23,6 +24,28 @@ class TestCompilerIR(unittest.TestCase): self.assertEqual(sorted(expected), expected) + def test_loc_set_hash_intern(self): + # type: () -> None + # hashes should match all other collections.abc.Set types, which are + # supposed to match frozenset but don't until Python 3.11 because of a + # bug fixed in: + # https://github.com/python/cpython/commit/c878f5d81772dc6f718d6608c78baa4be9a4f176 + a = LocSet([]) + self.assertEqual(hash(a), hash(OFSet())) + starts = 0, 1, 0, 1, 2 + GPR = LocKind.GPR + expected = OFSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts) + b = LocSet(expected) + c = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts) + d = LocSet(Loc(kind=GPR, start=i, reg_len=1) for i in starts) + # hashes should be equal to OFSet's hash + self.assertEqual(hash(b), hash(expected)) + self.assertEqual(hash(c), hash(expected)) + self.assertEqual(hash(d), hash(expected)) + # they should intern to the same object + self.assertIs(b, d) + self.assertIs(c, d) + def make_add_fn(self): # type: () -> tuple[Fn, SSAVal] fn = Fn() diff --git a/src/bigint_presentation_code/compiler_ir.py b/src/bigint_presentation_code/compiler_ir.py index d3b52e8..756615d 100644 --- a/src/bigint_presentation_code/compiler_ir.py +++ b/src/bigint_presentation_code/compiler_ir.py @@ -643,59 +643,35 @@ SPECIAL_GPRS = ( @final -class _LocSetHashHelper(AbstractSet[Loc]): - """helper to more quickly compute LocSet's hash""" - - def __init__(self, locs): - # type: (Iterable[Loc]) -> None - super().__init__() - self.locs = list(locs) - - def __hash__(self): - # type: () -> int - return super()._hash() - - def __contains__(self, x): - # type: (Loc | Any) -> bool - return x in self.locs - - def __iter__(self): - # type: () -> Iterator[Loc] - return iter(self.locs) - - def __len__(self): - return len(self.locs) - - -@plain_data(frozen=True, eq=False, repr=False) -@final -class LocSet(AbstractSet[Loc], metaclass=InternedMeta): - __slots__ = "starts", "ty", "_LocSet__hash" - +class LocSet(OFSet[Loc], metaclass=InternedMeta): def __init__(self, __locs=()): # type: (Iterable[Loc]) -> None + super().__init__(__locs) if isinstance(__locs, LocSet): - self.starts = __locs.starts # type: FMap[LocKind, FBitSet] - self.ty = __locs.ty # type: Ty | None - self._LocSet__hash = __locs._LocSet__hash # type: int + self.__starts = __locs.starts + self.__ty = __locs.ty return starts = {i: BitSet() for i in LocKind} ty = None # type: None | Ty - - def locs(): - # type: () -> Iterable[Loc] - nonlocal ty - for loc in __locs: - if ty is None: - ty = loc.ty - if ty != loc.ty: - raise ValueError(f"conflicting types: {ty} != {loc.ty}") - starts[loc.kind].add(loc.start) - yield loc - self._LocSet__hash = _LocSetHashHelper(locs()).__hash__() - self.starts = FMap( + for loc in self: + if ty is None: + ty = loc.ty + if ty != loc.ty: + raise ValueError(f"conflicting types: {ty} != {loc.ty}") + starts[loc.kind].add(loc.start) + self.__starts = FMap( (k, FBitSet(v)) for k, v in starts.items() if len(v) != 0) - self.ty = ty + self.__ty = ty + + @property + def starts(self): + # type: () -> FMap[LocKind, FBitSet] + return self.__starts + + @property + def ty(self): + # type: () -> Ty | None + return self.__ty @cached_property def stops(self): @@ -756,38 +732,6 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta): yield loc return LocSet(locs()) - def __contains__(self, loc): - # type: (Loc | Any) -> bool - if not isinstance(loc, Loc) or loc.ty != self.ty: - return False - if loc.kind not in self.starts: - return False - return loc.start in self.starts[loc.kind] - - def __iter__(self): - # type: () -> Iterator[Loc] - if self.ty is None: - return - for kind, starts in self.starts.items(): - for start in starts: - yield Loc(kind=kind, start=start, reg_len=self.ty.reg_len) - - @cached_property - def __len(self): - return sum((len(v) for v in self.starts.values()), 0) - - def __len__(self): - return self.__len - - def __hash__(self): - return self._LocSet__hash - - def __eq__(self, __other): - # type: (LocSet | Any) -> bool - if isinstance(__other, LocSet): - return self.ty == __other.ty and self.starts == __other.starts - return super().__eq__(__other) - @lru_cache(maxsize=None, typed=True) def max_conflicts_with(self, other): # type: (LocSet | Loc) -> int @@ -800,12 +744,7 @@ class LocSet(AbstractSet[Loc], metaclass=InternedMeta): return sum(other.conflicts(i) for i in self) def __repr__(self): - items = [] # type: list[str] - for name in fields(self): - if name.startswith("_"): - continue - items.append(f"{name}={getattr(self, name)!r}") - return f"LocSet({', '.join(items)})" + return f"LocSet(starts={self.starts!r}, ty={self.ty!r})" @plain_data(frozen=True, unsafe_hash=True) diff --git a/src/bigint_presentation_code/util.py b/src/bigint_presentation_code/util.py index 87c6975..03eaeff 100644 --- a/src/bigint_presentation_code/util.py +++ b/src/bigint_presentation_code/util.py @@ -57,7 +57,10 @@ class OFSet(AbstractSet[_T_co], metaclass=InternedMeta): def __init__(self, items=()): # type: (Iterable[_T_co]) -> None super().__init__() - self.__items = {v: None for v in items} + if isinstance(items, OFSet): + self.__items = items.__items + else: + self.__items = {v: None for v in items} def __contains__(self, x): # type: (Any) -> bool -- 2.30.2