1
1
import asyncio
2
+ from collections .abc import Callable
2
3
from inspect import isclass , signature
3
- from typing import Any , List , Optional , Type , Union
4
+ from typing import Any , Optional , Union
4
5
5
6
from beanie .migrations .controllers .base import BaseMigrationController
6
7
from beanie .migrations .utils import update_dict
10
11
11
12
class DummyOutput :
12
13
def __init__ (self ):
13
- super (DummyOutput , self ).__setattr__ ("_internal_structure_dict" , {})
14
+ super ().__setattr__ ("_internal_structure_dict" , {})
14
15
15
16
def __setattr__ (self , key , value ):
16
17
self ._internal_structure_dict [key ] = value
@@ -26,9 +27,7 @@ def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
26
27
if to_parse is None :
27
28
to_parse = self
28
29
input_dict = (
29
- to_parse ._internal_structure_dict
30
- if isinstance (to_parse , DummyOutput )
31
- else to_parse
30
+ to_parse ._internal_structure_dict if isinstance (to_parse , DummyOutput ) else to_parse
32
31
)
33
32
result_dict = {}
34
33
for key , value in input_dict .items ():
@@ -40,29 +39,21 @@ def dict(self, to_parse: Optional[Union[dict, "DummyOutput"]] = None):
40
39
41
40
42
41
def iterative_migration (
43
- document_models : Optional [List [ Type [Document ]]] = None ,
42
+ document_models : Optional [list [ type [Document ]]] = None ,
44
43
batch_size : int = 10000 ,
45
44
):
46
45
class IterativeMigration (BaseMigrationController ):
47
- def __init__ (self , function ) :
46
+ def __init__ (self , function : Callable ) -> None :
48
47
self .function = function
49
48
self .function_signature = signature (function )
50
- input_signature = self .function_signature .parameters .get (
51
- "input_document"
52
- )
49
+ input_signature = self .function_signature .parameters .get ("input_document" )
53
50
if input_signature is None :
54
51
raise RuntimeError ("input_signature must not be None" )
55
- self .input_document_model : Type [Document ] = (
56
- input_signature .annotation
57
- )
58
- output_signature = self .function_signature .parameters .get (
59
- "output_document"
60
- )
52
+ self .input_document_model : type [Document ] = input_signature .annotation
53
+ output_signature = self .function_signature .parameters .get ("output_document" )
61
54
if output_signature is None :
62
55
raise RuntimeError ("output_signature must not be None" )
63
- self .output_document_model : Type [Document ] = (
64
- output_signature .annotation
65
- )
56
+ self .output_document_model : type [Document ] = output_signature .annotation
66
57
67
58
if (
68
59
not isclass (self .input_document_model )
@@ -71,8 +62,7 @@ def __init__(self, function):
71
62
or not issubclass (self .output_document_model , Document )
72
63
):
73
64
raise TypeError (
74
- "input_document and output_document "
75
- "must have annotation of Document subclass"
65
+ "input_document and output_document must have annotation of Document subclass"
76
66
)
77
67
78
68
self .batch_size = batch_size
@@ -81,7 +71,7 @@ def __call__(self, *args: Any, **kwargs: Any):
81
71
pass
82
72
83
73
@property
84
- def models (self ) -> List [ Type [Document ]]:
74
+ def models (self ) -> list [ type [Document ]]:
85
75
preset_models = document_models
86
76
if preset_models is None :
87
77
preset_models = []
@@ -93,9 +83,7 @@ def models(self) -> List[Type[Document]]:
93
83
async def run (self , session ):
94
84
output_documents = []
95
85
all_migration_ops = []
96
- async for input_document in self .input_document_model .find_all (
97
- session = session
98
- ):
86
+ async for input_document in self .input_document_model .find_all (session = session ):
99
87
output = DummyOutput ()
100
88
function_kwargs = {
101
89
"input_document" : input_document ,
@@ -105,14 +93,10 @@ async def run(self, session):
105
93
function_kwargs ["self" ] = None
106
94
await self .function (** function_kwargs )
107
95
output_dict = (
108
- input_document .dict ()
109
- if not IS_PYDANTIC_V2
110
- else input_document .model_dump ()
96
+ input_document .dict () if not IS_PYDANTIC_V2 else input_document .model_dump ()
111
97
)
112
98
update_dict (output_dict , output .dict ())
113
- output_document = parse_model (
114
- self .output_document_model , output_dict
115
- )
99
+ output_document = parse_model (self .output_document_model , output_dict )
116
100
output_documents .append (output_document )
117
101
118
102
if len (output_documents ) == self .batch_size :
0 commit comments