From 4d0bc4336ea3e1de132b2aa29cf6dc1ee2c6fc8b Mon Sep 17 00:00:00 2001 From: George Sakkis Date: Mon, 20 Apr 2015 03:46:09 +0300 Subject: [PATCH] Fix the previous commit and add more test assertions to show why it was wrong. Although the previous commit correctly cached and returned only the first computed value (since dict.setdefault() is atomic), the actual computation could be performed more than once in multithreaded environment, with all but the first computed values being discarded. --- cached_property.py | 29 +++++++++++++++-- tests/test_cached_property.py | 59 ++++++++++++++++++++++++++++------- 2 files changed, 73 insertions(+), 15 deletions(-) diff --git a/cached_property.py b/cached_property.py index 8e07fde..50b40ac 100644 --- a/cached_property.py +++ b/cached_property.py @@ -23,11 +23,34 @@ class cached_property(object): def __get__(self, obj, cls): if obj is None: return self - return obj.__dict__.setdefault(self.func.__name__, self.func(obj)) + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value -# Leave for backwards compatibility -threaded_cached_property = cached_property +class threaded_cached_property(object): + """ + A cached_property version for use in environments where multiple threads + might concurrently try to access the property. + """ + + def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') + self.func = func + self.lock = threading.RLock() + + def __get__(self, obj, cls): + if obj is None: + return self + + obj_dict = obj.__dict__ + name = self.func.__name__ + with self.lock: + try: + # check if the value was computed before the lock was acquired + return obj_dict[name] + except KeyError: + # if not, do the calculation and release the lock + return obj_dict.setdefault(name, self.func(obj)) class cached_property_with_ttl(object): diff --git a/tests/test_cached_property.py b/tests/test_cached_property.py index 56678cf..cd04663 100644 --- a/tests/test_cached_property.py +++ b/tests/test_cached_property.py @@ -1,15 +1,18 @@ # -*- coding: utf-8 -*- -"""Tests for cached_property""" +"""Tests for cached_property and threaded_cached_property""" from time import sleep from threading import Lock, Thread import unittest -from cached_property import cached_property +from cached_property import cached_property, threaded_cached_property class TestCachedProperty(unittest.TestCase): + """Tests for cached_property""" + + cached_property_factory = cached_property def test_cached_property(self): @@ -24,7 +27,7 @@ class TestCachedProperty(unittest.TestCase): self.total1 += 1 return self.total1 - @cached_property + @self.cached_property_factory def add_cached(self): self.total2 += 1 return self.total2 @@ -38,10 +41,11 @@ class TestCachedProperty(unittest.TestCase): # The cached version demonstrates how nothing new is added self.assertEqual(c.add_cached, 1) self.assertEqual(c.add_cached, 1) + self.assertEqual(c.total2, 1) # It's customary for descriptors to return themselves if accessed # though the class, rather than through an instance. - self.assertTrue(isinstance(Check.add_cached, cached_property)) + self.assertTrue(isinstance(Check.add_cached, self.cached_property_factory)) def test_reset_cached_property(self): @@ -50,7 +54,7 @@ class TestCachedProperty(unittest.TestCase): def __init__(self): self.total = 0 - @cached_property + @self.cached_property_factory def add_cached(self): self.total += 1 return self.total @@ -60,11 +64,13 @@ class TestCachedProperty(unittest.TestCase): # Run standard cache assertion self.assertEqual(c.add_cached, 1) self.assertEqual(c.add_cached, 1) + self.assertEqual(c.total, 1) # Reset the cache. del c.add_cached self.assertEqual(c.add_cached, 2) self.assertEqual(c.add_cached, 2) + self.assertEqual(c.total, 2) def test_none_cached_property(self): @@ -73,7 +79,7 @@ class TestCachedProperty(unittest.TestCase): def __init__(self): self.total = None - @cached_property + @self.cached_property_factory def add_cached(self): return self.total @@ -81,17 +87,33 @@ class TestCachedProperty(unittest.TestCase): # Run standard cache assertion self.assertEqual(c.add_cached, None) + self.assertEqual(c.total, None) def test_threads(self): - """How well does this implementation work with threads?""" - + """ + How well does the standard cached_property implementation work with + threads? It doesn't, use threaded_cached_property instead! + """ + num_threads = 10 + check = self._run_threads(num_threads) + # Threads means that caching is bypassed. + # This assertion hinges on the fact the system executing the test can + # spawn and start running num_threads threads within the sleep period + # (defined in the Check class as 1 second). If num_threads were to be + # massively increased (try 10000), the actual value returned would be + # between 1 and num_threads, depending on thread scheduling and + # preemption. + self.assertEqual(check.add_cached, num_threads) + self.assertEqual(check.total, num_threads) + + def _run_threads(self, num_threads): class Check(object): def __init__(self): self.total = 0 self.lock = Lock() - @cached_property + @self.cached_property_factory def add_cached(self): sleep(1) # Need to guard this since += isn't atomic. @@ -100,13 +122,26 @@ class TestCachedProperty(unittest.TestCase): return self.total c = Check() + threads = [] - for x in range(10): + for _ in range(num_threads): thread = Thread(target=lambda: c.add_cached) thread.start() threads.append(thread) - for thread in threads: thread.join() - self.assertEqual(c.add_cached, 1) + return c + + +class TestThreadedCachedProperty(TestCachedProperty): + """Tests for threaded_cached_property""" + + cached_property_factory = threaded_cached_property + + def test_threads(self): + """How well does this implementation work with threads?""" + num_threads = 10 + check = self._run_threads(num_threads) + self.assertEqual(check.add_cached, 1) + self.assertEqual(check.total, 1) -- 2.30.2