Support cached coroutines (async/await)
authorVolker Braun <vbraun.name@gmail.com>
Sun, 25 Feb 2018 14:30:58 +0000 (15:30 +0100)
committerVolker Braun <vbraun.name@gmail.com>
Sun, 25 Feb 2018 16:13:16 +0000 (17:13 +0100)
.gitignore
README.rst
cached_property.py
conftest.py [new file with mode: 0644]
tests/test_async_cached_property.py [new file with mode: 0644]
tests/test_coroutine_cached_property.py [new file with mode: 0644]
tox.ini

index 07ee70684356161846040599ba78834f34a2cb39..f456abc935db5e200c64aee6e22a76d00f3976b8 100644 (file)
@@ -26,6 +26,8 @@ pip-log.txt
 .tox
 nosetests.xml
 htmlcov
+.cache
+.pytest_cache
 
 # Translations
 *.mo
index d7e46f54ef3965ae6dd095a9dbc1fa1bd60ef236..27bb40aa5c35ec05ba42bc5c15f55ac543b29ee8 100644 (file)
@@ -145,6 +145,49 @@ Now use it:
     >>> self.assertEqual(m.boardwalk, 550)
 
 
+Working with async/await (Python 3.5+)
+--------------------------------------
+
+The cached property can be async, in which case you have to use await
+as usual to get the value. Because of the caching, the value is only
+computed once and then cached:
+
+.. code-block:: python
+
+    from cached_property import cached_property
+
+    class Monopoly(object):
+
+        def __init__(self):
+            self.boardwalk_price = 500
+
+        @cached_property
+        async def boardwalk(self):
+            self.boardwalk_price += 50
+            return self.boardwalk_price
+
+Now use it:
+
+.. code-block:: python
+
+    >>> async def print_boardwalk():
+    ...     monopoly = Monopoly()
+    ...     print(await monopoly.boardwalk)
+    ...     print(await monopoly.boardwalk)
+    ...     print(await monopoly.boardwalk)
+    >>> import asyncio
+    >>> asyncio.get_event_loop().run_until_complete(print_boardwalk())
+    550
+    550
+    550
+
+Note that this does not work with threading either, most asyncio
+objects are not thread-safe. And if you run separate event loops in
+each thread, the cached version will most likely have the wrong event
+loop. To summarize, either use cooperative multitasking (event loop)
+or threading, but not both at the same time.
+
+
 Timing out the cache
 --------------------
 
index 7e8bebebcc8cef430e2d396a6ffaff560e6e7ecc..191eb099003d9a3b107eb254f16959704ed53594 100644 (file)
@@ -7,6 +7,10 @@ __license__ = 'BSD'
 
 from time import time
 import threading
+try:
+    import asyncio
+except ImportError:
+    asyncio = None
 
 
 class cached_property(object):
@@ -23,9 +27,19 @@ class cached_property(object):
     def __get__(self, obj, cls):
         if obj is None:
             return self
+        if asyncio and asyncio.iscoroutinefunction(self.func):
+            return self._wrap_in_coroutine(obj)
         value = obj.__dict__[self.func.__name__] = self.func(obj)
         return value
 
