Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-5087 - Convert test.test_load_balancer to async #2103

Merged
merged 9 commits into from
Feb 6, 2025
19 changes: 17 additions & 2 deletions test/asynchronous/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from typing import Any, Callable, Dict, Generator, Optional, no_type_check
from unittest import SkipTest

from bson.son import SON
Expand Down Expand Up @@ -395,7 +395,7 @@ def __init__(self, **kwargs):
async def start(self):
self.task = create_task(self.run(), name=self.name)

async def join(self, timeout: float | None = 0): # type: ignore[override]
async def join(self, timeout: Optional[float] = None): # type: ignore[override]
if self.task is not None:
await asyncio.wait([self.task], timeout=timeout)

Expand All @@ -407,3 +407,18 @@ async def run(self):
await self.target(*self.args)
finally:
self.stopped = True


class ExceptionCatchingTask(ConcurrentRunner):
"""A Task that stores any exception encountered while running."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.exc = None

async def run(self):
try:
await super().run()
except BaseException as exc:
self.exc = exc
raise
199 changes: 199 additions & 0 deletions test/asynchronous/test_load_balancer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test the Load Balancer unified spec tests."""
from __future__ import annotations

import asyncio
import gc
import os
import pathlib
import sys
import threading
from asyncio import Event
from test.asynchronous.helpers import ConcurrentRunner, ExceptionCatchingTask

import pytest

sys.path[0:0] = [""]

from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.asynchronous.unified_format import generate_test_classes
from test.utils import (
async_get_pool,
async_wait_until,
create_async_event,
)

from pymongo.asynchronous.helpers import anext

_IS_SYNC = False

pytestmark = pytest.mark.load_balancer

# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer")

# Generate unified tests.
globals().update(generate_test_classes(_TEST_PATH, module=__name__))


class TestLB(AsyncIntegrationTest):
RUN_ON_LOAD_BALANCER = True
RUN_ON_SERVERLESS = True

async def test_connections_are_only_returned_once(self):
if "PyPy" in sys.version:
# Tracked in PYTHON-3011
self.skipTest("Test is flaky on PyPy")
pool = await async_get_pool(self.client)
n_conns = len(pool.conns)
await self.db.test.find_one({})
self.assertEqual(len(pool.conns), n_conns)
await (await self.db.test.aggregate([{"$limit": 1}])).to_list()
self.assertEqual(len(pool.conns), n_conns)

@async_client_context.require_load_balancer
async def test_unpin_committed_transaction(self):
client = await self.async_rs_client()
pool = await async_get_pool(client)
coll = client[self.db.name].test
async with client.start_session() as session:
async with await session.start_transaction():
self.assertEqual(pool.active_sockets, 0)
await coll.insert_one({}, session=session)
self.assertEqual(pool.active_sockets, 1) # Pinned.
self.assertEqual(pool.active_sockets, 1) # Still pinned.
self.assertEqual(pool.active_sockets, 0) # Unpinned.

@async_client_context.require_failCommand_fail_point
async def test_cursor_gc(self):
async def create_resource(coll):
cursor = coll.find({}, batch_size=3)
await anext(cursor)
return cursor

await self._test_no_gc_deadlock(create_resource)

@async_client_context.require_failCommand_fail_point
async def test_command_cursor_gc(self):
async def create_resource(coll):
cursor = await coll.aggregate([], batchSize=3)
await anext(cursor)
return cursor

await self._test_no_gc_deadlock(create_resource)

async def _test_no_gc_deadlock(self, create_resource):
client = await self.async_rs_client()
pool = await async_get_pool(client)
coll = client[self.db.name].test
await coll.insert_many([{} for _ in range(10)])
self.assertEqual(pool.active_sockets, 0)
# Cause the initial find attempt to fail to induce a reference cycle.
args = {
"mode": {"times": 1},
"data": {
"failCommands": ["find", "aggregate"],
"closeConnection": True,
},
}
async with self.fail_point(args):
resource = await create_resource(coll)
if async_client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.

task = PoolLocker(pool)
await task.start()
self.assertTrue(await task.wait(task.locked, 5), "timed out")
# Garbage collect the resource while the pool is locked to ensure we
# don't deadlock.
del resource
# On PyPy it can take a few rounds to collect the cursor.
for _ in range(3):
gc.collect()
task.unlock.set()
await task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)

await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
await coll.delete_many({})

@async_client_context.require_transactions
async def test_session_gc(self):
client = await self.async_rs_client()
pool = await async_get_pool(client)
session = client.start_session()
await session.start_transaction()
await client.test_session_gc.test.find_one({}, session=session)
# Cleanup the transaction left open on the server unless we're
# testing serverless which does not support killSessions.
if not async_client_context.serverless:
self.addAsyncCleanup(self.client.admin.command, "killSessions", [session.session_id])
if async_client_context.load_balancer:
self.assertEqual(pool.active_sockets, 1) # Pinned.

task = PoolLocker(pool)
await task.start()
self.assertTrue(await task.wait(task.locked, 5), "timed out")
# Garbage collect the session while the pool is locked to ensure we
# don't deadlock.
del session
# On PyPy it can take a few rounds to collect the session.
for _ in range(3):
gc.collect()
task.unlock.set()
await task.join(5)
self.assertFalse(task.is_alive())
self.assertIsNone(task.exc)

