Skip to content

Commit 19ce266

Browse files
authored
Merge branch 'master' into PYTHON-5080
2 parents a131f48 + c8d3afd commit 19ce266

15 files changed

+876
-132
lines changed

.github/workflows/release-python.yml

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ env:
2323
SILK_ASSET_GROUP: mongodb-python-driver
2424
EVERGREEN_PROJECT: mongo-python-driver
2525
# Constant
26-
DRY_RUN: ${{ github.event_name == 'workflow_dispatch' && inputs.dry_run || 'true' }}
26+
# inputs will be empty on a scheduled run. so, we only set dry_run
27+
# to 'false' when the input is set to 'false'.
28+
DRY_RUN: ${{ ! contains(inputs.dry_run, 'false') }}
2729
FOLLOWING_VERSION: ${{ inputs.following_version || '' }}
2830
VERSION: ${{ inputs.version || '10.10.10.10' }}
2931

test/__init__.py

+86-33
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import gc
20+
import inspect
2021
import logging
2122
import multiprocessing
2223
import os
@@ -30,28 +31,6 @@
3031
import unittest
3132
import warnings
3233
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5534

5635
from pymongo.uri_parser import parse_uri
5736

@@ -63,7 +42,6 @@
6342
HAVE_IPADDRESS = False
6443
from contextlib import contextmanager
6544
from functools import partial, wraps
66-
from test.version import Version
6745
from typing import Any, Callable, Dict, Generator, overload
6846
from unittest import SkipTest
6947
from urllib.parse import quote_plus
@@ -78,6 +56,32 @@
7856
from pymongo.synchronous.database import Database
7957
from pymongo.synchronous.mongo_client import MongoClient
8058

59+
sys.path[0:0] = [""]
60+
61+
from test.helpers import (
62+
COMPRESSORS,
63+
IS_SRV,
64+
MONGODB_API_VERSION,
65+
MULTI_MONGOS_LB_URI,
66+
TEST_LOADBALANCER,
67+
TEST_SERVERLESS,
68+
TLS_OPTIONS,
69+
SystemCertsPatcher,
70+
client_knobs,
71+
db_pwd,
72+
db_user,
73+
global_knobs,
74+
host,
75+
is_server_resolvable,
76+
port,
77+
print_running_topology,
78+
print_thread_stacks,
79+
print_thread_tracebacks,
80+
sanitize_cmd,
81+
sanitize_reply,
82+
)
83+
from test.version import Version
84+
8185
_IS_SYNC = True
8286

8387

@@ -863,18 +867,66 @@ def max_message_size_bytes(self):
863867
# Reusable client context
864868
client_context = ClientContext()
865869

870+
# Global event loop for async tests.
871+
LOOP = None
866872

867-
def reset_client_context():
868-
if _IS_SYNC:
869-
# sync tests don't need to reset a client context
870-
return
871-
elif client_context.client is not None:
872-
client_context.client.close()
873-
client_context.client = None
874-
client_context._init_client()
873+
874+
def get_loop() -> asyncio.AbstractEventLoop:
875+
"""Get the test suite's global event loop."""
876+
global LOOP
877+
if LOOP is None:
878+
try:
879+
LOOP = asyncio.get_running_loop()
880+
except RuntimeError:
881+
# no running event loop, fallback to get_event_loop.
882+
try:
883+
# Ignore DeprecationWarning: There is no current event loop
884+
with warnings.catch_warnings():
885+
warnings.simplefilter("ignore", DeprecationWarning)
886+
LOOP = asyncio.get_event_loop()
887+
except RuntimeError:
888+
LOOP = asyncio.new_event_loop()
889+
asyncio.set_event_loop(LOOP)
890+
return LOOP
875891

876892

877893
class PyMongoTestCase(unittest.TestCase):
894+
if not _IS_SYNC:
895+
# An async TestCase that uses a single event loop for all tests.
896+
# Inspired by TestCase.
897+
def setUp(self):
898+
pass
899+
900+
def tearDown(self):
901+
pass
902+
903+
def addCleanup(self, func, /, *args, **kwargs):
904+
self.addCleanup(*(func, *args), **kwargs)
905+
906+
def _callSetUp(self):
907+
self.setUp()
908+
self._callAsync(self.setUp)
909+
910+
def _callTestMethod(self, method):
911+
self._callMaybeAsync(method)
912+
913+
def _callTearDown(self):
914+
self._callAsync(self.tearDown)
915+
self.tearDown()
916+
917+
def _callCleanup(self, function, *args, **kwargs):
918+
self._callMaybeAsync(function, *args, **kwargs)
919+
920+
def _callAsync(self, func, /, *args, **kwargs):
921+
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
922+
return get_loop().run_until_complete(func(*args, **kwargs))
923+
924+
def _callMaybeAsync(self, func, /, *args, **kwargs):
925+
if inspect.iscoroutinefunction(func):
926+
return get_loop().run_until_complete(func(*args, **kwargs))
927+
else:
928+
return func(*args, **kwargs)
929+
878930
def assertEqualCommand(self, expected, actual, msg=None):
879931
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
880932

@@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
11361188

11371189
@client_context.require_connection
11381190
def setUp(self) -> None:
1139-
if not _IS_SYNC:
1140-
reset_client_context()
11411191
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11421192
raise SkipTest("this test does not support load balancers")
11431193
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@@ -1186,6 +1236,9 @@ def tearDown(self) -> None:
11861236

11871237

11881238
def setup():
1239+
if not _IS_SYNC:
1240+
# Set up the event loop.
1241+
get_loop()
11891242
client_context.init()
11901243
warnings.resetwarnings()
11911244
warnings.simplefilter("always")

test/asynchronous/__init__.py

+87-34
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import asyncio
1919
import gc
20+
import inspect
2021
import logging
2122
import multiprocessing
2223
import os
@@ -30,28 +31,6 @@
3031
import unittest
3132
import warnings
3233
from asyncio import iscoroutinefunction
33-
from test.helpers import (
34-
COMPRESSORS,
35-
IS_SRV,
36-
MONGODB_API_VERSION,
37-
MULTI_MONGOS_LB_URI,
38-
TEST_LOADBALANCER,
39-
TEST_SERVERLESS,
40-
TLS_OPTIONS,
41-
SystemCertsPatcher,
42-
client_knobs,
43-
db_pwd,
44-
db_user,
45-
global_knobs,
46-
host,
47-
is_server_resolvable,
48-
port,
49-
print_running_topology,
50-
print_thread_stacks,
51-
print_thread_tracebacks,
52-
sanitize_cmd,
53-
sanitize_reply,
54-
)
5534

5635
from pymongo.uri_parser import parse_uri
5736

@@ -63,7 +42,6 @@
6342
HAVE_IPADDRESS = False
6443
from contextlib import asynccontextmanager, contextmanager
6544
from functools import partial, wraps
66-
from test.version import Version
6745
from typing import Any, Callable, Dict, Generator, overload
6846
from unittest import SkipTest
6947
from urllib.parse import quote_plus
@@ -78,6 +56,32 @@
7856
from pymongo.server_api import ServerApi
7957
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
8058

59+
sys.path[0:0] = [""]
60+
61+
from test.helpers import (
62+
COMPRESSORS,
63+
IS_SRV,
64+
MONGODB_API_VERSION,
65+
MULTI_MONGOS_LB_URI,
66+
TEST_LOADBALANCER,
67+
TEST_SERVERLESS,
68+
TLS_OPTIONS,
69+
SystemCertsPatcher,
70+
client_knobs,
71+
db_pwd,
72+
db_user,
73+
global_knobs,
74+
host,
75+
is_server_resolvable,
76+
port,
77+
print_running_topology,
78+
print_thread_stacks,
79+
print_thread_tracebacks,
80+
sanitize_cmd,
81+
sanitize_reply,
82+
)
83+
from test.version import Version
84+
8185
_IS_SYNC = False
8286

8387

@@ -865,18 +869,66 @@ async def max_message_size_bytes(self):
865869
# Reusable client context
866870
async_client_context = AsyncClientContext()
867871

