From aeb790871362b081482413f20d2a82df230efc6d Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 11 Aug 2022 23:08:44 -0700 Subject: [PATCH] finish implementing @plain_data() --- src/nmutil/plain_data.py | 31 +++++--- src/nmutil/test/test_plain_data.py | 119 ++++++++++++++++++++++++++--- 2 files changed, 126 insertions(+), 24 deletions(-) diff --git a/src/nmutil/plain_data.py b/src/nmutil/plain_data.py index b1912f1..92c303d 100644 --- a/src/nmutil/plain_data.py +++ b/src/nmutil/plain_data.py @@ -12,6 +12,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): # slots is an ordered set by using dict keys. # always add __dict__ and __weakref__ slots = {"__dict__": None, "__weakref__": None} + if frozen: + slots["__plain_data_init_done"] = None fields = [] any_parents_have_dict = False any_parents_have_weakref = False @@ -41,6 +43,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): elif field == "__weakref__": any_parents_have_weakref = True + fields = tuple(fields) # fields needs to be immutable + if any_parents_have_dict: # work around a CPython bug that unnecessarily checks if parent # classes already have the __dict__ slot. @@ -57,6 +61,8 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): for name in slots.keys(): retval_dict.pop(name, None) + retval_dict["_fields"] = fields + def add_method_or_error(value, replace=False): name = value.__name__ if name in retval_dict and not replace: @@ -66,8 +72,6 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): retval_dict[name] = value if frozen: - slots["__plain_data_init_done"] = None - def __setattr__(self, name: str, value): if getattr(self, "__plain_data_init_done", False): raise FrozenPlainDataError(f"cannot assign to field {name!r}") @@ -100,16 +104,12 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): object.__setattr__(self, "__plain_data_init_done", True) add_method_or_error(__init__, replace=True) + else: + old_init = None # set __slots__ to have everything we need in the preferred order retval_dict["__slots__"] = tuple(slots.keys()) - def __dir__(self): - # don't return fields un-copied so users can't mess with it - return fields.copy() - - add_method_or_error(__dir__) - def __getstate__(self): # pickling support return [getattr(self, name) for name in fields] @@ -186,9 +186,9 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): # add __qualname__ retval.__qualname__ = cls.__qualname__ - # fixup super() and __class__ - # derived from: https://stackoverflow.com/a/71666065/2597900 - for value in retval.__dict__.values(): + def fix_super_and_class(value): + # fixup super() and __class__ + # derived from: https://stackoverflow.com/a/71666065/2597900 try: closure = value.__closure__ if isinstance(closure, tuple): @@ -197,11 +197,18 @@ def _decorator(cls, *, eq, unsafe_hash, order, repr_, frozen): except (AttributeError, IndexError): pass + for value in retval.__dict__.values(): + fix_super_and_class(value) + + if old_init is not None: + fix_super_and_class(old_init) + return retval -def plain_data(*, eq=True, unsafe_hash=False, order=True, repr=True, +def plain_data(*, eq=True, unsafe_hash=False, order=False, repr=True, frozen=False): + # defaults match dataclass, with the exception of `init` """ Decorator for adding equality comparison, ordered comparison, `repr` support, `hash` support, and frozen type (read-only fields) support to classes that are just plain data. diff --git a/src/nmutil/test/test_plain_data.py b/src/nmutil/test/test_plain_data.py index b07d64b..93facbb 100644 --- a/src/nmutil/test/test_plain_data.py +++ b/src/nmutil/test/test_plain_data.py @@ -1,16 +1,18 @@ # SPDX-License-Identifier: LGPL-3-or-later # Copyright 2022 Jacob Lifshay programmerjake@gmail.com +import operator +import pickle import unittest from nmutil.plain_data import FrozenPlainDataError, plain_data -@plain_data() +@plain_data(order=True) class PlainData0: __slots__ = () -@plain_data() +@plain_data(order=True) class PlainData1: __slots__ = "a", "b", "x", "y" @@ -21,7 +23,7 @@ class PlainData1: self.y = y -@plain_data() +@plain_data(order=True) class PlainData2(PlainData1): __slots__ = "a", "z" @@ -30,12 +32,12 @@ class PlainData2(PlainData1): self.z = z -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(order=True, frozen=True, unsafe_hash=True) class PlainDataF0: __slots__ = () -@plain_data(frozen=True, unsafe_hash=True) +@plain_data(order=True, frozen=True, unsafe_hash=True) class PlainDataF1: __slots__ = "a", "b", "x", "y" @@ -46,17 +48,28 @@ class PlainDataF1: self.y = y +@plain_data(order=True, frozen=True, unsafe_hash=True) +class PlainDataF2(PlainDataF1): + __slots__ = "a", "z" + + def __init__(self, a, b, *, x, y, z): + super().__init__(a, b, x=x, y=y) + self.z = z + + class TestPlainData(unittest.TestCase): - def test_repr(self): - self.assertEqual(repr(PlainData0()), "PlainData0()") - self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")), - "PlainData1(a=1, b=2, x='x', y='y')") - self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)), - "PlainData2(a=1, b=2, x='x', y='y', z=3)") + def test_fields(self): + self.assertEqual(PlainData0._fields, ()) + self.assertEqual(PlainData1._fields, ("a", "b", "x", "y")) + self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z")) + self.assertEqual(PlainDataF0._fields, ()) + self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y")) + self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z")) def test_eq(self): self.assertTrue(PlainData0() == PlainData0()) self.assertFalse('a' == PlainData0()) + self.assertFalse(PlainDataF0() == PlainData0()) self.assertTrue(PlainData1(1, 2, x="x", y="y") == PlainData1(1, 2, x="x", y="y")) self.assertFalse(PlainData1(1, 2, x="x", y="y") @@ -64,6 +77,78 @@ class TestPlainData(unittest.TestCase): self.assertFalse(PlainData1(1, 2, x="x", y="y") == PlainData2(1, 2, x="x", y="y", z=3)) + def test_hash(self): + def check_op(v, tuple_v): + with self.subTest(v=v, tuple_v=tuple_v): + self.assertEqual(hash(v), hash(tuple_v)) + + def check(a, b, x, y, z): + tuple_v = a, b, x, y, z + v = PlainDataF2(a=a, b=b, x=x, y=y, z=z) + check_op(v, tuple_v) + + check(1, 2, "x", "y", "z") + + check(1, 2, "x", "y", "a") + check(1, 2, "x", "y", "zz") + + check(1, 2, "x", "a", "z") + check(1, 2, "x", "zz", "z") + + check(1, 2, "a", "y", "z") + check(1, 2, "zz", "y", "z") + + check(1, -10, "x", "y", "z") + check(1, 10, "x", "y", "z") + + check(-10, 2, "x", "y", "z") + check(10, 2, "x", "y", "z") + + def test_order(self): + def check_op(l, r, tuple_l, tuple_r, op): + with self.subTest(l=l, r=r, + tuple_l=tuple_l, tuple_r=tuple_r, op=op): + self.assertEqual(op(l, r), op(tuple_l, tuple_r)) + self.assertEqual(op(r, l), op(tuple_r, tuple_l)) + + def check(a, b, x, y, z): + tuple_l = 1, 2, "x", "y", "z" + l = PlainData2(a=1, b=2, x="x", y="y", z="z") + tuple_r = a, b, x, y, z + r = PlainData2(a=a, b=b, x=x, y=y, z=z) + check_op(l, r, tuple_l, tuple_r, operator.eq) + check_op(l, r, tuple_l, tuple_r, operator.ne) + check_op(l, r, tuple_l, tuple_r, operator.lt) + check_op(l, r, tuple_l, tuple_r, operator.le) + check_op(l, r, tuple_l, tuple_r, operator.gt) + check_op(l, r, tuple_l, tuple_r, operator.ge) + + check(1, 2, "x", "y", "z") + + check(1, 2, "x", "y", "a") + check(1, 2, "x", "y", "zz") + + check(1, 2, "x", "a", "z") + check(1, 2, "x", "zz", "z") + + check(1, 2, "a", "y", "z") + check(1, 2, "zz", "y", "z") + + check(1, -10, "x", "y", "z") + check(1, 10, "x", "y", "z") + + check(-10, 2, "x", "y", "z") + check(10, 2, "x", "y", "z") + + def test_repr(self): + self.assertEqual(repr(PlainData0()), "PlainData0()") + self.assertEqual(repr(PlainData1(1, 2, x="x", y="y")), + "PlainData1(a=1, b=2, x='x', y='y')") + self.assertEqual(repr(PlainData2(1, 2, x="x", y="y", z=3)), + "PlainData2(a=1, b=2, x='x', y='y', z=3)") + self.assertEqual(repr(PlainDataF2(1, 2, x="x", y="y", z=3)), + "PlainDataF2(a=1, b=2, x='x', y='y', z=3)") + def test_frozen(self): not_frozen = PlainData0() not_frozen.a = 1 @@ -74,7 +159,17 @@ class TestPlainData(unittest.TestCase): with self.assertRaises(FrozenPlainDataError): frozen1.a = 1 - # FIXME: add more tests + def test_pickle(self): + def check(v): + with self.subTest(v=v): + self.assertEqual(v, pickle.loads(pickle.dumps(v))) + + check(PlainData0()) + check(PlainData1(a=1, b=2, x="x", y="y")) + check(PlainData2(a=1, b=2, x="x", y="y", z="z")) + check(PlainDataF0()) + check(PlainDataF1(a=1, b=2, x="x", y="y")) + check(PlainDataF2(a=1, b=2, x="x", y="y", z="z")) if __name__ == "__main__": -- 2.30.2