From c4f2a72668ee3840d6e6b08a9413e6ae084794be Mon Sep 17 00:00:00 2001 From: Bartlomiej Hirsz Date: Fri, 10 Nov 2023 11:08:34 +0100 Subject: [PATCH] Move connection to ConnectionStore class --- src/DatabaseLibrary/assertion.py | 2 +- src/DatabaseLibrary/connection_manager.py | 102 +++++++++++++--------- src/DatabaseLibrary/query.py | 14 +-- 3 files changed, 68 insertions(+), 50 deletions(-) diff --git a/src/DatabaseLibrary/assertion.py b/src/DatabaseLibrary/assertion.py index 104a85b..bb19b63 100644 --- a/src/DatabaseLibrary/assertion.py +++ b/src/DatabaseLibrary/assertion.py @@ -287,7 +287,7 @@ def table_must_exist( | Table Must Exist | first_name | msg=my error message | """ logger.info(f"Executing : Table Must Exist | {tableName}") - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) if db_connection.module_name in ["cx_Oracle", "oracledb"]: query = ( "SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND " diff --git a/src/DatabaseLibrary/connection_manager.py b/src/DatabaseLibrary/connection_manager.py index 964b214..09a4279 100644 --- a/src/DatabaseLibrary/connection_manager.py +++ b/src/DatabaseLibrary/connection_manager.py @@ -30,17 +30,12 @@ class Connection: module_name: str -class ConnectionManager: - """ - Connection Manager handles the connection & disconnection to the database. - """ - +class ConnectionStore: def __init__(self): - self.omit_trailing_semicolon: bool = False self._connections: Dict[str, Connection] = {} self.default_alias: str = "default" - def _register_connection(self, client: Any, module_name: str, alias: str): + def register_connection(self, client: Any, module_name: str, alias: str): if alias in self._connections: if alias == self.default_alias: logger.warn("Overwriting not closed connection.") @@ -48,6 +43,53 @@ def _register_connection(self, client: Any, module_name: str, alias: str): logger.warn(f"Overwriting not closed connection for alias = '{alias}'") self._connections[alias] = Connection(client, module_name) + def get_connection(self, alias: Optional[str]): + """ + Return connection with given alias. + + If alias is not provided, it will return default connection. + If there is no default connection, it will return last opened connection. + """ + if not self._connections: + raise ValueError(f"No database connection is open.") + if not alias: + if self.default_alias in self._connections: + return self._connections[self.default_alias] + return list(self._connections.values())[-1] + if alias not in self._connections: + raise ValueError(f"Alias '{alias}' not found in existing connections.") + return self._connections[alias] + + def pop_connection(self, alias: Optional[str]): + if not self._connections: + return None + if not alias: + alias = self.default_alias + if alias not in self._connections: + alias = list(self._connections.keys())[-1] + return self._connections.pop(alias, None) + + def clear(self): + self._connections = {} + + def switch(self, alias: str): + if alias not in self._connections: + raise ValueError(f"Alias '{alias}' not found in existing connections.") + self.default_alias = alias + + def __iter__(self): + return iter(self._connections.values()) + + +class ConnectionManager: + """ + Connection Manager handles the connection & disconnection to the database. + """ + + def __init__(self): + self.omit_trailing_semicolon: bool = False + self.connection_store: ConnectionStore = ConnectionStore() + def connect_to_database( self, dbapiModuleName: Optional[str] = None, @@ -279,7 +321,7 @@ def connect_to_database( host=dbHost, port=dbPort, ) - self._register_connection(db_connection, db_api_module_name, alias) + self.connection_store.register_connection(db_connection, db_api_module_name, alias) def connect_to_database_using_custom_params( self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" @@ -317,7 +359,7 @@ def connect_to_database_using_custom_params( ) db_connection = eval(db_connect_string) - self._register_connection(db_connection, db_api_module_name, alias) + self.connection_store.register_connection(db_connection, db_api_module_name, alias) def connect_to_database_using_custom_connection_string( self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" @@ -341,7 +383,7 @@ def connect_to_database_using_custom_connection_string( f"'{db_connect_string}')" ) db_connection = db_api_2.connect(db_connect_string) - self._register_connection(db_connection, db_api_module_name, alias) + self.connection_store.register_connection(db_connection, db_api_module_name, alias) def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = None): """ @@ -356,19 +398,14 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias: | Disconnect From Database | alias=my_alias | # disconnects from current connection to the database | """ logger.info("Executing : Disconnect From Database") - if not alias: - if not self._connections or self.default_alias in self._connections: - alias = self.default_alias - else: - alias = list(self._connections.keys())[-1] - try: - db_connection = self._connections.pop(alias) - db_connection.client.close() - except KeyError: # Non-existing alias + db_connection = self.connection_store.pop_connection(alias) + if db_connection is None: log_msg = "No open database connection to close" if error_if_no_connection: raise ConnectionError(log_msg) from None logger.info(log_msg) + else: + db_connection.client.close() def disconnect_from_all_databases(self): """ @@ -378,9 +415,9 @@ def disconnect_from_all_databases(self): | Disconnect From All Databases | # disconnects from all connections to the database | """ logger.info("Executing : Disconnect From All Databases") - for db_connection in self._connections.values(): + for db_connection in self.connection_store: db_connection.client.close() - self._connections = {} + self.connection_store.clear() def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None): """ @@ -400,7 +437,7 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None): | Set Auto Commit | False """ logger.info("Executing : Set Auto Commit") - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) db_connection.client.autocommit = autoCommit def switch_database(self, alias: str): @@ -411,23 +448,4 @@ def switch_database(self, alias: str): | Switch Database | my_alias | | Switch Database | alias=my_alias | """ - if alias not in self._connections: - raise ValueError(f"Alias '{alias}' not found in existing connections.") - self.default_alias = alias - - def _get_connection_with_alias(self, alias: Optional[str]) -> Connection: - """ - Return connection with given alias. - - If alias is not provided, it will return default connection. - If there is no default connection, it will return last opened connection. - """ - if not self._connections: - raise ValueError(f"No database connection is open.") - if not alias: - if self.default_alias in self._connections: - return self._connections[self.default_alias] - return list(self._connections.values())[-1] - if alias not in self._connections: - raise ValueError(f"Alias '{alias}' not found in existing connections.") - return self._connections[alias] + self.connection_store.switch(alias) diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index e5b6cb4..2da52fe 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -64,7 +64,7 @@ def query( Using optional ``sansTran`` to run command without an explicit transaction commit or rollback: | @{queryResults} | Query | SELECT * FROM person | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) cur = None try: cur = db_connection.client.cursor() @@ -110,7 +110,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona Using optional ``sansTran`` to run command without an explicit transaction commit or rollback: | ${rowCount} | Row Count | SELECT * FROM person | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) cur = None try: cur = db_connection.client.cursor() @@ -149,7 +149,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{queryResults} | Description | SELECT * FROM person | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) cur = None try: cur = db_connection.client.cursor() @@ -187,7 +187,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Delete All Rows From Table | person | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) cur = None query = f"DELETE FROM {tableName}" try: @@ -265,7 +265,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) with open(sqlScriptFileName, encoding="UTF-8") as sql_file: cur = None try: @@ -351,7 +351,7 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) cur = None try: cur = db_connection.client.cursor() @@ -407,7 +407,7 @@ def call_stored_procedure( Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True | """ - db_connection = self._get_connection_with_alias(alias) + db_connection = self.connection_store.get_connection(alias) if spParams is None: spParams = [] cur = None