diff --git a/multidb/pinning.py b/multidb/pinning.py index aa529c2..ed2412d 100644 --- a/multidb/pinning.py +++ b/multidb/pinning.py @@ -7,7 +7,7 @@ __all__ = ['this_thread_is_pinned', 'pin_this_thread', 'unpin_this_thread', - 'use_primary_db', 'use_master', 'db_write'] + 'use_primary_db', 'use_secondary_db', 'use_master', 'db_write'] _locals = threading.local() @@ -33,8 +33,8 @@ def unpin_this_thread(): _locals.pinned = False -class UsePrimaryDB(object): - """A contextmanager/decorator to use the master database.""" +class _UseDB(object): + """A contextmanager/decorator to use the specified database.""" def __call__(self, func): @wraps(func) def decorator(*args, **kw): @@ -43,14 +43,31 @@ def decorator(*args, **kw): return decorator def __enter__(self): - _locals.old = this_thread_is_pinned() - pin_this_thread() + _locals.old = getattr(_locals, 'old', []) + _locals.old.append(this_thread_is_pinned()) def __exit__(self, type, value, tb): - if not _locals.old: + previous_state = _locals.old.pop() + if previous_state: + pin_this_thread() + else: unpin_this_thread() +class UsePrimaryDB(_UseDB): + """A contextmanager/decorator to use the primary database.""" + def __enter__(self): + super(UsePrimaryDB, self).__enter__() + pin_this_thread() + + +class UseSecondaryDB(_UseDB): + """A contextmanager/decorator to use the secondary database.""" + def __enter__(self): + super(UseSecondaryDB, self).__enter__() + unpin_this_thread() + + class DeprecatedUseMaster(UsePrimaryDB): def __enter__(self): warnings.warn( @@ -62,6 +79,7 @@ def __enter__(self): use_primary_db = UsePrimaryDB() +use_secondary_db = UseSecondaryDB() use_master = DeprecatedUseMaster() diff --git a/multidb/tests.py b/multidb/tests.py index 125cff4..36c1e09 100644 --- a/multidb/tests.py +++ b/multidb/tests.py @@ -20,7 +20,8 @@ pinning_cookie_samesite, pinning_cookie_secure, pinning_seconds, PinningRouterMiddleware) from multidb.pinning import (this_thread_is_pinned, pin_this_thread, - unpin_this_thread, use_primary_db, db_write) + unpin_this_thread, use_primary_db, + use_secondary_db, db_write) class UnpinningTestCase(TestCase): @@ -202,6 +203,38 @@ def check(): check() assert not this_thread_is_pinned() + def test_decorator_nested(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + unpin_this_thread() + assert not this_thread_is_pinned() + check_outer() + assert not this_thread_is_pinned() + + def test_decorator_nested_mixed(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_secondary_db + def check_outer(): + assert not this_thread_is_pinned() + check_inner() + assert not this_thread_is_pinned() + + unpin_this_thread() + assert not this_thread_is_pinned() + check_outer() + assert not this_thread_is_pinned() + def test_decorator_resets(self): @use_primary_db def check(): @@ -211,6 +244,38 @@ def check(): check() assert this_thread_is_pinned() + def test_decorator_resets_nested(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + pin_this_thread() + assert this_thread_is_pinned() + check_outer() + assert this_thread_is_pinned() + + def test_decorator_resets_nested_mixed(self): + @use_primary_db + def check_inner(): + assert this_thread_is_pinned() + + @use_secondary_db + def check_outer(): + assert not this_thread_is_pinned() + check_inner() + assert not this_thread_is_pinned() + + pin_this_thread() + assert this_thread_is_pinned() + check_outer() + assert this_thread_is_pinned() + def test_context_manager(self): unpin_this_thread() assert not this_thread_is_pinned() @@ -218,6 +283,26 @@ def test_context_manager(self): assert this_thread_is_pinned() assert not this_thread_is_pinned() + def test_context_manager_nested(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert this_thread_is_pinned() + assert not this_thread_is_pinned() + + def test_context_manager_nested_mixed(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert not this_thread_is_pinned() + assert not this_thread_is_pinned() + def test_context_manager_resets(self): pin_this_thread() assert this_thread_is_pinned() @@ -225,6 +310,26 @@ def test_context_manager_resets(self): assert this_thread_is_pinned() assert this_thread_is_pinned() + def test_context_manager_resets_nested(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert this_thread_is_pinned() + assert this_thread_is_pinned() + + def test_context_manager_resets_nested_mixed(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + assert not this_thread_is_pinned() + assert this_thread_is_pinned() + def test_context_manager_exception(self): unpin_this_thread() assert not this_thread_is_pinned() @@ -282,6 +387,201 @@ def thread2_worker(): self.assertEqual(pinned[1], False) +class UseSecondaryDBTests(TestCase): + def test_decorator(self): + @use_secondary_db + def check(): + assert not this_thread_is_pinned() + pin_this_thread() + assert this_thread_is_pinned() + check() + assert this_thread_is_pinned() + + def test_decorator_nested(self): + @use_secondary_db + def check_inner(): + assert not this_thread_is_pinned() + + @use_secondary_db + def check_outer(): + assert not this_thread_is_pinned() + check_inner() + assert not this_thread_is_pinned() + + pin_this_thread() + assert this_thread_is_pinned() + check_outer() + assert this_thread_is_pinned() + + def test_decorator_nested_mixed(self): + @use_secondary_db + def check_inner(): + assert not this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + pin_this_thread() + assert this_thread_is_pinned() + check_outer() + assert this_thread_is_pinned() + + def test_decorator_resets(self): + @use_secondary_db + def check(): + assert not this_thread_is_pinned() + unpin_this_thread() + assert not this_thread_is_pinned() + check() + assert not this_thread_is_pinned() + + def test_decorator_resets_nested(self): + @use_secondary_db + def check_inner(): + assert not this_thread_is_pinned() + + @use_secondary_db + def check_outer(): + assert not this_thread_is_pinned() + check_inner() + assert not this_thread_is_pinned() + + unpin_this_thread() + assert not this_thread_is_pinned() + check_outer() + assert not this_thread_is_pinned() + + def test_decorator_resets_nested_mixed(self): + @use_secondary_db + def check_inner(): + assert not this_thread_is_pinned() + + @use_primary_db + def check_outer(): + assert this_thread_is_pinned() + check_inner() + assert this_thread_is_pinned() + + unpin_this_thread() + assert not this_thread_is_pinned() + check_outer() + assert not this_thread_is_pinned() + + def test_context_manager(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert this_thread_is_pinned() + + def test_context_manager_nested(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert not this_thread_is_pinned() + assert this_thread_is_pinned() + + def test_context_manager_nested_mixed(self): + pin_this_thread() + assert this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert this_thread_is_pinned() + assert this_thread_is_pinned() + + def test_context_manager_resets(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert not this_thread_is_pinned() + + def test_context_manager_resets_nested(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert not this_thread_is_pinned() + assert not this_thread_is_pinned() + + def test_context_manager_resets_nested_mixed(self): + unpin_this_thread() + assert not this_thread_is_pinned() + with use_primary_db: + assert this_thread_is_pinned() + with use_secondary_db: + assert not this_thread_is_pinned() + assert this_thread_is_pinned() + assert not this_thread_is_pinned() + + def test_context_manager_exception(self): + pin_this_thread() + assert this_thread_is_pinned() + with self.assertRaises(ValueError): + with use_secondary_db: + assert not this_thread_is_pinned() + raise ValueError + assert this_thread_is_pinned() + + def test_multithreaded_unpinning(self): + thread1_lock = Lock() + thread2_lock = Lock() + thread1_lock.acquire() + thread2_lock.acquire() + orchestrator = Lock() + orchestrator.acquire() + + pinned = {} + + def thread1_worker(): + pin_this_thread() + with use_secondary_db: + orchestrator.release() + thread1_lock.acquire() + + pinned[1] = this_thread_is_pinned() + + def thread2_worker(): + unpin_this_thread() + with use_secondary_db: + orchestrator.release() + thread2_lock.acquire() + + pinned[2] = this_thread_is_pinned() + orchestrator.release() + + thread1 = Thread(target=thread1_worker) + thread2 = Thread(target=thread2_worker) + + # thread1 starts, entering `use_primary_db` from a pinned state + thread1.start() + orchestrator.acquire() + + # thread2 starts, entering `use_primary_db` from an unpinned state + thread2.start() + orchestrator.acquire() + + # thread2 finishes, returning to an unpinned state + thread2_lock.release() + thread2.join() + self.assertEqual(pinned[2], False) + + # thread1 finishes, returning to a pinned state + thread1_lock.release() + thread1.join() + self.assertEqual(pinned[1], True) + + class DeprecationTestCase(TestCase): def test_masterslaverouter(self): with warnings.catch_warnings(record=True) as w: