Skip to content

Commit 3f70a27

Browse files
committed
chore: modernize Beanie codebase and types for py39+
* Removed future annotations imports. * Set maximum line length to 100 from 79. * Re-formatted all files due to above change. * Enabled more Ruff rules (most notably pyupgrade - "UP"). * Fixed "RUF029" rule violation in tests. * Minor docstring fixes. * Introduced is_generic_alias() to beanie.odm.utils.typing module. Provides better support for both new generic types and older generic types from the typing module (e.g. typing.List[str]).
1 parent 380f0d6 commit 3f70a27

File tree

107 files changed

+1201
-2200
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

107 files changed

+1201
-2200
lines changed

.github/scripts/handlers/gh.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import subprocess
22
from dataclasses import dataclass
33
from datetime import datetime
4-
from typing import List
54

65
import requests # type: ignore
76

@@ -31,7 +30,7 @@ def __init__(
3130
self.commits = self.get_commits_after_tag(current_version)
3231
self.prs = [self.get_pr_for_commit(commit) for commit in self.commits]
3332

34-
def get_commits_after_tag(self, tag: str) -> List[str]:
33+
def get_commits_after_tag(self, tag: str) -> list[str]:
3534
result = subprocess.run(
3635
["git", "log", f"{tag}..HEAD", "--pretty=format:%H"],
3736
stdout=subprocess.PIPE,
@@ -64,9 +63,7 @@ def build_markdown_for_many_prs(self) -> str:
6463
return markdown
6564

6665
def commit_changes(self):
67-
self.run_git_command(
68-
["git", "config", "--global", "user.name", "github-actions[bot]"]
69-
)
66+
self.run_git_command(["git", "config", "--global", "user.name", "github-actions[bot]"])
7067
self.run_git_command(
7168
[
7269
"git",
@@ -77,15 +74,13 @@ def commit_changes(self):
7774
]
7875
)
7976
self.run_git_command(["git", "add", "."])
80-
self.run_git_command(
81-
["git", "commit", "-m", f"Bump version to {self.new_version}"]
82-
)
77+
self.run_git_command(["git", "commit", "-m", f"Bump version to {self.new_version}"])
8378
self.run_git_command(["git", "tag", self.new_version])
8479
self.git_push()
8580

8681
def git_push(self):
8782
self.run_git_command(["git", "push", "origin", "main", "--tags"])
8883

8984
@staticmethod
90-
def run_git_command(command: List[str]):
85+
def run_git_command(command: list[str]):
9186
subprocess.run(command, check=True)

.github/scripts/handlers/version.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ def __gt__(self, other):
3131
(self.major > other.major)
3232
or (self.major == other.major and self.minor > other.minor)
3333
or (
34-
self.major == other.major
35-
and self.minor == other.minor
36-
and self.patch > other.patch
34+
self.major == other.major and self.minor == other.minor and self.patch > other.patch
3735
)
3836
)
3937

@@ -47,9 +45,7 @@ def __init__(self):
4745
self.init_py = self.ROOT_PATH / "beanie" / "__init__.py"
4846
self.changelog = self.ROOT_PATH / "docs" / "changelog.md"
4947

50-
self.current_version = self.parse_version_from_pyproject(
51-
self.pyproject
52-
)
48+
self.current_version = self.parse_version_from_pyproject(self.pyproject)
5349
self.pypi_version = self.get_version_from_pypi()
5450

5551
if self.current_version < self.pypi_version:
@@ -68,9 +64,7 @@ def parse_version_from_pyproject(pyproject: Path) -> SemVer:
6864
return SemVer(toml_data["project"]["version"])
6965

7066
def get_version_from_pypi(self) -> SemVer:
71-
response = requests.get(
72-
f"https://pypi.org/pypi/{self.PACKAGE_NAME}/json"
73-
)
67+
response = requests.get(f"https://pypi.org/pypi/{self.PACKAGE_NAME}/json")
7468
if response.status_code == 200:
7569
return SemVer(response.json()["info"]["version"])
7670
raise ValueError("Can't get version from pypi")
@@ -90,9 +84,7 @@ def update_pyproject_version(self):
9084
def update_file_versions(self, files_to_update):
9185
for file_path in files_to_update:
9286
content = file_path.read_text()
93-
content = content.replace(
94-
str(self.pypi_version), str(self.current_version)
95-
)
87+
content = content.replace(str(self.pypi_version), str(self.current_version))
9688
file_path.write_text(content)
9789

9890
def update_changelog(self):

beanie/executors/migrate.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,7 @@ def __init__(self, **kwargs: Any):
4848
or self.get_from_toml("database_name")
4949
)
5050
self.path = Path(
51-
kwargs.get("path")
52-
or self.get_env_value("path")
53-
or self.get_from_toml("path")
51+
kwargs.get("path") or self.get_env_value("path") or self.get_from_toml("path")
5452
)
5553
self.allow_index_dropping = bool(
5654
kwargs.get("allow_index_dropping")
@@ -85,9 +83,9 @@ def get_env_value(field_name) -> Any:
8583
or os.environ.get("beanie_database_name")
8684
)
8785
else:
88-
value = os.environ.get(
89-
f"BEANIE_{field_name.upper()}"
90-
) or os.environ.get(f"beanie_{field_name.lower()}")
86+
value = os.environ.get(f"BEANIE_{field_name.upper()}") or os.environ.get(
87+
f"beanie_{field_name.lower()}"
88+
)
9189
return value
9290

9391
@staticmethod
@@ -96,11 +94,7 @@ def get_from_toml(field_name) -> Any:
9694
if path.is_file():
9795
with path.open("rb") as f:
9896
toml_data = tomllib.load(f)
99-
val = (
100-
toml_data.get("tool", {})
101-
.get("beanie", {})
102-
.get("migrations", {})
103-
)
97+
val = toml_data.get("tool", {}).get("beanie", {}).get("migrations", {})
10498
else:
10599
val = {}
106100
return val.get(field_name)
@@ -114,9 +108,7 @@ def migrations():
114108
async def run_migrate(settings: MigrationSettings):
115109
DBHandler.set_db(settings.connection_uri, settings.database_name)
116110
root = await MigrationNode.build(settings.path)
117-
mode = RunningMode(
118-
direction=settings.direction, distance=settings.distance
119-
)
111+
mode = RunningMode(direction=settings.direction, distance=settings.distance)
120112
await root.run(
121113
mode=mode,
122114
allow_index_dropping=settings.allow_index_dropping,
@@ -158,9 +150,7 @@ async def run_migrate(settings: MigrationSettings):
158150
type=str,
159151
help="MongoDB connection URI",
160152
)
161-
@click.option(
162-
"-db", "--database_name", required=False, type=str, help="DataBase name"
163-
)
153+
@click.option("-db", "--database_name", required=False, type=str, help="DataBase name")
164154
@click.option(
165155
"-p",
166156
"--path",

beanie/migrations/controllers/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Type
32

43
from beanie.odm.documents import Document
54

@@ -14,5 +13,5 @@ async def run(self, session):
1413

1514
@property
1615
@abstractmethod
17-
def models(self) -> List[Type[Document]]:
16+
def models(self) -> list[type[Document]]:
1817
pass

beanie/migrations/controllers/free_fall.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from inspect import signature
2-
from typing import Any, List, Type
2+
from typing import Any
33

44
from beanie.migrations.controllers.base import BaseMigrationController
55
from beanie.odm.documents import Document
66

77

8-
def free_fall_migration(document_models: List[Type[Document]]):
8+
def free_fall_migration(document_models: list[type[Document]]):
99
class FreeFallMigrationController(BaseMigrationController):
1010
def __init__(self, function):
1111
self.function = function
@@ -16,7 +16,7 @@ def __call__(self, *args: Any, **kwargs: Any):
1616
pass
1717

1818
@property
19-
def models(self) -> List[Type[Document]]:
19+
def models(self) -> list[type[Document]]:
2020
return self.document_models
2121

2222
async def run(self, session):

beanie/migrations/controllers/iterative.py

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
2+
from collections.abc import Callable
23
from inspect import isclass, signature
3-
from typing import Any, List, Optional, Type, Union
4+
from typing import Any, Optional, Union
45

56
from beanie.migrations.controllers.base import BaseMigrationController
67
from beanie.migrations.utils import update_dict
@@ -10,7 +11,7 @@
1011

1112
class DummyOutput:
1213
def __init__(self):
13-
super(DummyOutput, self).__setattr__("_internal_structure_dict", {})
14+
super().__setattr__("_internal_structure_dict", {})
1415

1516
def __setattr__(self, key, value):
1617
self._internal_structure_dict[key] = value
@@ -26,9 +27,7 @@ def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
2627
if to_parse is None:
2728
to_parse = self
2829
input_dict = (
29-
to_parse._internal_structure_dict
30-
if isinstance(to_parse, DummyOutput)
31-
else to_parse
30+
to_parse._internal_structure_dict if isinstance(to_parse, DummyOutput) else to_parse
3231
)
3332
result_dict = {}
3433
for key, value in input_dict.items():
@@ -40,29 +39,21 @@ def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
4039

4140

4241
def iterative_migration(
43-
document_models: Optional[List[Type[Document]]] = None,
42+
document_models: Optional[list[type[Document]]] = None,
4443
batch_size: int = 10000,
4544
):
4645
class IterativeMigration(BaseMigrationController):
47-
def __init__(self, function):
46+
def __init__(self, function: Callable) -> None:
4847
self.function = function
4948
self.function_signature = signature(function)
50-
input_signature = self.function_signature.parameters.get(
51-
"input_document"
52-
)
49+
input_signature = self.function_signature.parameters.get("input_document")
5350
if input_signature is None:
5451
raise RuntimeError("input_signature must not be None")
55-
self.input_document_model: Type[Document] = (
56-
input_signature.annotation
57-
)
58-
output_signature = self.function_signature.parameters.get(
59-
"output_document"
60-
)
52+
self.input_document_model: type[Document] = input_signature.annotation
53+
output_signature = self.function_signature.parameters.get("output_document")
6154
if output_signature is None:
6255
raise RuntimeError("output_signature must not be None")
63-
self.output_document_model: Type[Document] = (
64-
output_signature.annotation
65-
)
56+
self.output_document_model: type[Document] = output_signature.annotation
6657

6758
if (
6859
not isclass(self.input_document_model)
@@ -71,8 +62,7 @@ def __init__(self, function):
7162
or not issubclass(self.output_document_model, Document)
7263
):
7364
raise TypeError(
74-
"input_document and output_document "
75-
"must have annotation of Document subclass"
65+
"input_document and output_document must have annotation of Document subclass"
7666
)
7767

7868
self.batch_size = batch_size
@@ -81,7 +71,7 @@ def __call__(self, *args: Any, **kwargs: Any):
8171
pass
8272

8373
@property
84-
def models(self) -> List[Type[Document]]:
74+
def models(self) -> list[type[Document]]:
8575
preset_models = document_models
8676
if preset_models is None:
8777
preset_models = []
@@ -93,9 +83,7 @@ def models(self) -> List[Type[Document]]:
9383
async def run(self, session):
9484
output_documents = []
9585
all_migration_ops = []
96-
async for input_document in self.input_document_model.find_all(
97-
session=session
98-
):
86+
async for input_document in self.input_document_model.find_all(session=session):
9987
output = DummyOutput()
10088
function_kwargs = {
10189
"input_document": input_document,
@@ -105,14 +93,10 @@ async def run(self, session):
10593
function_kwargs["self"] = None
10694
await self.function(**function_kwargs)
10795
output_dict = (
108-
input_document.dict()
109-
if not IS_PYDANTIC_V2
110-
else input_document.model_dump()
96+
input_document.dict() if not IS_PYDANTIC_V2 else input_document.model_dump()
11197
)
11298
update_dict(output_dict, output.dict())
113-
output_document = parse_model(
114-
self.output_document_model, output_dict
115-
)
99+
output_document = parse_model(self.output_document_model, output_dict)
116100
output_documents.append(output_document)
117101

118102
if len(output_documents) == self.batch_size:

beanie/migrations/models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from datetime import datetime
22
from enum import Enum
3-
from typing import List, Optional
3+
from typing import Optional
44

55
from pydantic import Field
66
from pydantic.main import BaseModel
@@ -29,5 +29,5 @@ class RunningMode(BaseModel):
2929

3030
class ParsedMigrations(BaseModel):
3131
path: str
32-
names: List[str]
32+
names: list[str]
3333
current: Optional[MigrationLog] = None

0 commit comments

Comments
 (0)