Skip to content

Commit

Permalink
incremental improvements to ming's type hinting (while avoiding circu…
Browse files Browse the repository at this point in the history
…lar imports)
  • Loading branch information
dill0wn committed Jun 28, 2024
1 parent e610de7 commit e9d3c31
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
18 changes: 12 additions & 6 deletions ming/odm/mapper.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from __future__ import annotations

import typing
import warnings
from copy import copy

from ming.base import Object, NoDefault
from ming.utils import wordwrap

from .base import ObjectState, state, _with_hooks
from .base import ObjectState, ObjectState, state, _with_hooks
from .property import FieldProperty

if typing.TYPE_CHECKING:
# from ming.odm import ODMSession
from . import ODMSession, MappedClass

def mapper(cls, collection=None, session=None, **kwargs):
"""Gets or creates the mapper for the given ``cls`` :class:`.MappedClass`"""
Expand Down Expand Up @@ -75,14 +81,14 @@ def __repr__(self):
self.mapped_class.__name__, self.collection.m.collection_name)

@_with_hooks('insert')
def insert(self, obj, state, session, **kwargs):
def insert(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs):
doc = self.collection(state.document, skip_from_bson=True)
ret = session.impl.insert(doc, validate=False)
state.status = state.clean
return ret

@_with_hooks('update')
def update(self, obj, state, session, **kwargs):
def update(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs):
fields = state.options.get('fields', None)
if fields is None:
fields = ()
Expand All @@ -93,12 +99,12 @@ def update(self, obj, state, session, **kwargs):
return ret

@_with_hooks('delete')
def delete(self, obj, state, session, **kwargs):
def delete(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs):
doc = self.collection(state.document, skip_from_bson=True)
return session.impl.delete(doc)

@_with_hooks('remove')
def remove(self, session, *args, **kwargs):
def remove(self, session: ODMSession, *args, **kwargs):
return session.impl.remove(self.collection, *args, **kwargs)

def create(self, doc, options, remake=True):
Expand Down Expand Up @@ -176,7 +182,7 @@ def compile(self):
for p in self.properties:
p.compile(self)

def update_partial(self, session, *args, **kwargs):
def update_partial(self, session: ODMSession, *args, **kwargs):
return session.impl.update_partial(self.collection, *args, **kwargs)

def _from_doc(self, doc, options, validate=True):
Expand Down
9 changes: 5 additions & 4 deletions ming/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import pymongo
import pymongo.errors
from pymongo.database import Database
import pymongo.collection
import pymongo.database

from .base import Cursor, Object
from .datastore import DataStore
Expand Down Expand Up @@ -48,14 +49,14 @@ def by_name(cls, name):
result = cls._registry[name] = cls(cls._datastores.get(name))
return result

def _impl(self, cls):
def _impl(self, cls) -> pymongo.collection.Collection:
try:
return self.db[cls.m.collection_name]
except TypeError:
raise exc.MongoGone('MongoDB is not connected')

@property
def db(self) -> Database:
def db(self) -> pymongo.database.Database:
if not self.bind:
raise exc.MongoGone('No MongoDB connection for "%s"' % getattr(self, '_name', 'unknown connection'))
return self.bind.db
Expand Down Expand Up @@ -146,7 +147,7 @@ def _prep_save(self, doc, validate):
return data

@annotate_doc_failure
def save(self, doc, *args, **kwargs):
def save(self, doc, *args, **kwargs) -> bson.ObjectId:
data = self._prep_save(doc, kwargs.pop('validate', True))
if args:
data = {arg: data[arg] for arg in args}
Expand Down

0 comments on commit e9d3c31

Please sign in to comment.