use in multithreaded environments.
Added tests.
Added Python 3.4 to the list of test environments.
__license__ = 'BSD'
import time
+import threading
class cached_property(object):
return self
value = obj.__dict__[self.func.__name__] = self.func(obj)
return value
+
+
+class threaded_cached_property(cached_property):
+ """ A cached_property version for use in environments where multiple
+ threads might concurrently try to access the property.
+ """
+ def __init__(self, func):
+ super(threaded_cached_property, self).__init__(func)
+ self.lock = threading.RLock()
+
+ def __get__(self, obj, cls):
+ with self.lock:
+ # Double check if the value was computed before the lock was
+ # acquired.
+ prop_name = self.func.__name__
+ if prop_name in obj.__dict__:
+ return obj.__dict__[prop_name]
+
+ # If not, do the calculation and release the lock.
+ return super(threaded_cached_property, self).__get__(obj, cls)
\ No newline at end of file
"""
from time import sleep
-from threading import Thread
+from threading import Lock, Thread
import unittest
from cached_property import cached_property
def __init__(self):
self.total = 0
+ self.lock = Lock()
@cached_property
def add_cached(self):
sleep(1)
- self.total += 1
+ # Need to guard this since += isn't atomic.
+ with self.lock:
+ self.total += 1
return self.total
c = Check()
threads = []
- for x in range(10):
+ num_threads = 10
+ for x in range(num_threads):
thread = Thread(target=lambda: c.add_cached)
thread.start()
threads.append(thread)
# TODO: This assertion should be failing.
# See https://github.com/pydanny/cached-property/issues/6
- self.assertEqual(c.add_cached, 10)
+ # This assertion hinges on the fact the system executing the test can
+ # spawn and start running num_threads threads within the sleep period
+ # (defined in the Check class as 1 second). If num_threads were to be
+ # massively increased (try 10000), the actual value returned would be
+ # between 1 and num_threads, depending on thread scheduling and
+ # preemption.
+ self.assertEqual(c.add_cached, num_threads)
--- /dev/null
+# -*- coding: utf-8 -*-
+
+"""
+test_threaded_cache_property.py
+----------------------------------
+
+Tests for `cached-property` module, threaded_cache_property.
+"""
+
+from time import sleep
+from threading import Thread, Lock
+import unittest
+
+from cached_property import threaded_cached_property
+
+
+class TestCachedProperty(unittest.TestCase):
+
+ def test_cached_property(self):
+
+ class Check(object):
+
+ def __init__(self):
+ self.total1 = 0
+ self.total2 = 0
+
+ @property
+ def add_control(self):
+ self.total1 += 1
+ return self.total1
+
+ @threaded_cached_property
+ def add_cached(self):
+ self.total2 += 1
+ return self.total2
+
+ c = Check()
+
+ # The control shows that we can continue to add 1.
+ self.assertEqual(c.add_control, 1)
+ self.assertEqual(c.add_control, 2)
+
+ # The cached version demonstrates how nothing new is added
+ self.assertEqual(c.add_cached, 1)
+ self.assertEqual(c.add_cached, 1)
+
+ def test_reset_cached_property(self):
+
+ class Check(object):
+
+ def __init__(self):
+ self.total = 0
+
+ @threaded_cached_property
+ def add_cached(self):
+ self.total += 1
+ return self.total
+
+ c = Check()
+
+ # Run standard cache assertion
+ self.assertEqual(c.add_cached, 1)
+ self.assertEqual(c.add_cached, 1)
+
+ # Reset the cache.
+ del c.add_cached
+ self.assertEqual(c.add_cached, 2)
+ self.assertEqual(c.add_cached, 2)
+
+ def test_none_cached_property(self):
+
+ class Check(object):
+
+ def __init__(self):
+ self.total = None
+
+ @threaded_cached_property
+ def add_cached(self):
+ return self.total
+
+ c = Check()
+
+ # Run standard cache assertion
+ self.assertEqual(c.add_cached, None)
+
+
+class TestThreadingIssues(unittest.TestCase):
+
+ def test_threads(self):
+ """ How well does this implementation work with threads?"""
+
+ class Check(object):
+
+ def __init__(self):
+ self.total = 0
+ self.lock = Lock()
+
+ @threaded_cached_property
+ def add_cached(self):
+ sleep(1)
+ # Need to guard this since += isn't atomic.
+ with self.lock:
+ self.total += 1
+ return self.total
+
+ c = Check()
+ threads = []
+ for x in range(10):
+ thread = Thread(target=lambda: c.add_cached)
+ thread.start()
+ threads.append(thread)
+
+ for thread in threads:
+ thread.join()
+
+ self.assertEqual(c.add_cached, 1)
[tox]
-envlist = py26, py27, py33
+envlist = py26, py27, py33, py34
[testenv]
setenv =