__NOT_SET = __NotSet()
+def __ignored_classes():
+ classes = [object] # type: list[type]
+
+ from abc import ABC
+
+ classes += [ABC]
+
+ from typing import (
+ Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+ SupportsInt, SupportsRound)
+
+ classes += [
+ Generic, SupportsAbs, SupportsBytes, SupportsComplex, SupportsFloat,
+ SupportsInt, SupportsRound]
+
+ from collections.abc import (
+ Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+ Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+ Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+ MappingView, KeysView, ItemsView, ValuesView, Sequence,
+ MutableSequence)
+
+ classes += [
+ Awaitable, Coroutine, AsyncIterable, AsyncIterator, AsyncGenerator,
+ Hashable, Iterable, Iterator, Generator, Reversible, Sized, Container,
+ Callable, Collection, Set, MutableSet, Mapping, MutableMapping,
+ MappingView, KeysView, ItemsView, ValuesView, Sequence,
+ MutableSequence]
+
+ # rest aren't supported by python 3.7, so try to import them and skip if
+ # that errors
+
+ try:
+ # typing_extensions uses typing.Protocol if available
+ from typing_extensions import Protocol
+ classes.append(Protocol)
+ except ImportError:
+ pass
+
+ for cls in classes:
+ yield from cls.__mro__
+
+
+__IGNORED_CLASSES = frozenset(__ignored_classes())
+
+
def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen):
if not isinstance(cls, type):
raise TypeError(
any_parents_have_dict = False
any_parents_have_weakref = False
for cur_cls in reversed(cls.__mro__):
- if cur_cls is object:
+ d = getattr(cur_cls, "__dict__", {})
+ if cur_cls is not cls:
+ if "__dict__" in d:
+ any_parents_have_dict = True
+ if "__weakref__" in d:
+ any_parents_have_weakref = True
+ if cur_cls in __IGNORED_CLASSES:
continue
try:
cur_slots = cur_cls.__slots__
if field not in slots:
fields.append(field)
slots[field] = None
- if cur_cls is not cls:
- if field == "__dict__":
- any_parents_have_dict = True
- elif field == "__weakref__":
- any_parents_have_weakref = True
fields = tuple(fields) # fields needs to be immutable
import operator
import pickle
import unittest
+import typing
from nmutil.plain_data import (FrozenPlainDataError, plain_data,
fields, replace)
+try:
+ from typing import Protocol
+except ImportError:
+ try:
+ from typing_extensions import Protocol
+ except ImportError:
+ Protocol = None
+
@plain_data(order=True)
class PlainData0:
setattr(self, name, value)
+T = typing.TypeVar("T")
+
+
+@plain_data()
+class GenericClass(typing.Generic[T]):
+ __slots__ = "a",
+
+ def __init__(self, a):
+ self.a = a
+
+
+@plain_data()
+class MySet(typing.AbstractSet[int]):
+ __slots__ = ()
+
+ def __contains__(self, x):
+ raise NotImplementedError
+
+ def __iter__(self):
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+
+@plain_data()
+class MyIntLike(typing.SupportsInt):
+ __slots__ = ()
+
+ def __int__(self):
+ return 1
+
+
+if Protocol is not None:
+ class MyProtocol(Protocol):
+ def my_method(self): ...
+
+ @plain_data()
+ class MyProtocolImpl(MyProtocol):
+ __slots__ = ()
+
+ def my_method(self):
+ pass
+
+
class TestPlainData(unittest.TestCase):
def test_fields(self):
self.assertEqual(fields(PlainData0), ())
self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
("a", "b", "x", "y", "z"))
+ self.assertEqual(fields(GenericClass(1)), ("a",))
+ self.assertEqual(fields(MySet()), ())
+ self.assertEqual(fields(MyIntLike()), ())
+ if Protocol is not None:
+ self.assertEqual(fields(MyProtocolImpl()), ())
with self.assertRaisesRegex(
TypeError,
r"the passed-in object must be a class or an instance of a "