await async_wait_until(lambda: pool.active_sockets == 0, "return socket")
# Run another operation to ensure the socket still works.
await client[self.db.name].test.delete_many({})


class PoolLocker(ExceptionCatchingTask):
def __init__(self, pool):
super().__init__(target=self.lock_pool)
self.pool = pool
self.daemon = True
self.locked = create_async_event()
self.unlock = create_async_event()

async def lock_pool(self):
async with self.pool.lock:
self.locked.set()
# Wait for the unlock flag.
unlock_pool = await self.wait(self.unlock, 10)
if not unlock_pool:
raise Exception("timed out waiting for unlock signal: deadlock?")

async def wait(self, event: Event, timeout: int):
if _IS_SYNC:
return event.wait(timeout) # type: ignore[call-arg]
else:
try:
await asyncio.wait_for(event.wait(), timeout=timeout)
except asyncio.TimeoutError:
return False
return True


if __name__ == "__main__":
unittest.main()
38 changes: 20 additions & 18 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
"""Test the client_session module."""
from __future__ import annotations

import asyncio
import copy
import sys
import time
from asyncio import iscoroutinefunction
from io import BytesIO
from test.asynchronous.helpers import ExceptionCatchingTask
from typing import Any, Callable, List, Set, Tuple

from pymongo.synchronous.mongo_client import MongoClient
Expand All @@ -35,7 +38,6 @@
)
from test.utils import (
EventListener,
ExceptionCatchingThread,
OvertCommandListener,
async_wait_until,
)
Expand Down Expand Up @@ -184,16 +186,15 @@ async def _test_ops(self, client, *ops):
f"{f.__name__} did not return implicit session to pool",
)

@async_client_context.require_sync
def test_implicit_sessions_checkout(self):
async def test_implicit_sessions_checkout(self):
# "To confirm that implicit sessions only allocate their server session after a
# successful connection checkout" test from Driver Sessions Spec.
succeeded = False
lsid_set = set()
failures = 0
for _ in range(5):
listener = OvertCommandListener()
client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
client = await self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1)
cursor = client.db.test.find({})
ops: List[Tuple[Callable, List[Any]]] = [
(client.db.test.find_one, [{"_id": 1}]),
Expand All @@ -210,26 +211,27 @@ def test_implicit_sessions_checkout(self):
(cursor.distinct, ["_id"]),
(client.db.list_collections, []),
]
threads = []
tasks = []
listener.reset()

def thread_target(op, *args):
res = op(*args)
async def target(op, *args):
if iscoroutinefunction(op):
res = await op(*args)
else:
res = op(*args)
if isinstance(res, (AsyncCursor, AsyncCommandCursor)):
list(res) # type: ignore[call-overload]
await res.to_list()

for op, args in ops:
threads.append(
ExceptionCatchingThread(
target=thread_target, args=[op, *args], name=op.__name__
)
tasks.append(
ExceptionCatchingTask(target=target, args=[op, *args], name=op.__name__)
)
threads[-1].start()
self.assertEqual(len(threads), len(ops))
for thread in threads:
thread.join()
self.assertIsNone(thread.exc)
client.close()
await tasks[-1].start()
self.assertEqual(len(tasks), len(ops))
for t in tasks:
await t.join()
self.assertIsNone(t.exc)
await client.close()
lsid_set.clear()
for i in listener.started_events:
if i.command.get("lsid"):
Expand Down
19 changes: 17 additions & 2 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
HAVE_IPADDRESS = False
from functools import wraps
from typing import Any, Callable, Dict, Generator, no_type_check
from typing import Any, Callable, Dict, Generator, Optional, no_type_check
from unittest import SkipTest

from bson.son import SON
Expand Down Expand Up @@ -395,7 +395,7 @@ def __init__(self, **kwargs):
def start(self):
self.task = create_task(self.run(), name=self.name)

def join(self, timeout: float | None = 0): # type: ignore[override]
def join(self, timeout: Optional[float] = None): # type: ignore[override]
if self.task is not None:
asyncio.wait([self.task], timeout=timeout)

Expand All @@ -407,3 +407,18 @@ def run(self):
self.target(*self.args)
finally:
self.stopped = True


class ExceptionCatchingTask(ConcurrentRunner):
"""A Task that stores any exception encountered while running."""

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.exc = None

def run(self):
try:
super().run()
except BaseException as exc:
self.exc = exc
raise
6 changes: 3 additions & 3 deletions test/test_bson.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
sys.path[0:0] = [""]

from test import qcheck, unittest
from test.utils import ExceptionCatchingThread
from test.helpers import ExceptionCatchingTask

import bson
from bson import (
Expand Down Expand Up @@ -1075,7 +1075,7 @@ def target(i):
my_int = type(f"MyInt_{i}_{j}", (int,), {})
bson.encode({"my_int": my_int()})

threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)]
threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)]
for t in threads:
t.start()

Expand Down Expand Up @@ -1114,7 +1114,7 @@ def __repr__(self):

def test_doc_in_invalid_document_error_message_mapping(self):
class MyMapping(abc.Mapping):
def keys():
def keys(self):
return ["t"]

def __getitem__(self, name):
Expand Down
Loading
Loading