872+
# Global event loop for async tests.
873+
LOOP = None
874+
875+
876+
def get_loop() -> asyncio.AbstractEventLoop:
877+
"""Get the test suite's global event loop."""
878+
global LOOP
879+
if LOOP is None:
880+
try:
881+
LOOP = asyncio.get_running_loop()
882+
except RuntimeError:
883+
# no running event loop, fallback to get_event_loop.
884+
try:
885+
# Ignore DeprecationWarning: There is no current event loop
886+
with warnings.catch_warnings():
887+
warnings.simplefilter("ignore", DeprecationWarning)
888+
LOOP = asyncio.get_event_loop()
889+
except RuntimeError:
890+
LOOP = asyncio.new_event_loop()
891+
asyncio.set_event_loop(LOOP)
892+
return LOOP
893+
894+
895+
class AsyncPyMongoTestCase(unittest.TestCase):
896+
if not _IS_SYNC:
897+
# An async TestCase that uses a single event loop for all tests.
898+
# Inspired by IsolatedAsyncioTestCase.
899+
async def asyncSetUp(self):
900+
pass
868901

869-
async def reset_client_context():
870-
if _IS_SYNC:
871-
# sync tests don't need to reset a client context
872-
return
873-
elif async_client_context.client is not None:
874-
await async_client_context.client.close()
875-
async_client_context.client = None
876-
await async_client_context._init_client()
902+
async def asyncTearDown(self):
903+
pass
877904

905+
def addAsyncCleanup(self, func, /, *args, **kwargs):
906+
self.addCleanup(*(func, *args), **kwargs)
907+
908+
def _callSetUp(self):
909+
self.setUp()
910+
self._callAsync(self.asyncSetUp)
911+
912+
def _callTestMethod(self, method):
913+
self._callMaybeAsync(method)
914+
915+
def _callTearDown(self):
916+
self._callAsync(self.asyncTearDown)
917+
self.tearDown()
918+
919+
def _callCleanup(self, function, *args, **kwargs):
920+
self._callMaybeAsync(function, *args, **kwargs)
921+
922+
def _callAsync(self, func, /, *args, **kwargs):
923+
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
924+
return get_loop().run_until_complete(func(*args, **kwargs))
925+
926+
def _callMaybeAsync(self, func, /, *args, **kwargs):
927+
if inspect.iscoroutinefunction(func):
928+
return get_loop().run_until_complete(func(*args, **kwargs))
929+
else:
930+
return func(*args, **kwargs)
878931

879-
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
880932
def assertEqualCommand(self, expected, actual, msg=None):
881933
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
882934

@@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
11541206

11551207
@async_client_context.require_connection
11561208
async def asyncSetUp(self) -> None:
1157-
if not _IS_SYNC:
1158-
await reset_client_context()
11591209
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
11601210
raise SkipTest("this test does not support load balancers")
11611211
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@@ -1204,6 +1254,9 @@ async def asyncTearDown(self) -> None:
12041254

12051255

12061256
async def async_setup():
1257+
if not _IS_SYNC:
1258+
# Set up the event loop.
1259+
get_loop()
12071260
await async_client_context.init()
12081261
warnings.resetwarnings()
12091262
warnings.simplefilter("always")

test/asynchronous/test_gridfs_spec.py

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2015-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test the AsyncGridFS unified spec tests."""
16+
from __future__ import annotations
17+
18+
import os
19+
import sys
20+
from pathlib import Path
21+
22+
sys.path[0:0] = [""]
23+
24+
from test import unittest
25+
from test.asynchronous.unified_format import generate_test_classes
26+
27+
_IS_SYNC = False
28+
29+
# Location of JSON test specifications.
30+
if _IS_SYNC:
31+
TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs")
32+
else:
33+
TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs")
34+
35+
# Generate unified tests.
36+
globals().update(generate_test_classes(TEST_PATH, module=__name__))
37+
38+
if __name__ == "__main__":
39+
unittest.main()

0 commit comments

Comments
 (0)