# slots is an ordered set by using dict keys.
# always add __dict__ and __weakref__
slots = {"__dict__": None, "__weakref__": None}
+ if frozen:
+ slots["__plain_data_init_done"] = None
fields = []
any_parents_have_dict = False
any_parents_have_weakref = False
elif field == "__weakref__":
any_parents_have_weakref = True
+ fields = tuple(fields) # fields needs to be immutable
+
if any_parents_have_dict:
# work around a CPython bug that unnecessarily checks if parent
# classes already have the __dict__ slot.
for name in slots.keys():
retval_dict.pop(name, None)
+ retval_dict["_fields"] = fields
+
def add_method_or_error(value, replace=False):
name = value.__name__
if name in retval_dict and not replace:
retval_dict[name] = value
if frozen:
- slots["__plain_data_init_done"] = None
-
def __setattr__(self, name: str, value):
if getattr(self, "__plain_data_init_done", False):
raise FrozenPlainDataError(f"cannot assign to field {name!r}")
object.__setattr__(self, "__plain_data_init_done", True)
add_method_or_error(__init__, replace=True)
+ else:
+ old_init = None
# set __slots__ to have everything we need in the preferred order
retval_dict["__slots__"] = tuple(slots.keys())
- def __dir__(self):
- # don't return fields un-copied so users can't mess with it
- return fields.copy()
-
- add_method_or_error(__dir__)
-
def __getstate__(self):
# pickling support
return [getattr(self, name) for name in fields]
# add __qualname__
retval.__qualname__ = cls.__qualname__
- # fixup super() and __class__
- # derived from: https://stackoverflow.com/a/71666065/2597900
- for value in retval.__dict__.values():
+ def fix_super_and_class(value):
+ # fixup super() and __class__
+ # derived from: https://stackoverflow.com/a/71666065/2597900
try:
closure = value.__closure__
if isinstance(closure, tuple):
except (AttributeError, IndexError):
pass
+ for value in retval.__dict__.values():
+ fix_super_and_class(value)
+
+ if old_init is not None:
+ fix_super_and_class(old_init)
+
return retval
-def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True,
+def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True,
frozen=False):
+ # defaults match dataclass, with the exception of `init`
""" Decorator for adding equality comparison, ordered comparison,
`repr` support, `hash` support, and frozen type (read-only fields)
support to classes that are just plain data.
# SPDX-License-Identifier: LGPL-3-or-later
# Copyright 2022 Jacob Lifshay programmerjake@gmail.com
+import operator
+import pickle
import unittest
from nmutil.plain_data import FrozenPlainDataError, plain_data
-@plain_data()
+@plain_data(order=True)
class PlainData0:
__slots__ = ()
-@plain_data()
+@plain_data(order=True)
class PlainData1:
__slots__ = "a", "b", "x", "y"
self.y = y
-@plain_data()
+@plain_data(order=True)
class PlainData2(PlainData1):
__slots__ = "a", "z"
self.z = z
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(order=True, frozen=True, unsafe_hash=True)
class PlainDataF0:
__slots__ = ()
-@plain_data(frozen=True, unsafe_hash=True)
+@plain_data(order=True, frozen=True, unsafe_hash=True)
class PlainDataF1:
__slots__ = "a", "b", "x", "y"
self.y = y
+@plain_data(order=True, frozen=True, unsafe_hash=True)
+class PlainDataF2(PlainDataF1):
+ __slots__ = "a", "z"
+
+ def __init__(self, a, b, *, x, y, z):
+ super().__init__(a, b, x=x, y=y)
+ self.z = z
+
+
class TestPlainData(unittest.TestCase):
- def test_repr(self):
- self.assertEqual(repr(PlainData0()), "PlainData0()")
- self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
- "PlainData1(a=1, b=2, x='x', y='y')")
- self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
- "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+ def test_fields(self):
+ self.assertEqual(PlainData0._fields, ())
+ self.assertEqual(PlainData1._fields, ("a", "b", "x", "y"))
+ self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z"))
+ self.assertEqual(PlainDataF0._fields, ())
+ self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y"))
+ self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z"))
def test_eq(self):
self.assertTrue(PlainData0() == PlainData0())
self.assertFalse('a' == PlainData0())
+ self.assertFalse(PlainDataF0() == PlainData0())
self.assertTrue(PlainData1(1, 2, x="x", y="y")
== PlainData1(1, 2, x="x", y="y"))
self.assertFalse(PlainData1(1, 2, x="x", y="y")
self.assertFalse(PlainData1(1, 2, x="x", y="y")
== PlainData2(1, 2, x="x", y="y", z=3))
+ def test_hash(self):
+ def check_op(v, tuple_v):
+ with self.subTest(v=v, tuple_v=tuple_v):
+ self.assertEqual(hash(v), hash(tuple_v))
+
+ def check(a, b, x, y, z):
+ tuple_v = a, b, x, y, z
+ v = PlainDataF2(a=a, b=b, x=x, y=y, z=z)
+ check_op(v, tuple_v)
+
+ check(1, 2, "x", "y", "z")
+
+ check(1, 2, "x", "y", "a")
+ check(1, 2, "x", "y", "zz")
+
+ check(1, 2, "x", "a", "z")
+ check(1, 2, "x", "zz", "z")
+
+ check(1, 2, "a", "y", "z")
+ check(1, 2, "zz", "y", "z")
+
+ check(1, -10, "x", "y", "z")
+ check(1, 10, "x", "y", "z")
+
+ check(-10, 2, "x", "y", "z")
+ check(10, 2, "x", "y", "z")
+
+ def test_order(self):
+ def check_op(l, r, tuple_l, tuple_r, op):
+ with self.subTest(l=l, r=r,
+ tuple_l=tuple_l, tuple_r=tuple_r, op=op):
+ self.assertEqual(op(l, r), op(tuple_l, tuple_r))
+ self.assertEqual(op(r, l), op(tuple_r, tuple_l))
+
+ def check(a, b, x, y, z):
+ tuple_l = 1, 2, "x", "y", "z"
+ l = PlainData2(a=1, b=2, x="x", y="y", z="z")
+ tuple_r = a, b, x, y, z
+ r = PlainData2(a=a, b=b, x=x, y=y, z=z)
+ check_op(l, r, tuple_l, tuple_r, operator.eq)
+ check_op(l, r, tuple_l, tuple_r, operator.ne)
+ check_op(l, r, tuple_l, tuple_r, operator.lt)
+ check_op(l, r, tuple_l, tuple_r, operator.le)
+ check_op(l, r, tuple_l, tuple_r, operator.gt)
+ check_op(l, r, tuple_l, tuple_r, operator.ge)
+
+ check(1, 2, "x", "y", "z")
+
+ check(1, 2, "x", "y", "a")
+ check(1, 2, "x", "y", "zz")
+
+ check(1, 2, "x", "a", "z")
+ check(1, 2, "x", "zz", "z")
+
+ check(1, 2, "a", "y", "z")
+ check(1, 2, "zz", "y", "z")
+
+ check(1, -10, "x", "y", "z")
+ check(1, 10, "x", "y", "z")
+
+ check(-10, 2, "x", "y", "z")
+ check(10, 2, "x", "y", "z")
+
+ def test_repr(self):
+ self.assertEqual(repr(PlainData0()), "PlainData0()")
+ self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")),
+ "PlainData1(a=1, b=2, x='x', y='y')")
+ self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)),
+ "PlainData2(a=1, b=2, x='x', y='y', z=3)")
+ self.assertEqual(repr(PlainDataF2(1, 2, x="x", y="y", z=3)),
+ "PlainDataF2(a=1, b=2, x='x', y='y', z=3)")
+
def test_frozen(self):
not_frozen = PlainData0()
not_frozen.a = 1
with self.assertRaises(FrozenPlainDataError):
frozen1.a = 1
- # FIXME: add more tests
+ def test_pickle(self):
+ def check(v):
+ with self.subTest(v=v):
+ self.assertEqual(v, pickle.loads(pickle.dumps(v)))
+
+ check(PlainData0())
+ check(PlainData1(a=1, b=2, x="x", y="y"))
+ check(PlainData2(a=1, b=2, x="x", y="y", z="z"))
+ check(PlainDataF0())
+ check(PlainDataF1(a=1, b=2, x="x", y="y"))
+ check(PlainDataF2(a=1, b=2, x="x", y="y", z="z"))
if __name__ == "__main__":