for name in slots.keys():
retval_dict.pop(name, None)
- retval_dict["_fields"] = fields
+ retval_dict["__plain_data_fields"] = fields
def add_method_or_error(value, replace=False):
name = value.__name__
return _decorator(cls, eq=eq, unsafe_hash=unsafe_hash, order=order,
repr_=repr, frozen=frozen)
return decorator
+
+
+def fields(pd):
+ """ get the tuple of field names of the passed-in
+ `@plain_data()`-decorated class.
+
+ This is similar to `dataclasses.fields`, except this returns a
+ different type.
+
+ Returns: tuple[str, ...]
+
+ e.g.:
+ ```
+ @plain_data()
+ class MyBaseClass:
+ __slots__ = "a_field", "field2"
+ def __init__(self, a_field, field2):
+ self.a_field = a_field
+ self.field2 = field2
+
+ assert fields(MyBaseClass) == ("a_field", "field2")
+ assert fields(MyBaseClass(1, 2)) == ("a_field", "field2")
+
+ @plain_data()
+ class MyClass(MyBaseClass):
+ __slots__ = "child_field",
+ def __init__(self, a_field, field2, child_field):
+ super().__init__(a_field=a_field, field2=field2)
+ self.child_field = child_field
+
+ assert fields(MyClass) == ("a_field", "field2", "child_field")
+ assert fields(MyClass(1, 2, 3)) == ("a_field", "field2", "child_field")
+ ```
+ """
+ retval = getattr(pd, "__plain_data_fields", None)
+ if not isinstance(retval, tuple):
+ raise TypeError("the passed-in object must be a class or an instance"
+ " of a class decorated with @plain_data()")
+ return retval
+
+
+__NOT_SPECIFIED = object()
+
+
+def replace(pd, **changes):
+ """ Return a new instance of the passed-in `@plain_data()`-decorated
+ object, but with the specified fields replaced with new values.
+ This is quite useful with frozen `@plain_data()` classes.
+
+ e.g.:
+ ```
+ @plain_data(frozen=True)
+ class MyClass:
+ __slots__ = "a", "b", "c"
+ def __init__(self, a, b, *, c):
+ self.a = a
+ self.b = b
+ self.c = c
+
+ v1 = MyClass(1, 2, c=3)
+ v2 = replace(v1, b=4)
+ assert v2 == MyClass(a=1, b=4, c=3)
+ assert v2 is not v1
+ ```
+ """
+ kwargs = {}
+ ty = type(pd)
+ # call fields on ty rather than pd to ensure we're not called with a
+ # class rather than an instance.
+ for name in fields(ty):
+ value = changes.pop(name, __NOT_SPECIFIED)
+ if value is __NOT_SPECIFIED:
+ kwargs[name] = getattr(pd, name)
+ else:
+ kwargs[name] = value
+ if len(changes) != 0:
+ raise TypeError(f"can't set unknown field {changes.popitem()[0]!r}")
+ return ty(**kwargs)
import operator
import pickle
import unittest
-from nmutil.plain_data import FrozenPlainDataError, plain_data
+from nmutil.plain_data import (FrozenPlainDataError, plain_data,
+ fields, replace)
@plain_data(order=True)
class TestPlainData(unittest.TestCase):
def test_fields(self):
- self.assertEqual(PlainData0._fields, ())
- self.assertEqual(PlainData1._fields, ("a", "b", "x", "y"))
- self.assertEqual(PlainData2._fields, ("a", "b", "x", "y", "z"))
- self.assertEqual(PlainDataF0._fields, ())
- self.assertEqual(PlainDataF1._fields, ("a", "b", "x", "y"))
- self.assertEqual(PlainDataF2._fields, ("a", "b", "x", "y", "z"))
+ self.assertEqual(fields(PlainData0), ())
+ self.assertEqual(fields(PlainData0()), ())
+ self.assertEqual(fields(PlainData1), ("a", "b", "x", "y"))
+ self.assertEqual(fields(PlainData1(1, 2, x="x", y="y")),
+ ("a", "b", "x", "y"))
+ self.assertEqual(fields(PlainData2), ("a", "b", "x", "y", "z"))
+ self.assertEqual(fields(PlainData2(1, 2, x="x", y="y", z=3)),
+ ("a", "b", "x", "y", "z"))
+ self.assertEqual(fields(PlainDataF0), ())
+ self.assertEqual(fields(PlainDataF0()), ())
+ self.assertEqual(fields(PlainDataF1), ("a", "b", "x", "y"))
+ self.assertEqual(fields(PlainDataF1(1, 2, x="x", y="y")),
+ ("a", "b", "x", "y"))
+ self.assertEqual(fields(PlainDataF2), ("a", "b", "x", "y", "z"))
+ self.assertEqual(fields(PlainDataF2(1, 2, x="x", y="y", z=3)),
+ ("a", "b", "x", "y", "z"))
+ with self.assertRaisesRegex(
+ TypeError,
+ r"the passed-in object must be a class or an instance of a "
+ r"class decorated with @plain_data\(\)"):
+ fields(type)
+
+ def test_replace(self):
+ with self.assertRaisesRegex(
+ TypeError,
+ r"the passed-in object must be a class or an instance of a "
+ r"class decorated with @plain_data\(\)"):
+ replace(PlainData0)
+ with self.assertRaisesRegex(TypeError, "can't set unknown field 'a'"):
+ replace(PlainData0(), a=1)
+ with self.assertRaisesRegex(TypeError, "can't set unknown field 'z'"):
+ replace(PlainDataF1(1, 2, x="x", y="y"), a=3, z=1)
+ self.assertEqual(replace(PlainData0()), PlainData0())
+ self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y")),
+ PlainDataF1(1, 2, x="x", y="y"))
+ self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), a=3),
+ PlainDataF1(3, 2, x="x", y="y"))
+ self.assertEqual(replace(PlainDataF1(1, 2, x="x", y="y"), x=5, a=3),
+ PlainDataF1(3, 2, x=5, y="y"))
def test_eq(self):
self.assertTrue(PlainData0() == PlainData0())