+    def _wrap_in_coroutine(self, obj):
+        @asyncio.coroutine
+        def wrapper():
+            future = asyncio.ensure_future(self.func(obj))
+            obj.__dict__[self.func.__name__] = future
+            return future
+        return wrapper()
+
 
 class threaded_cached_property(object):
     """
diff --git a/conftest.py b/conftest.py
new file mode 100644 (file)
index 0000000..d68017d
--- /dev/null
@@ -0,0 +1,24 @@
+
+import sys
+
+# Whether "import asyncio" works
+has_asyncio = (
+    sys.version_info[0] == 3 and sys.version_info[1] >= 4
+)
+
+# Whether the async and await keywords work
+has_async_await = (
+    sys.version_info[0] == 3 and sys.version_info[1] >= 5
+)
+
+
+print('conftest.py', has_asyncio, has_async_await)
+
+
+collect_ignore = []
+
+if not has_asyncio:
+    collect_ignore.append('tests/test_coroutine_cached_property.py')
+
+if not has_async_await:
+    collect_ignore.append('tests/test_async_cached_property.py')
diff --git a/tests/test_async_cached_property.py b/tests/test_async_cached_property.py
new file mode 100644 (file)
index 0000000..36892be
--- /dev/null
@@ -0,0 +1,135 @@
+# -*- coding: utf-8 -*-
+
+import time
+import unittest
+import asyncio
+from threading import Lock, Thread
+from freezegun import freeze_time
+
+import cached_property
+
+
+def unittest_run_loop(f):
+    def wrapper(*args, **kwargs):
+        coro = asyncio.coroutine(f)
+        future = coro(*args, **kwargs)
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(future)
+    return wrapper
+
+
+def CheckFactory(cached_property_decorator, threadsafe=False):
+    """
+    Create dynamically a Check class whose add_cached method is decorated by
+    the cached_property_decorator.
+    """
+
+    class Check(object):
+
+        def __init__(self):
+            self.control_total = 0
+            self.cached_total = 0
+            self.lock = Lock()
+
+        async def add_control(self):
+            self.control_total += 1
+            return self.control_total
+
+        @cached_property_decorator
+        async def add_cached(self):
+            if threadsafe:
+                time.sleep(1)
+                # Need to guard this since += isn't atomic.
+                with self.lock:
+                    self.cached_total += 1
+            else:
+                self.cached_total += 1
+
+            return self.cached_total
+
+        def run_threads(self, num_threads):
+            threads = []
+            for _ in range(num_threads):
+                def call_add_cached():
+                    loop = asyncio.new_event_loop()
+                    asyncio.set_event_loop(loop)
+                    loop.run_until_complete(self.add_cached)
+                thread = Thread(target=call_add_cached)
+                thread.start()
+                threads.append(thread)
+            for thread in threads:
+                thread.join()
+
+    return Check
+
+
+class TestCachedProperty(unittest.TestCase):
+    """Tests for cached_property"""
+
+    cached_property_factory = cached_property.cached_property
+
+    async def assert_control(self, check, expected):
+        """
+        Assert that both `add_control` and 'control_total` equal `expected`
+        """
+        self.assertEqual(await check.add_control(), expected)
+        self.assertEqual(check.control_total, expected)
+
+    async def assert_cached(self, check, expected):
+        """
+        Assert that both `add_cached` and 'cached_total` equal `expected`
+        """
+        print('assert_cached', check.add_cached)
+        self.assertEqual(await check.add_cached, expected)
+        self.assertEqual(check.cached_total, expected)
+
+    @unittest_run_loop
+    async def test_cached_property(self):
+        Check = CheckFactory(self.cached_property_factory)
+        check = Check()
+
+        # The control shows that we can continue to add 1
+        await self.assert_control(check, 1)
+        await self.assert_control(check, 2)
+
+        # The cached version demonstrates how nothing is added after the first
+        await self.assert_cached(check, 1)
+        await self.assert_cached(check, 1)
+
+        # The cache does not expire
+        with freeze_time("9999-01-01"):
+            await self.assert_cached(check, 1)
+
+        # Typically descriptors return themselves if accessed though the class
+        # rather than through an instance.
+        self.assertTrue(isinstance(Check.add_cached,
+                                   self.cached_property_factory))
+
+    @unittest_run_loop
+    async def test_reset_cached_property(self):
+        Check = CheckFactory(self.cached_property_factory)
+        check = Check()
+
+        # Run standard cache assertion
+        await self.assert_cached(check, 1)
+        await self.assert_cached(check, 1)
+
+        # Clear the cache
+        del check.add_cached
+
+        # Value is cached again after the next access
+        await self.assert_cached(check, 2)
+        await self.assert_cached(check, 2)
+
+    @unittest_run_loop
+    async def test_none_cached_property(self):
+        class Check(object):
+
+            def __init__(self):
+                self.cached_total = None
+
+            @self.cached_property_factory
+            async def add_cached(self):
+                return self.cached_total
+
+        await self.assert_cached(Check(), None)
diff --git a/tests/test_coroutine_cached_property.py b/tests/test_coroutine_cached_property.py
new file mode 100644 (file)
index 0000000..ede9baf
--- /dev/null
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+"""
+The same tests as in :mod:`.test_async_cached_property`, but with the old
+yield from instead of the new async/await syntax. Used to test Python 3.4
+compatibility which has asyncio but doesn't have async/await yet.
+"""
+
+import unittest
+import asyncio
+from freezegun import freeze_time
+
+import cached_property
+
+
+def unittest_run_loop(f):
+    def wrapper(*args, **kwargs):
+        coro = asyncio.coroutine(f)
+        future = coro(*args, **kwargs)
+        loop = asyncio.get_event_loop()
+        loop.run_until_complete(future)
+    return wrapper
+
+
+def CheckFactory(cached_property_decorator):
+    """
+    Create dynamically a Check class whose add_cached method is decorated by
+    the cached_property_decorator.
+    """
+
+    class Check(object):
+
+        def __init__(self):
+            self.control_total = 0
+            self.cached_total = 0
+
+        @asyncio.coroutine
+        def add_control(self):
+            self.control_total += 1
+            return self.control_total
+
+        @cached_property_decorator
+        @asyncio.coroutine
+        def add_cached(self):
+            self.cached_total += 1
+            return self.cached_total
+
+    return Check
+
+
+class TestCachedProperty(unittest.TestCase):
+    """Tests for cached_property"""
+
+    cached_property_factory = cached_property.cached_property
+
+    @asyncio.coroutine
+    def assert_control(self, check, expected):
+        """
+        Assert that both `add_control` and 'control_total` equal `expected`
+        """
+        value = yield from check.add_control()
+        self.assertEqual(value, expected)
+        self.assertEqual(check.control_total, expected)
+
+    @asyncio.coroutine
+    def assert_cached(self, check, expected):
+        """
+        Assert that both `add_cached` and 'cached_total` equal `expected`
+        """
+        print('assert_cached', check.add_cached)
+        value = yield from check.add_cached
+        self.assertEqual(value, expected)
+        self.assertEqual(check.cached_total, expected)
+
+    @unittest_run_loop
+    @asyncio.coroutine
+    def test_cached_property(self):
+        Check = CheckFactory(self.cached_property_factory)
+        check = Check()
+
+        # The control shows that we can continue to add 1
+        yield from self.assert_control(check, 1)
+        yield from self.assert_control(check, 2)
+
+        # The cached version demonstrates how nothing is added after the first
+        yield from self.assert_cached(check, 1)
+        yield from self.assert_cached(check, 1)
+
+        # The cache does not expire
+        with freeze_time("9999-01-01"):
+            yield from self.assert_cached(check, 1)
+
+        # Typically descriptors return themselves if accessed though the class
+        # rather than through an instance.
+        self.assertTrue(isinstance(Check.add_cached,
+                                   self.cached_property_factory))
+
+    @unittest_run_loop
+    @asyncio.coroutine
+    def test_reset_cached_property(self):
+        Check = CheckFactory(self.cached_property_factory)
+        check = Check()
+
+        # Run standard cache assertion
+        yield from self.assert_cached(check, 1)
+        yield from self.assert_cached(check, 1)
+
+        # Clear the cache
+        del check.add_cached
+
+        # Value is cached again after the next access
+        yield from self.assert_cached(check, 2)
+        yield from self.assert_cached(check, 2)
+
+    @unittest_run_loop
+    @asyncio.coroutine
+    def test_none_cached_property(self):
+        class Check(object):
+
+            def __init__(self):
+                self.cached_total = None
+
+            @self.cached_property_factory
+            @asyncio.coroutine
+            def add_cached(self):
+                return self.cached_total
+
+        yield from self.assert_cached(Check(), None)
diff --git a/tox.ini b/tox.ini
index 33e856efa2bd5661512251f6cb68739c5de674ad..9f42a58f0b8b6945d9e1dc18efd28cc88882e30a 100644 (file)
--- a/tox.ini
+++ b/tox.ini
@@ -7,3 +7,4 @@ setenv =
 commands = py.test
 deps =
     pytest
+    freezegun