add deduped
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 9 Oct 2021 01:00:10 +0000 (18:00 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 9 Oct 2021 01:00:10 +0000 (18:00 -0700)
src/nmutil/deduped.py [new file with mode: 0644]
src/nmutil/test/test_deduped.py [new file with mode: 0644]

diff --git a/src/nmutil/deduped.py b/src/nmutil/deduped.py
new file mode 100644 (file)
index 0000000..d6fca1c
--- /dev/null
@@ -0,0 +1,83 @@
+import functools
+import weakref
+
+
+class _KeyBuilder:
+    def __init__(self, do_delete):
+        self.__keys = []
+        self.__refs = {}
+        self.__do_delete = do_delete
+
+    def add_ref(self, v):
+        v_id = id(v)
+        if v_id in self.__refs:
+            return
+        try:
+            v = weakref.ref(v, callback=self.__do_delete)
+        except TypeError:
+            pass
+        self.__refs[v_id] = v
+
+    def add(self, k, v):
+        self.__keys.append(id(k))
+        self.__keys.append(id(v))
+        self.add_ref(k)
+        self.add_ref(v)
+
+    def finish(self):
+        return tuple(self.__keys), tuple(self.__refs.values())
+
+
+def deduped(*, global_keys=()):
+    """decorator that causes functions to deduplicate their results based on
+    their input args and the requested globals. For each set of arguments, it
+    will always return the exact same object, by storing it internally.
+    Arguments are compared by their identity, so they don't need to be
+    hashable.
+
+    Usage:
+    ```
+    # for functions that don't depend on global variables
+    @deduped()
+    def my_fn1(a, b, *, c=1):
+        return a + b * c
+
+    my_global = 23
+
+    # for functions that depend on global variables
+    @deduped(global_keys=[lambda: my_global])
+    def my_fn2(a, b, *, c=2):
+        return a + b * c + my_global
+    ```
+    """
+    global_keys = tuple(global_keys)
+    assert all(map(callable, global_keys))
+
+    def decorator(f):
+        if isinstance(f, (staticmethod, classmethod)):
+            raise TypeError("@staticmethod or @classmethod should be applied "
+                            "to the result of @deduped, not the other way"
+                            " around")
+        assert callable(f)
+
+        map = {}
+
+        @functools.wraps(f)
+        def wrapper(*args, **kwargs):
+            key_builder = _KeyBuilder(lambda _: map.pop(key, None))
+            for arg in args:
+                key_builder.add(None, arg)
+            for k, v in kwargs.items():
+                key_builder.add(k, v)
+            for global_key in global_keys:
+                key_builder.add(None, global_key())
+            key, refs = key_builder.finish()
+            if key in map:
+                return map[key][0]
+            retval = f(*args, **kwargs)
+            # keep reference to stuff used for key to avoid ids
+            # getting reused for something else.
+            map[key] = retval, refs
+            return retval
+        return wrapper
+    return decorator
diff --git a/src/nmutil/test/test_deduped.py b/src/nmutil/test/test_deduped.py
new file mode 100644 (file)
index 0000000..42a4edf
--- /dev/null
@@ -0,0 +1,90 @@
+import unittest
+from nmutil.deduped import deduped
+
+
+class TestDeduped(unittest.TestCase):
+    def test_deduped1(self):
+        global_key = 1
+        call_count = 0
+
+        def call_counter():
+            nonlocal call_count
+            retval = call_count
+            call_count += 1
+            return retval
+
+        class C:
+            def __init__(self, name):
+                self.name = name
+
+            @deduped()
+            def method(self, a, *, b=1):
+                return self, a, b, call_counter()
+
+            @deduped(global_keys=[lambda: global_key])
+            def method_with_global(self, a, *, b=1):
+                return self, a, b, call_counter(), global_key
+
+            @staticmethod
+            @deduped()
+            def smethod(a, *, b=1):
+                return a, b, call_counter()
+
+            @classmethod
+            @deduped()
+            def cmethod(cls, a, *, b=1):
+                return cls, a, b, call_counter()
+
+            def __repr__(self):
+                return f"{self.__class__.__name__}({self.name})"
+
+        class D(C):
+            pass
+
+        c1 = C("c1")
+        c2 = C("c2")
+
+        # run everything twice to ensure caching works
+        for which_pass in ("first", "second"):
+            with self.subTest(which_pass=which_pass):
+                self.assertEqual(C.cmethod(1), (C, 1, 1, 0))
+                self.assertEqual(C.cmethod(2), (C, 2, 1, 1))
+                self.assertEqual(C.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(D.cmethod(1, b=5), (D, 1, 5, 3))
+                self.assertEqual(D.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(C.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(None), (c1, None, 1, 5))
+                global_key = 2
+                self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
+                self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
+                self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 8, 2))
+                global_key = 1
+                self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
+                self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
+                self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
+                self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
+                self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 9, 1))
+        self.assertEqual(call_count, 10)
+
+    def test_bad_methods(self):
+        with self.assertRaisesRegex(TypeError,
+                                    ".*@staticmethod.*applied.*@deduped.*"):
+            class C:
+                @deduped()
+                @staticmethod
+                def f():
+                    pass
+
+        with self.assertRaisesRegex(TypeError,
+                                    ".*@classmethod.*applied.*@deduped.*"):
+            class C:
+                @deduped()
+                @classmethod
+                def f():
+                    pass
+
+
+if __name__ == '__main__':
+    unittest.main()