From 884806ba81e7bb92dc8bcefb945e54e7c518fa41 Mon Sep 17 00:00:00 2001 From: Bart Feenstra Date: Fri, 14 Mar 2025 16:58:45 +0000 Subject: [PATCH] Make gramps.plugins.db.dbapi.dbapi typed --- gramps/plugins/db/dbapi/dbapi.py | 126 +++++++++++++++++++------------ 1 file changed, 76 insertions(+), 50 deletions(-) diff --git a/gramps/plugins/db/dbapi/dbapi.py b/gramps/plugins/db/dbapi/dbapi.py index a0d87dea3f2..1f46349e398 100644 --- a/gramps/plugins/db/dbapi/dbapi.py +++ b/gramps/plugins/db/dbapi/dbapi.py @@ -28,11 +28,15 @@ # Python modules # # ------------------------------------------------------------------------- +from __future__ import annotations import logging import json import time +from collections.abc import Iterator, Collection +from typing import Any, TYPE_CHECKING from gramps.gen.const import GRAMPS_LOCALE as glocale +from gramps.gen.db import DbTxn # ------------------------------------------------------------------------ # @@ -63,6 +67,10 @@ ) from gramps.gen.lib.genderstats import GenderStats from gramps.gen.updatecallback import UpdateCallback +from gramps.gen.utils.grampslocale import GrampsLocale + +if TYPE_CHECKING: + from gramps.plugins.db.dbapi.sqlite import Connection LOG = logging.getLogger(".dbapi") _LOG = logging.getLogger(DBLOGNAME) @@ -78,10 +86,12 @@ class DBAPI(DbGeneric): Database backends class for DB-API 2.0 databases """ + dbapi: Connection + def _initialize(self, directory, username, password): raise NotImplementedError - def use_json_data(self): + def use_json_data(self) -> bool: """ A DBAPI level method for testing if the database supports JSON access. @@ -90,7 +100,7 @@ def use_json_data(self): # if the database has been converted to use JSON data return self.dbapi.column_exists("metadata", "json_data") - def upgrade_table_for_json_data(self, table_name): + def upgrade_table_for_json_data(self, table_name: str) -> None: """ A DBAPI level method for upgrading the given table adding a json_data column. @@ -98,7 +108,7 @@ def upgrade_table_for_json_data(self, table_name): if not self.dbapi.column_exists(table_name, "json_data"): self.dbapi.execute("ALTER TABLE %s ADD COLUMN json_data TEXT;" % table_name) - def _schema_exists(self): + def _schema_exists(self) -> bool: """ Check to see if the schema exists. @@ -107,7 +117,7 @@ def _schema_exists(self): """ return self.dbapi.table_exists("person") - def _create_schema(self, json_data): + def _create_schema(self, json_data) -> None: """ Create and update schema. """ @@ -252,7 +262,7 @@ def _create_schema(self, json_data): self.dbapi.commit() - def _drop_column(self, table_name, column_name): + def _drop_column(self, table_name: str, column_name: str) -> None: """ Used to remove a column of data which we don't need anymore. Must be used within a tranaction @@ -260,10 +270,10 @@ def _drop_column(self, table_name, column_name): """ self.dbapi.drop_column(table_name, column_name) - def _close(self): + def _close(self) -> None: self.dbapi.close() - def _txn_begin(self): + def _txn_begin(self) -> None: """ Lowlevel interface to the backend transaction. Executes a db BEGIN; @@ -272,7 +282,7 @@ def _txn_begin(self): _LOG.debug(" DBAPI %s transaction begin", hex(id(self))) self.dbapi.begin() - def _txn_commit(self): + def _txn_commit(self) -> None: """ Lowlevel interface to the backend transaction. Executes a db END; @@ -281,7 +291,7 @@ def _txn_commit(self): _LOG.debug(" DBAPI %s transaction commit", hex(id(self))) self.dbapi.commit() - def _txn_abort(self): + def _txn_abort(self) -> None: """ Lowlevel interface to the backend transaction. Executes a db ROLLBACK; @@ -289,7 +299,7 @@ def _txn_abort(self): if self.transaction is None: self.dbapi.rollback() - def _collation(self, locale): + def _collation(self, locale: GrampsLocale) -> str: """ Get the adjusted collation if there is one, falling back on the locale.collation. @@ -299,7 +309,7 @@ def _collation(self, locale): return locale.get_collation() return collation - def transaction_begin(self, transaction): + def transaction_begin(self, transaction: DbTxn) -> DbTxn: """ Transactions are handled automatically by the db layer. """ @@ -317,7 +327,7 @@ def transaction_begin(self, transaction): self.dbapi.begin() return transaction - def transaction_commit(self, transaction): + def transaction_commit(self, transaction: DbTxn) -> None: """ Executed at the end of a transaction. """ @@ -364,7 +374,7 @@ def transaction_commit(self, transaction): transaction.clear() self.has_changed += 1 # Also gives commits since startup - def transaction_abort(self, transaction): + def transaction_abort(self, transaction: DbTxn) -> None: """ Executed after a batch operation abort. """ @@ -401,7 +411,7 @@ def _get_metadata(self, key, default="_"): return [] return default - def _set_metadata(self, key, value, use_txn=True): + def _set_metadata(self, key, value, use_txn: bool = True) -> None: """ key: string value: item, will be serialized here @@ -425,7 +435,7 @@ def _set_metadata(self, key, value, use_txn=True): if use_txn: self._txn_commit() - def get_name_group_keys(self): + def get_name_group_keys(self) -> list: """ Return the defined names that have been assigned to a default grouping. """ @@ -433,7 +443,7 @@ def get_name_group_keys(self): # not None test below fixes db corrupted by 11011 for export return [row[0] for row in self.dbapi.fetchall() if row[1] is not None] - def get_name_group_mapping(self, surname): + def get_name_group_mapping(self, surname: str) -> str: """ Return the default grouping name for a surname. """ @@ -444,7 +454,9 @@ def get_name_group_mapping(self, surname): return row[0] return surname - def get_person_handles(self, sort_handles=False, locale=glocale): + def get_person_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Person in the database. @@ -464,7 +476,9 @@ def get_person_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM person") return [row[0] for row in self.dbapi.fetchall()] - def get_family_handles(self, sort_handles=False, locale=glocale): + def get_family_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Family in the database. @@ -496,7 +510,7 @@ def get_family_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM family") return [row[0] for row in self.dbapi.fetchall()] - def get_event_handles(self): + def get_event_handles(self) -> list[str]: """ Return a list of database handles, one handle for each Event in the database. @@ -504,7 +518,9 @@ def get_event_handles(self): self.dbapi.execute("SELECT handle FROM event") return [row[0] for row in self.dbapi.fetchall()] - def get_citation_handles(self, sort_handles=False, locale=glocale): + def get_citation_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Citation in the database. @@ -524,7 +540,9 @@ def get_citation_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM citation") return [row[0] for row in self.dbapi.fetchall()] - def get_source_handles(self, sort_handles=False, locale=glocale): + def get_source_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Source in the database. @@ -544,7 +562,9 @@ def get_source_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle from source") return [row[0] for row in self.dbapi.fetchall()] - def get_place_handles(self, sort_handles=False, locale=glocale): + def get_place_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Place in the database. @@ -564,7 +584,7 @@ def get_place_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM place") return [row[0] for row in self.dbapi.fetchall()] - def get_repository_handles(self): + def get_repository_handles(self) -> list[str]: """ Return a list of database handles, one handle for each Repository in the database. @@ -572,7 +592,9 @@ def get_repository_handles(self): self.dbapi.execute("SELECT handle FROM repository") return [row[0] for row in self.dbapi.fetchall()] - def get_media_handles(self, sort_handles=False, locale=glocale): + def get_media_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Media in the database. @@ -592,7 +614,7 @@ def get_media_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM media") return [row[0] for row in self.dbapi.fetchall()] - def get_note_handles(self): + def get_note_handles(self) -> list[str]: """ Return a list of database handles, one handle for each Note in the database. @@ -600,7 +622,9 @@ def get_note_handles(self): self.dbapi.execute("SELECT handle FROM note") return [row[0] for row in self.dbapi.fetchall()] - def get_tag_handles(self, sort_handles=False, locale=glocale): + def get_tag_handles( + self, sort_handles: bool = False, locale: GrampsLocale = glocale + ) -> list[str]: """ Return a list of database handles, one handle for each Tag in the database. @@ -620,7 +644,7 @@ def get_tag_handles(self, sort_handles=False, locale=glocale): self.dbapi.execute("SELECT handle FROM tag") return [row[0] for row in self.dbapi.fetchall()] - def get_tag_from_name(self, name): + def get_tag_from_name(self, name: str) -> Tag | None: """ Find a Tag in the database from the passed Tag name. @@ -634,13 +658,13 @@ def get_tag_from_name(self, name): return self.serializer.string_to_object(Tag, row[0]) return None - def _get_number_of(self, obj_key): + def _get_number_of(self, obj_key) -> int: table = KEY_TO_NAME_MAP[obj_key] self.dbapi.execute(f"SELECT count(1) FROM {table}") row = self.dbapi.fetchone() return row[0] - def has_name_group_key(self, key): + def has_name_group_key(self, key) -> bool: """ Return if a key exists in the name_group table. """ @@ -648,7 +672,7 @@ def has_name_group_key(self, key): row = self.dbapi.fetchone() return row and row[0] is not None - def set_name_group_mapping(self, name, group): + def set_name_group_mapping(self, name, group) -> None: """ Set the default grouping name for a surname. """ @@ -715,7 +739,7 @@ def _commit_base(self, obj, obj_key, trans, change_time): return old_data - def _commit_raw(self, data, obj_key): + def _commit_raw(self, data, obj_key) -> None: """ Commit a serialized primary object to the database, storing the changes as part of the transaction. @@ -736,7 +760,7 @@ def _commit_raw(self, data, obj_key): [handle, self.serializer.data_to_string(data)], ) - def _update_backlinks(self, obj, transaction): + def _update_backlinks(self, obj, transaction: DbTxn) -> None: if not transaction.batch: # Find existing references self.dbapi.execute( @@ -801,7 +825,7 @@ def _update_backlinks(self, obj, transaction): [obj.handle, obj.__class__.__name__, ref_handle, ref_class_name], ) - def _do_remove(self, handle, transaction, obj_key): + def _do_remove(self, handle: str, transaction: DbTxn, obj_key) -> None: if self.readonly or not handle: return if self._has_handle(obj_key, handle): @@ -813,7 +837,7 @@ def _do_remove(self, handle, transaction, obj_key): if not transaction.batch: transaction.add(obj_key, TXNDEL, handle, data, None) - def _remove_backlinks(self, obj_class, obj_handle, transaction): + def _remove_backlinks(self, obj_class, obj_handle: str, transaction: DbTxn) -> None: """ Removes all references from this object (backlinks). """ @@ -832,7 +856,9 @@ def _remove_backlinks(self, obj_class, obj_handle, transaction): old_data = (obj_handle, obj_class, ref_handle, ref_class_name) transaction.add(REFERENCE_KEY, TXNDEL, key, old_data, None) - def find_backlink_handles(self, handle, include_classes=None): + def find_backlink_handles( + self, handle: str, include_classes: Collection[str] | None = None + ) -> Iterator[tuple[Any, str]]: """ Find all objects that hold a reference to the object handle. @@ -858,7 +884,7 @@ def find_backlink_handles(self, handle, include_classes=None): if (include_classes is None) or (row[0] in include_classes): yield (row[0], row[1]) - def find_initial_person(self): + def find_initial_person(self) -> Person | None: """ Returns first person in the database """ @@ -874,7 +900,7 @@ def find_initial_person(self): return self.get_person_from_handle(row[0]) return None - def _iter_handles(self, obj_key): + def _iter_handles(self, obj_key) -> Iterator[str]: """ Return an iterator over handles in the database """ @@ -913,7 +939,7 @@ def _iter_raw_place_tree_data(self): to_do.append(row[0]) yield (row[0], self.serializer.string_to_data(row[1])) - def reindex_reference_map(self, callback): + def reindex_reference_map(self, callback) -> None: """ Reindex all primary records in the database. """ @@ -1019,12 +1045,12 @@ def rebuild_secondary(self, callback=None): gstats = self.get_gender_stats() self.genderStats = GenderStats(gstats) - def _has_handle(self, obj_key, handle): + def _has_handle(self, obj_key, handle: str) -> bool: table = KEY_TO_NAME_MAP[obj_key] self.dbapi.execute(f"SELECT 1 FROM {table} WHERE handle = ?", [handle]) return self.dbapi.fetchone() is not None - def _has_gramps_id(self, obj_key, gramps_id): + def _has_gramps_id(self, obj_key, gramps_id) -> bool: table = KEY_TO_NAME_MAP[obj_key] self.dbapi.execute(f"SELECT 1 FROM {table} WHERE gramps_id = ?", [gramps_id]) return self.dbapi.fetchone() is not None @@ -1034,7 +1060,7 @@ def _get_gramps_ids(self, obj_key): self.dbapi.execute(f"SELECT gramps_id FROM {table}") return [row[0] for row in self.dbapi.fetchall()] - def _get_raw_data(self, obj_key, handle): + def _get_raw_data(self, obj_key, handle: str): table = KEY_TO_NAME_MAP[obj_key] self.dbapi.execute( f"SELECT {self.serializer.data_field} FROM {table} WHERE handle = ?", @@ -1056,7 +1082,7 @@ def _get_raw_from_id_data(self, obj_key, gramps_id): return self.serializer.string_to_data(row[0]) return None - def get_gender_stats(self): + def get_gender_stats(self) -> dict[str, tuple[int, int, int]]: """ Returns a dictionary of {given_name: (male_count, female_count, unknown_count)} @@ -1067,7 +1093,7 @@ def get_gender_stats(self): gstats[row[0]] = (row[1], row[2], row[3]) return gstats - def save_gender_stats(self, gstats): + def save_gender_stats(self, gstats: GenderStats) -> None: self._txn_begin() self.dbapi.execute("DELETE FROM gender_stats") for key in gstats.stats: @@ -1080,7 +1106,7 @@ def save_gender_stats(self, gstats): ) self._txn_commit() - def undo_reference(self, data, handle): + def undo_reference(self, data, handle: str) -> None: """ Helper method to undo a reference map entry """ @@ -1097,7 +1123,7 @@ def undo_reference(self, data, handle): data, ) - def undo_data(self, data, handle, obj_key): + def undo_data(self, data, handle: str, obj_key) -> None: """ Helper method to undo/redo the changes made """ @@ -1119,7 +1145,7 @@ def undo_data(self, data, handle, obj_key): obj = self.serializer.data_to_object(data, cls) self._update_secondary_values(obj) - def get_surname_list(self): + def get_surname_list(self) -> list[str]: """ Return the list of locale-sorted surnames contained in the database. """ @@ -1129,7 +1155,7 @@ def get_surname_list(self): surname_list.append(row[0]) return surname_list - def _sql_type(self, schema_type, max_length): + def _sql_type(self, schema_type: str, max_length: int) -> str: """ Given a schema type, return the SQL type for a new column. @@ -1144,7 +1170,7 @@ def _sql_type(self, schema_type, max_length): return "REAL" return "BLOB" - def _create_secondary_columns(self): + def _create_secondary_columns(self) -> None: """ Create secondary columns. """ @@ -1169,7 +1195,7 @@ def _create_secondary_columns(self): f"ALTER TABLE {table_name} ADD COLUMN {field} {sql_type}" ) - def _update_secondary_values(self, obj): + def _update_secondary_values(self, obj) -> None: """ Given a primary object update its secondary field values in the database. @@ -1202,7 +1228,7 @@ def _update_secondary_values(self, obj): self._sql_cast_list(values) + [obj.handle], ) - def _sql_cast_list(self, values): + def _sql_cast_list(self, values) -> list[Any]: """ Given a list of field names and values, return the values in the appropriate type.