1 # SPDX-License-Identifier: LGPL-3-or-later
2 # Copyright 2022 Jacob Lifshay programmerjake@gmail.com
6 class FrozenPlainDataError(AttributeError):
11 """ helper for __repr__ for when fields aren't set """
17 __NOT_SET
= __NotSet()
20 def __ignored_classes():
21 classes
= [object] # type: list[type]
28 Generic
, SupportsAbs
, SupportsBytes
, SupportsComplex
, SupportsFloat
,
29 SupportsInt
, SupportsRound
)
32 Generic
, SupportsAbs
, SupportsBytes
, SupportsComplex
, SupportsFloat
,
33 SupportsInt
, SupportsRound
]
35 from collections
.abc
import (
36 Awaitable
, Coroutine
, AsyncIterable
, AsyncIterator
, AsyncGenerator
,
37 Hashable
, Iterable
, Iterator
, Generator
, Reversible
, Sized
, Container
,
38 Callable
, Collection
, Set
, MutableSet
, Mapping
, MutableMapping
,
39 MappingView
, KeysView
, ItemsView
, ValuesView
, Sequence
,
43 Awaitable
, Coroutine
, AsyncIterable
, AsyncIterator
, AsyncGenerator
,
44 Hashable
, Iterable
, Iterator
, Generator
, Reversible
, Sized
, Container
,
45 Callable
, Collection
, Set
, MutableSet
, Mapping
, MutableMapping
,
46 MappingView
, KeysView
, ItemsView
, ValuesView
, Sequence
,
49 # rest aren't supported by python 3.7, so try to import them and skip if
53 # typing_extensions uses typing.Protocol if available
54 from typing_extensions
import Protocol
55 classes
.append(Protocol
)
60 yield from cls
.__mro
__
63 __IGNORED_CLASSES
= frozenset(__ignored_classes())
66 def _decorator(cls
, *, eq
, unsafe_hash
, order
, repr_
, frozen
):
67 if not isinstance(cls
, type):
69 "plain_data() can only be used as a class decorator")
70 # slots is an ordered set by using dict keys.
71 # always add __dict__ and __weakref__
72 slots
= {"__dict__": None, "__weakref__": None}
74 slots
["__plain_data_init_done"] = None
76 any_parents_have_dict
= False
77 any_parents_have_weakref
= False
78 for cur_cls
in reversed(cls
.__mro
__):
79 d
= getattr(cur_cls
, "__dict__", {})
80 if cur_cls
is not cls
:
82 any_parents_have_dict
= True
83 if "__weakref__" in d
:
84 any_parents_have_weakref
= True
85 if cur_cls
in __IGNORED_CLASSES
:
88 cur_slots
= cur_cls
.__slots
__
89 except AttributeError as e
:
90 raise TypeError(f
"{cur_cls.__module__}.{cur_cls.__qualname__}"
91 " must have __slots__ so plain_data() can "
92 "determine what fields exist in "
93 f
"{cls.__module__}.{cls.__qualname__}") from e
94 if not isinstance(cur_slots
, tuple):
95 raise TypeError("plain_data() requires __slots__ to be a "
97 for field
in cur_slots
:
98 if not isinstance(field
, str):
99 raise TypeError("plain_data() requires __slots__ to be a "
101 if not field
.isidentifier() or keyword
.iskeyword(field
):
103 "plain_data() requires __slots__ entries to be valid "
104 "Python identifiers and not keywords")
105 if field
not in slots
:
109 fields
= tuple(fields
) # fields needs to be immutable
111 if any_parents_have_dict
:
112 # work around a CPython bug that unnecessarily checks if parent
113 # classes already have the __dict__ slot.
114 del slots
["__dict__"]
116 if any_parents_have_weakref
:
117 # work around a CPython bug that unnecessarily checks if parent
118 # classes already have the __weakref__ slot.
119 del slots
["__weakref__"]
121 # now create a new class having everything we need
122 retval_dict
= dict(cls
.__dict
__)
123 # remove all old descriptors:
124 for name
in slots
.keys():
125 retval_dict
.pop(name
, None)
127 retval_dict
["__plain_data_fields"] = fields
129 def add_method_or_error(value
, replace
=False):
130 name
= value
.__name
__
131 if name
in retval_dict
and not replace
:
133 f
"can't generate {name} method: attribute already exists")
134 value
.__qualname
__ = f
"{cls.__qualname__}.{value.__name__}"
135 retval_dict
[name
] = value
138 def __setattr__(self
, name
: str, value
):
139 if getattr(self
, "__plain_data_init_done", False):
140 raise FrozenPlainDataError(f
"cannot assign to field {name!r}")
141 elif name
not in slots
and not name
.startswith("_"):
142 raise AttributeError(
143 f
"cannot assign to unknown field {name!r}")
144 object.__setattr
__(self
, name
, value
)
146 add_method_or_error(__setattr__
)
148 def __delattr__(self
, name
):
149 if getattr(self
, "__plain_data_init_done", False):
150 raise FrozenPlainDataError(f
"cannot delete field {name!r}")
151 object.__delattr
__(self
, name
)
153 add_method_or_error(__delattr__
)
155 old_init
= cls
.__init
__
157 def __init__(self
, *args
, **kwargs
):
158 if hasattr(self
, "__plain_data_init_done"):
159 # we're already in an __init__ call (probably a
160 # superclass's __init__), don't set
161 # __plain_data_init_done too early
162 return old_init(self
, *args
, **kwargs
)
163 object.__setattr
__(self
, "__plain_data_init_done", False)
165 return old_init(self
, *args
, **kwargs
)
167 object.__setattr
__(self
, "__plain_data_init_done", True)
169 add_method_or_error(__init__
, replace
=True)
173 # set __slots__ to have everything we need in the preferred order
174 retval_dict
["__slots__"] = tuple(slots
.keys())
176 def __getstate__(self
):
178 return [getattr(self
, name
) for name
in fields
]
180 add_method_or_error(__getstate__
)
182 def __setstate__(self
, state
):
184 for name
, value
in zip(fields
, state
):
185 # bypass frozen setattr
186 object.__setattr
__(self
, name
, value
)
188 add_method_or_error(__setstate__
)
190 # get source code that gets a tuple of all fields
191 def fields_tuple(var
):
195 l
.append(f
"{var}.{name}, ")
196 return "(" + "".join(l
) + ")"
200 def __eq__(self, other):
201 if other.__class__ is not self.__class__:
202 return NotImplemented
203 return {fields_tuple('self')} == {fields_tuple('other')}
205 add_method_or_error(__eq__)
211 return hash({fields_tuple('self')})
213 add_method_or_error(__hash__)
218 def __lt__(self, other):
219 if other.__class__ is not self.__class__:
220 return NotImplemented
221 return {fields_tuple('self')} < {fields_tuple('other')}
223 add_method_or_error(__lt__)
225 def __le__(self, other):
226 if other.__class__ is not self.__class__:
227 return NotImplemented
228 return {fields_tuple('self')} <= {fields_tuple('other')}
230 add_method_or_error(__le__)
232 def __gt__(self, other):
233 if other.__class__ is not self.__class__:
234 return NotImplemented
235 return {fields_tuple('self')} > {fields_tuple('other')}
237 add_method_or_error(__gt__)
239 def __ge__(self, other):
240 if other.__class__ is not self.__class__:
241 return NotImplemented
242 return {fields_tuple('self')} >= {fields_tuple('other')}
244 add_method_or_error(__ge__)
251 parts
.append(f
"{name}={getattr(self, name, __NOT_SET)!r}")
252 return f
"{self.__class__.__qualname__}({', '.join(parts)})"
254 add_method_or_error(__repr__
)
257 retval
= type(cls
)(cls
.__name
__, cls
.__bases
__, retval_dict
)
260 retval
.__qualname
__ = cls
.__qualname
__
262 def fix_super_and_class(value
):
263 # fixup super() and __class__
264 # derived from: https://stackoverflow.com/a/71666065/2597900
266 closure
= value
.__closure
__
267 if isinstance(closure
, tuple):
268 if closure
[0].cell_contents
is cls
:
269 closure
[0].cell_contents
= retval
270 except (AttributeError, IndexError):
273 for value
in retval
.__dict
__.values():
274 fix_super_and_class(value
)
276 if old_init
is not None:
277 fix_super_and_class(old_init
)
282 def plain_data(*, eq
=True, unsafe_hash
=False, order
=False, repr=True,
284 # defaults match dataclass, with the exception of `init`
285 """ Decorator for adding equality comparison, ordered comparison,
286 `repr` support, `hash` support, and frozen type (read-only fields)
287 support to classes that are just plain data.
289 This is kinda like dataclasses, but uses `__slots__` instead of type
290 annotations, as well as requiring you to write your own `__init__`
293 return _decorator(cls
, eq
=eq
, unsafe_hash
=unsafe_hash
, order
=order
,
294 repr_
=repr, frozen
=frozen
)
299 """ get the tuple of field names of the passed-in
300 `@plain_data()`-decorated class.
302 This is similar to `dataclasses.fields`, except this returns a
305 Returns: tuple[str, ...]
311 __slots__ = "a_field", "field2"
312 def __init__(self, a_field, field2):
313 self.a_field = a_field
316 assert fields(MyBaseClass) == ("a_field", "field2")
317 assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
320 class MyClass(MyBaseClass):
321 __slots__ = "child_field",
322 def __init__(self, a_field, field2, child_field):
323 super().__init__(a_field=a_field, field2=field2)
324 self.child_field = child_field
326 assert fields(MyClass) == ("a_field", "field2", "child_field")
327 assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
330 retval
= getattr(pd
, "__plain_data_fields", None)
331 if not isinstance(retval
, tuple):
332 raise TypeError("the passed-in object must be a class or an instance"
333 " of a class decorated with @plain_data()")
337 __NOT_SPECIFIED
= object()
340 def replace(pd
, **changes
):
341 """ Return a new instance of the passed-in `@plain_data()`-decorated
342 object, but with the specified fields replaced with new values.
343 This is quite useful with frozen `@plain_data()` classes.
347 @plain_data(frozen=True)
349 __slots__ = "a", "b", "c"
350 def __init__(self, a, b, *, c):
355 v1 = MyClass(1, 2, c=3)
356 v2 = replace(v1, b=4)
357 assert v2 == MyClass(a=1, b=4, c=3)
363 # call fields on ty rather than pd to ensure we're not called with a
364 # class rather than an instance.
365 for name
in fields(ty
):
366 value
= changes
.pop(name
, __NOT_SPECIFIED
)
367 if value
is __NOT_SPECIFIED
:
368 kwargs
[name
] = getattr(pd
, name
)
371 if len(changes
) != 0:
372 raise TypeError(f
"can't set unknown field {changes.popitem()[0]!r}")