From: Jacob Lifshay Date: Sat, 9 Oct 2021 01:00:10 +0000 (-0700) Subject: add deduped X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=fbb284ecf93ead5d748f9c19e5b1b899a06c6b55;p=nmutil.git add deduped --- diff --git a/src/nmutil/deduped.py b/src/nmutil/deduped.py new file mode 100644 index 0000000..d6fca1c --- /dev/null +++ b/src/nmutil/deduped.py @@ -0,0 +1,83 @@ +import functools +import weakref + + +class _KeyBuilder: + def __init__(self, do_delete): + self.__keys = [] + self.__refs = {} + self.__do_delete = do_delete + + def add_ref(self, v): + v_id = id(v) + if v_id in self.__refs: + return + try: + v = weakref.ref(v, callback=self.__do_delete) + except TypeError: + pass + self.__refs[v_id] = v + + def add(self, k, v): + self.__keys.append(id(k)) + self.__keys.append(id(v)) + self.add_ref(k) + self.add_ref(v) + + def finish(self): + return tuple(self.__keys), tuple(self.__refs.values()) + + +def deduped(*, global_keys=()): + """decorator that causes functions to deduplicate their results based on + their input args and the requested globals. For each set of arguments, it + will always return the exact same object, by storing it internally. + Arguments are compared by their identity, so they don't need to be + hashable. + + Usage: + ``` + # for functions that don't depend on global variables + @deduped() + def my_fn1(a, b, *, c=1): + return a + b * c + + my_global = 23 + + # for functions that depend on global variables + @deduped(global_keys=[lambda: my_global]) + def my_fn2(a, b, *, c=2): + return a + b * c + my_global + ``` + """ + global_keys = tuple(global_keys) + assert all(map(callable, global_keys)) + + def decorator(f): + if isinstance(f, (staticmethod, classmethod)): + raise TypeError("@staticmethod or @classmethod should be applied " + "to the result of @deduped, not the other way" + " around") + assert callable(f) + + map = {} + + @functools.wraps(f) + def wrapper(*args, **kwargs): + key_builder = _KeyBuilder(lambda _: map.pop(key, None)) + for arg in args: + key_builder.add(None, arg) + for k, v in kwargs.items(): + key_builder.add(k, v) + for global_key in global_keys: + key_builder.add(None, global_key()) + key, refs = key_builder.finish() + if key in map: + return map[key][0] + retval = f(*args, **kwargs) + # keep reference to stuff used for key to avoid ids + # getting reused for something else. + map[key] = retval, refs + return retval + return wrapper + return decorator diff --git a/src/nmutil/test/test_deduped.py b/src/nmutil/test/test_deduped.py new file mode 100644 index 0000000..42a4edf --- /dev/null +++ b/src/nmutil/test/test_deduped.py @@ -0,0 +1,90 @@ +import unittest +from nmutil.deduped import deduped + + +class TestDeduped(unittest.TestCase): + def test_deduped1(self): + global_key = 1 + call_count = 0 + + def call_counter(): + nonlocal call_count + retval = call_count + call_count += 1 + return retval + + class C: + def __init__(self, name): + self.name = name + + @deduped() + def method(self, a, *, b=1): + return self, a, b, call_counter() + + @deduped(global_keys=[lambda: global_key]) + def method_with_global(self, a, *, b=1): + return self, a, b, call_counter(), global_key + + @staticmethod + @deduped() + def smethod(a, *, b=1): + return a, b, call_counter() + + @classmethod + @deduped() + def cmethod(cls, a, *, b=1): + return cls, a, b, call_counter() + + def __repr__(self): + return f"{self.__class__.__name__}({self.name})" + + class D(C): + pass + + c1 = C("c1") + c2 = C("c2") + + # run everything twice to ensure caching works + for which_pass in ("first", "second"): + with self.subTest(which_pass=which_pass): + self.assertEqual(C.cmethod(1), (C, 1, 1, 0)) + self.assertEqual(C.cmethod(2), (C, 2, 1, 1)) + self.assertEqual(C.cmethod(1, b=5), (C, 1, 5, 2)) + self.assertEqual(D.cmethod(1, b=5), (D, 1, 5, 3)) + self.assertEqual(D.smethod(1, b=5), (1, 5, 4)) + self.assertEqual(C.smethod(1, b=5), (1, 5, 4)) + self.assertEqual(c1.method(None), (c1, None, 1, 5)) + global_key = 2 + self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2)) + self.assertEqual(c1.smethod(1, b=5), (1, 5, 4)) + self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6)) + self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7)) + self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 8, 2)) + global_key = 1 + self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2)) + self.assertEqual(c1.smethod(1, b=5), (1, 5, 4)) + self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6)) + self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7)) + self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 9, 1)) + self.assertEqual(call_count, 10) + + def test_bad_methods(self): + with self.assertRaisesRegex(TypeError, + ".*@staticmethod.*applied.*@deduped.*"): + class C: + @deduped() + @staticmethod + def f(): + pass + + with self.assertRaisesRegex(TypeError, + ".*@classmethod.*applied.*@deduped.*"): + class C: + @deduped() + @classmethod + def f(): + pass + + +if __name__ == '__main__': + unittest.main()