diff --git a/test/asynchronous/helpers.py b/test/asynchronous/helpers.py index a35c71b107..28260d0a52 100644 --- a/test/asynchronous/helpers.py +++ b/test/asynchronous/helpers.py @@ -40,7 +40,7 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON @@ -395,7 +395,7 @@ def __init__(self, **kwargs): async def start(self): self.task = create_task(self.run(), name=self.name) - async def join(self, timeout: float | None = 0): # type: ignore[override] + async def join(self, timeout: Optional[float] = None): # type: ignore[override] if self.task is not None: await asyncio.wait([self.task], timeout=timeout) @@ -407,3 +407,18 @@ async def run(self): await self.target(*self.args) finally: self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.exc = None + + async def run(self): + try: + await super().run() + except BaseException as exc: + self.exc = exc + raise diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py new file mode 100644 index 0000000000..fd50841c87 --- /dev/null +++ b/test/asynchronous/test_load_balancer.py @@ -0,0 +1,199 @@ +# Copyright 2021-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test the Load Balancer unified spec tests.""" +from __future__ import annotations + +import asyncio +import gc +import os +import pathlib +import sys +import threading +from asyncio import Event +from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask + +import pytest + +sys.path[0:0] = [""] + +from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest +from test.asynchronous.unified_format import generate_test_classes +from test.utils import ( + async_get_pool, + async_wait_until, + create_async_event, +) + +from pymongo.asynchronous.helpers import anext + +_IS_SYNC = False + +pytestmark = pytest.mark.load_balancer + +# Location of JSON test specifications. +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") + +# Generate unified tests. +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) + + +class TestLB(AsyncIntegrationTest): + RUN_ON_LOAD_BALANCER = True + RUN_ON_SERVERLESS = True + + async def test_connections_are_only_returned_once(self): + if "PyPy" in sys.version: + # Tracked in PYTHON-3011 + self.skipTest("Test is flaky on PyPy") + pool = await async_get_pool(self.client) + n_conns = len(pool.conns) + await self.db.test.find_one({}) + self.assertEqual(len(pool.conns), n_conns) + await (await self.db.test.aggregate([{"$limit": 1}])).to_list() + self.assertEqual(len(pool.conns), n_conns) + + @async_client_context.require_load_balancer + async def test_unpin_committed_transaction(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + async with client.start_session() as session: + async with await session.start_transaction(): + self.assertEqual(pool.active_sockets, 0) + await coll.insert_one({}, session=session) + self.assertEqual(pool.active_sockets, 1) # Pinned. + self.assertEqual(pool.active_sockets, 1) # Still pinned. + self.assertEqual(pool.active_sockets, 0) # Unpinned. + + @async_client_context.require_failCommand_fail_point + async def test_cursor_gc(self): + async def create_resource(coll): + cursor = coll.find({}, batch_size=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + @async_client_context.require_failCommand_fail_point + async def test_command_cursor_gc(self): + async def create_resource(coll): + cursor = await coll.aggregate([], batchSize=3) + await anext(cursor) + return cursor + + await self._test_no_gc_deadlock(create_resource) + + async def _test_no_gc_deadlock(self, create_resource): + client = await self.async_rs_client() + pool = await async_get_pool(client) + coll = client[self.db.name].test + await coll.insert_many([{} for _ in range(10)]) + self.assertEqual(pool.active_sockets, 0) + # Cause the initial find attempt to fail to induce a reference cycle. + args = { + "mode": {"times": 1}, + "data": { + "failCommands": ["find", "aggregate"], + "closeConnection": True, + }, + } + async with self.fail_point(args): + resource = await create_resource(coll) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the resource while the pool is locked to ensure we + # don't deadlock. + del resource + # On PyPy it can take a few rounds to collect the cursor. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await coll.delete_many({}) + + @async_client_context.require_transactions + async def test_session_gc(self): + client = await self.async_rs_client() + pool = await async_get_pool(client) + session = client.start_session() + await session.start_transaction() + await client.test_session_gc.test.find_one({}, session=session) + # Cleanup the transaction left open on the server unless we're + # testing serverless which does not support killSessions. + if not async_client_context.serverless: + self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id]) + if async_client_context.load_balancer: + self.assertEqual(pool.active_sockets, 1) # Pinned. + + task = PoolLocker(pool) + await task.start() + self.assertTrue(await task.wait(task.locked, 5), "timed out") + # Garbage collect the session while the pool is locked to ensure we + # don't deadlock. + del session + # On PyPy it can take a few rounds to collect the session. + for _ in range(3): + gc.collect() + task.unlock.set() + await task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) + + await async_wait_until(lambda: pool.active_sockets == 0, "return socket") + # Run another operation to ensure the socket still works. + await client[self.db.name].test.delete_many({}) + + +class PoolLocker(ExceptionCatchingTask): + def __init__(self, pool): + super().__init__(target=self.lock_pool) + self.pool = pool + self.daemon = True + self.locked = create_async_event() + self.unlock = create_async_event() + + async def lock_pool(self): + async with self.pool.lock: + self.locked.set() + # Wait for the unlock flag. + unlock_pool = await self.wait(self.unlock, 10) + if not unlock_pool: + raise Exception("timed out waiting for unlock signal: deadlock?") + + async def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) # type: ignore[call-arg] + else: + try: + await asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True + + +if __name__ == "__main__": + unittest.main() diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 42bc253b56..03d1032b5b 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.asynchronous.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -35,7 +38,6 @@ ) from test.utils import ( EventListener, - ExceptionCatchingThread, OvertCommandListener, async_wait_until, ) @@ -184,8 +186,7 @@ async def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @async_client_context.require_sync - def test_implicit_sessions_checkout(self): + async def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. succeeded = False @@ -193,7 +194,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = OvertCommandListener() - client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -210,26 +211,27 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + async def target(op, *args): + if iscoroutinefunction(op): + res = await op(*args) + else: + res = op(*args) if isinstance(res, (AsyncCursor, AsyncCommandCursor)): - list(res) # type: ignore[call-overload] + await res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) - client.close() + await tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + await t.join() + self.assertIsNone(t.exc) + await client.close() lsid_set.clear() for i in listener.started_events: if i.command.get("lsid"): diff --git a/test/helpers.py b/test/helpers.py index 705843efcd..3f51fde08c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -40,7 +40,7 @@ except ImportError: HAVE_IPADDRESS = False from functools import wraps -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator, Optional, no_type_check from unittest import SkipTest from bson.son import SON @@ -395,7 +395,7 @@ def __init__(self, **kwargs): def start(self): self.task = create_task(self.run(), name=self.name) - def join(self, timeout: float | None = 0): # type: ignore[override] + def join(self, timeout: Optional[float] = None): # type: ignore[override] if self.task is not None: asyncio.wait([self.task], timeout=timeout) @@ -407,3 +407,18 @@ def run(self): self.target(*self.args) finally: self.stopped = True + + +class ExceptionCatchingTask(ConcurrentRunner): + """A Task that stores any exception encountered while running.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.exc = None + + def run(self): + try: + super().run() + except BaseException as exc: + self.exc = exc + raise diff --git a/test/test_bson.py b/test/test_bson.py index e601be4915..e704efe451 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] from test import qcheck, unittest -from test.utils import ExceptionCatchingThread +from test.helpers import ExceptionCatchingTask import bson from bson import ( @@ -1075,7 +1075,7 @@ def target(i): my_int = type(f"MyInt_{i}_{j}", (int,), {}) bson.encode({"my_int": my_int()}) - threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] + threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() @@ -1114,7 +1114,7 @@ def __repr__(self): def test_doc_in_invalid_document_error_message_mapping(self): class MyMapping(abc.Mapping): - def keys(): + def keys(self): return ["t"] def __getitem__(self, name): diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 23bea4d984..7db19b46b5 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -15,10 +15,14 @@ """Test the Load Balancer unified spec tests.""" from __future__ import annotations +import asyncio import gc import os +import pathlib import sys import threading +from asyncio import Event +from test.helpers import ConcurrentRunner, ExceptionCatchingTask import pytest @@ -26,15 +30,26 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, wait_until +from test.utils import ( + create_event, + get_pool, + wait_until, +) + +from pymongo.synchronous.helpers import next + +_IS_SYNC = True pytestmark = pytest.mark.load_balancer # Location of JSON test specifications. -TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "load_balancer") +if _IS_SYNC: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") +else: + _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(_TEST_PATH, module=__name__)) class TestLB(IntegrationTest): @@ -49,13 +64,12 @@ def test_connections_are_only_returned_once(self): n_conns = len(pool.conns) self.db.test.find_one({}) self.assertEqual(len(pool.conns), n_conns) - list(self.db.test.aggregate([{"$limit": 1}])) + (self.db.test.aggregate([{"$limit": 1}])).to_list() self.assertEqual(len(pool.conns), n_conns) @client_context.require_load_balancer def test_unpin_committed_transaction(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test with client.start_session() as session: @@ -86,7 +100,6 @@ def create_resource(coll): def _test_no_gc_deadlock(self, create_resource): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test coll.insert_many([{} for _ in range(10)]) @@ -104,19 +117,19 @@ def _test_no_gc_deadlock(self, create_resource): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") # Garbage collect the resource while the pool is locked to ensure we # don't deadlock. del resource # On PyPy it can take a few rounds to collect the cursor. for _ in range(3): gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. @@ -125,7 +138,6 @@ def _test_no_gc_deadlock(self, create_resource): @client_context.require_transactions def test_session_gc(self): client = self.rs_client() - self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() session.start_transaction() @@ -137,41 +149,51 @@ def test_session_gc(self): if client_context.load_balancer: self.assertEqual(pool.active_sockets, 1) # Pinned. - thread = PoolLocker(pool) - thread.start() - self.assertTrue(thread.locked.wait(5), "timed out") + task = PoolLocker(pool) + task.start() + self.assertTrue(task.wait(task.locked, 5), "timed out") # Garbage collect the session while the pool is locked to ensure we # don't deadlock. del session # On PyPy it can take a few rounds to collect the session. for _ in range(3): gc.collect() - thread.unlock.set() - thread.join(5) - self.assertFalse(thread.is_alive()) - self.assertIsNone(thread.exc) + task.unlock.set() + task.join(5) + self.assertFalse(task.is_alive()) + self.assertIsNone(task.exc) wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. client[self.db.name].test.delete_many({}) -class PoolLocker(ExceptionCatchingThread): +class PoolLocker(ExceptionCatchingTask): def __init__(self, pool): super().__init__(target=self.lock_pool) self.pool = pool self.daemon = True - self.locked = threading.Event() - self.unlock = threading.Event() + self.locked = create_event() + self.unlock = create_event() def lock_pool(self): with self.pool.lock: self.locked.set() # Wait for the unlock flag. - unlock_pool = self.unlock.wait(10) + unlock_pool = self.wait(self.unlock, 10) if not unlock_pool: raise Exception("timed out waiting for unlock signal: deadlock?") + def wait(self, event: Event, timeout: int): + if _IS_SYNC: + return event.wait(timeout) # type: ignore[call-arg] + else: + try: + asyncio.wait_for(event.wait(), timeout=timeout) + except asyncio.TimeoutError: + return False + return True + if __name__ == "__main__": unittest.main() diff --git a/test/test_session.py b/test/test_session.py index 634efa11c0..175a282495 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -15,10 +15,13 @@ """Test the client_session module.""" from __future__ import annotations +import asyncio import copy import sys import time +from asyncio import iscoroutinefunction from io import BytesIO +from test.helpers import ExceptionCatchingTask from typing import Any, Callable, List, Set, Tuple from pymongo.synchronous.mongo_client import MongoClient @@ -35,7 +38,6 @@ ) from test.utils import ( EventListener, - ExceptionCatchingThread, OvertCommandListener, wait_until, ) @@ -184,7 +186,6 @@ def _test_ops(self, client, *ops): f"{f.__name__} did not return implicit session to pool", ) - @client_context.require_sync def test_implicit_sessions_checkout(self): # "To confirm that implicit sessions only allocate their server session after a # successful connection checkout" test from Driver Sessions Spec. @@ -210,25 +211,26 @@ def test_implicit_sessions_checkout(self): (cursor.distinct, ["_id"]), (client.db.list_collections, []), ] - threads = [] + tasks = [] listener.reset() - def thread_target(op, *args): - res = op(*args) + def target(op, *args): + if iscoroutinefunction(op): + res = op(*args) + else: + res = op(*args) if isinstance(res, (Cursor, CommandCursor)): - list(res) # type: ignore[call-overload] + res.to_list() for op, args in ops: - threads.append( - ExceptionCatchingThread( - target=thread_target, args=[op, *args], name=op.__name__ - ) + tasks.append( + ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__) ) - threads[-1].start() - self.assertEqual(len(threads), len(ops)) - for thread in threads: - thread.join() - self.assertIsNone(thread.exc) + tasks[-1].start() + self.assertEqual(len(tasks), len(ops)) + for t in tasks: + t.join() + self.assertIsNone(t.exc) client.close() lsid_set.clear() for i in listener.started_events: diff --git a/test/utils.py b/test/utils.py index 91000a636a..5c1e0bfb7c 100644 --- a/test/utils.py +++ b/test/utils.py @@ -39,6 +39,7 @@ from bson.objectid import ObjectId from bson.son import SON from pymongo import AsyncMongoClient, monitoring, operations, read_preferences +from pymongo._asyncio_task import create_task from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -912,21 +913,6 @@ def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() -class ExceptionCatchingThread(threading.Thread): - """A thread that stores any exception encountered from run().""" - - def __init__(self, *args, **kwargs): - self.exc = None - super().__init__(*args, **kwargs) - - def run(self): - try: - super().run() - except BaseException as exc: - self.exc = exc - raise - - def parse_read_preference(pref): # Make first letter lowercase to match read_pref's modes. mode_string = pref.get("mode", "primary") @@ -1079,3 +1065,11 @@ async def async_set_fail_point(client, command_args): cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) await client.admin.command(cmd) + + +def create_async_event(): + return asyncio.Event() + + +def create_event(): + return threading.Event() diff --git a/tools/synchro.py b/tools/synchro.py index 4b6326a49c..fe38b4dcfe 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -122,6 +122,7 @@ "SpecRunnerTask": "SpecRunnerThread", "AsyncMockConnection": "MockConnection", "AsyncMockPool": "MockPool", + "create_async_event": "create_event", } docstring_replacements: dict[tuple[str, str], str] = { @@ -214,6 +215,7 @@ def async_only_test(f: str) -> bool: "test_heartbeat_monitoring.py", "test_index_management.py", "test_grid_file.py", + "test_load_balancer.py", "test_json_util_integration.py", "test_gridfs_spec.py", "test_logger.py",