From e9d3c31b5fa2005148c7aaab71c180e5b7842bb8 Mon Sep 17 00:00:00 2001 From: Dillon Walls Date: Thu, 20 Jun 2024 16:08:08 -0400 Subject: [PATCH] incremental improvements to ming's type hinting (while avoiding circular imports) --- ming/odm/mapper.py | 18 ++++++++++++------ ming/session.py | 9 +++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/ming/odm/mapper.py b/ming/odm/mapper.py index fa957db..fead002 100644 --- a/ming/odm/mapper.py +++ b/ming/odm/mapper.py @@ -1,12 +1,18 @@ +from __future__ import annotations + +import typing import warnings from copy import copy from ming.base import Object, NoDefault from ming.utils import wordwrap -from .base import ObjectState, state, _with_hooks +from .base import ObjectState, ObjectState, state, _with_hooks from .property import FieldProperty +if typing.TYPE_CHECKING: + # from ming.odm import ODMSession + from . import ODMSession, MappedClass def mapper(cls, collection=None, session=None, **kwargs): """Gets or creates the mapper for the given ``cls`` :class:`.MappedClass`""" @@ -75,14 +81,14 @@ def __repr__(self): self.mapped_class.__name__, self.collection.m.collection_name) @_with_hooks('insert') - def insert(self, obj, state, session, **kwargs): + def insert(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs): doc = self.collection(state.document, skip_from_bson=True) ret = session.impl.insert(doc, validate=False) state.status = state.clean return ret @_with_hooks('update') - def update(self, obj, state, session, **kwargs): + def update(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs): fields = state.options.get('fields', None) if fields is None: fields = () @@ -93,12 +99,12 @@ def update(self, obj, state, session, **kwargs): return ret @_with_hooks('delete') - def delete(self, obj, state, session, **kwargs): + def delete(self, obj: MappedClass, state: ObjectState, session: ODMSession, **kwargs): doc = self.collection(state.document, skip_from_bson=True) return session.impl.delete(doc) @_with_hooks('remove') - def remove(self, session, *args, **kwargs): + def remove(self, session: ODMSession, *args, **kwargs): return session.impl.remove(self.collection, *args, **kwargs) def create(self, doc, options, remake=True): @@ -176,7 +182,7 @@ def compile(self): for p in self.properties: p.compile(self) - def update_partial(self, session, *args, **kwargs): + def update_partial(self, session: ODMSession, *args, **kwargs): return session.impl.update_partial(self.collection, *args, **kwargs) def _from_doc(self, doc, options, validate=True): diff --git a/ming/session.py b/ming/session.py index 5a263b0..63d6ee3 100644 --- a/ming/session.py +++ b/ming/session.py @@ -5,7 +5,8 @@ import pymongo import pymongo.errors -from pymongo.database import Database +import pymongo.collection +import pymongo.database from .base import Cursor, Object from .datastore import DataStore @@ -48,14 +49,14 @@ def by_name(cls, name): result = cls._registry[name] = cls(cls._datastores.get(name)) return result - def _impl(self, cls): + def _impl(self, cls) -> pymongo.collection.Collection: try: return self.db[cls.m.collection_name] except TypeError: raise exc.MongoGone('MongoDB is not connected') @property - def db(self) -> Database: + def db(self) -> pymongo.database.Database: if not self.bind: raise exc.MongoGone('No MongoDB connection for "%s"' % getattr(self, '_name', 'unknown connection')) return self.bind.db @@ -146,7 +147,7 @@ def _prep_save(self, doc, validate): return data @annotate_doc_failure - def save(self, doc, *args, **kwargs): + def save(self, doc, *args, **kwargs) -> bson.ObjectId: data = self._prep_save(doc, kwargs.pop('validate', True)) if args: data = {arg: data[arg] for arg in args}