From 900506e0410607e66c542aa50ef0ded876d32102 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 12 Oct 2022 18:33:17 -0700 Subject: [PATCH] change plain_data to ignore more base classes, so it'll work with ABCs and stuff --- src/nmutil/plain_data.py | 59 +++++++++++++++++++++++++++--- src/nmutil/test/test_plain_data.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 6 deletions(-) diff --git a/src/nmutil/plain_data.py b/src/nmutil/plain_data.py index ace4d61..0ebcb5d 100644 --- a/src/nmutil/plain_data.py +++ b/src/nmutil/plain_data.py @@ -15,6 +15,52 @@ class __NotSet: __NOT_SET = __NotSet() +def __ignored_classes(): + classes = [object] # type: list[type] + + from abc import ABC + + classes += [ABC] + + from typing import ( + Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat, + SupportsInt, SupportsRound) + + classes += [ + Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat, + SupportsInt, SupportsRound] + + from collections.abc import ( + Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator, + Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container, + Callable, Collection, Set, MutableSet, Mapping, MutableMapping, + MappingView, KeysView, ItemsView, ValuesView, Sequence, + MutableSequence) + + classes += [ + Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator, + Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container, + Callable, Collection, Set, MutableSet, Mapping, MutableMapping, + MappingView, KeysView, ItemsView, ValuesView, Sequence, + MutableSequence] + + # rest aren't supported by python 3.7, so try to import them and skip if + # that errors + + try: + # typing_extensions uses typing.Protocol if available + from typing_extensions import Protocol + classes.append(Protocol) + except ImportError: + pass + + for cls in classes: + yield from cls.__mro__ + + +__IGNORED_CLASSES = frozenset(__ignored_classes()) + + def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): if not isinstance(cls, type): raise TypeError( @@ -28,7 +74,13 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): any_parents_have_dict = False any_parents_have_weakref = False for cur_cls in reversed(cls.__mro__): - if cur_cls is object: + d = getattr(cur_cls, "__dict__", {}) + if cur_cls is not cls: + if "__dict__" in d: + any_parents_have_dict = True + if "__weakref__" in d: + any_parents_have_weakref = True + if cur_cls in __IGNORED_CLASSES: continue try: cur_slots = cur_cls.__slots__ @@ -47,11 +99,6 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): if field not in slots: fields.append(field) slots[field] = None - if cur_cls is not cls: - if field == "__dict__": - any_parents_have_dict = True - elif field == "__weakref__": - any_parents_have_weakref = True fields = tuple(fields) # fields needs to be immutable diff --git a/src/nmutil/test/test_plain_data.py b/src/nmutil/test/test_plain_data.py index a087e6e..f16faba 100644 --- a/src/nmutil/test/test_plain_data.py +++ b/src/nmutil/test/test_plain_data.py @@ -4,9 +4,18 @@ import operator import pickle import unittest +import typing from nmutil.plain_data import (FrozenPlainDataError, plain_data, fields, replace) +try: + from typing import Protocol +except ImportError: + try: + from typing_extensions import Protocol + except ImportError: + Protocol = None + @plain_data(order=True) class PlainData0: @@ -67,6 +76,51 @@ class UnsetField: setattr(self, name, value) +T = typing.TypeVar("T") + + +@plain_data() +class GenericClass(typing.Generic[T]): + __slots__ = "a", + + def __init__(self, a): + self.a = a + + +@plain_data() +class MySet(typing.AbstractSet[int]): + __slots__ = () + + def __contains__(self, x): + raise NotImplementedError + + def __iter__(self): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + +@plain_data() +class MyIntLike(typing.SupportsInt): + __slots__ = () + + def __int__(self): + return 1 + + +if Protocol is not None: + class MyProtocol(Protocol): + def my_method(self): ... + + @plain_data() + class MyProtocolImpl(MyProtocol): + __slots__ = () + + def my_method(self): + pass + + class TestPlainData(unittest.TestCase): def test_fields(self): self.assertEqual(fields(PlainData0), ()) @@ -85,6 +139,11 @@ class TestPlainData(unittest.TestCase): self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z")) self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)), ("a", "b", "x", "y", "z")) + self.assertEqual(fields(GenericClass(1)), ("a",)) + self.assertEqual(fields(MySet()), ()) + self.assertEqual(fields(MyIntLike()), ()) + if Protocol is not None: + self.assertEqual(fields(MyProtocolImpl()), ()) with self.assertRaisesRegex( TypeError, r"the passed-in object must be a class or an instance of a " -- 2.30.2