diff --git a/src/DatabaseLibrary/__init__.py b/src/DatabaseLibrary/__init__.py index efb757f..cd70e0b 100644 --- a/src/DatabaseLibrary/__init__.py +++ b/src/DatabaseLibrary/__init__.py @@ -69,7 +69,7 @@ class DatabaseLibrary(ConnectionManager, Query, Assertion): The library is basically compatible with any [https://peps.python.org/pep-0249|Python Database API Specification 2.0] module. However, the actual implementation in existing Python modules is sometimes quite different, which requires custom handling in the library. - Therefore there are some modules, which are "natively" supported in the library - and others, which may work and may not. + Therefore, there are some modules, which are "natively" supported in the library - and others, which may work and may not. See more on the [https://github.com/MarketSquare/Robotframework-Database-Library|project page on GitHub]. """ diff --git a/src/DatabaseLibrary/assertion.py b/src/DatabaseLibrary/assertion.py index ea58694..50c3c19 100644 --- a/src/DatabaseLibrary/assertion.py +++ b/src/DatabaseLibrary/assertion.py @@ -21,7 +21,9 @@ class Assertion: Assertion handles all the assertions of Database Library. """ - def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def check_if_exists_in_database( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + ): """ Check if any row would be returned by given the input `selectStatement`. If there are no results, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction @@ -52,7 +54,9 @@ def check_if_exists_in_database(self, selectStatement: str, sansTran: bool = Fal msg or f"Expected to have have at least one row, but got 0 rows from: '{selectStatement}'" ) - def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def check_if_not_exists_in_database( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + ): """ This is the negation of `check_if_exists_in_database`. @@ -86,7 +90,9 @@ def check_if_not_exists_in_database(self, selectStatement: str, sansTran: bool = msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}" ) - def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def row_count_is_0( + self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + ): """ Check if any rows are returned from the submitted `selectStatement`. If there are, then this will throw an AssertionError. Set optional input `sansTran` to True to run command without an explicit transaction commit or @@ -117,7 +123,12 @@ def row_count_is_0(self, selectStatement: str, sansTran: bool = False, msg: Opti raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'") def row_count_is_equal_to_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is equal to the value submitted. If not, then this @@ -152,7 +163,12 @@ def row_count_is_equal_to_x( ) def row_count_is_greater_than_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is greater than the value submitted. If not, then @@ -187,7 +203,12 @@ def row_count_is_greater_than_x( ) def row_count_is_less_than_x( - self, selectStatement: str, numRows: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + self, + selectStatement: str, + numRows: str, + sansTran: bool = False, + msg: Optional[str] = None, + alias: str = "default", ): """ Check if the number of rows returned from `selectStatement` is less than the value submitted. If not, then this @@ -221,7 +242,9 @@ def row_count_is_less_than_x( msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'" ) - def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None): + def table_must_exist( + self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None + ): """ Check if the table given exists in the database. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. The default error message can be overridden with the `msg` argument. @@ -243,20 +266,20 @@ def table_must_exist(self, tableName: str, sansTran: bool = False, msg: Optional | Table Must Exist | first_name | msg=my error message | """ logger.info(f"Executing : Table Must Exist | {tableName}") - _, db_api_module_name = self._cache.switch(alias) - if db_api_module_name in ["cx_Oracle", "oracledb"]: + db_connection = self._get_connection_with_alias(alias) + if db_connection.module_name in ["cx_Oracle", "oracledb"]: query = ( "SELECT * FROM all_objects WHERE object_type IN ('TABLE','VIEW') AND " f"owner = SYS_CONTEXT('USERENV', 'SESSION_USER') AND object_name = UPPER('{tableName}')" ) table_exists = self.row_count(query, sansTran, alias=alias) > 0 - elif db_api_module_name in ["sqlite3"]: + elif db_connection.module_name in ["sqlite3"]: query = f"SELECT name FROM sqlite_master WHERE type='table' AND name='{tableName}' COLLATE NOCASE" table_exists = self.row_count(query, sansTran, alias=alias) > 0 - elif db_api_module_name in ["ibm_db", "ibm_db_dbi"]: + elif db_connection.module_name in ["ibm_db", "ibm_db_dbi"]: query = f"SELECT name FROM SYSIBM.SYSTABLES WHERE type='T' AND name=UPPER('{tableName}')" table_exists = self.row_count(query, sansTran, alias=alias) > 0 - elif db_api_module_name in ["teradata"]: + elif db_connection.module_name in ["teradata"]: query = f"SELECT TableName FROM DBC.TablesV WHERE TableKind='T' AND TableName='{tableName}'" table_exists = self.row_count(query, sansTran, alias=alias) > 0 else: diff --git a/src/DatabaseLibrary/connection_manager.py b/src/DatabaseLibrary/connection_manager.py index 13259b0..4f4381c 100644 --- a/src/DatabaseLibrary/connection_manager.py +++ b/src/DatabaseLibrary/connection_manager.py @@ -13,7 +13,8 @@ # limitations under the License. import importlib -from typing import Optional +from dataclasses import dataclass +from typing import Any, Dict, Optional try: import ConfigParser @@ -21,7 +22,12 @@ import configparser as ConfigParser from robot.api import logger -from robot.utils import ConnectionCache + + +@dataclass +class Connection: + client: Any + module_name: str class ConnectionManager: @@ -30,8 +36,14 @@ class ConnectionManager: """ def __init__(self): - self.omit_trailing_semicolon = False - self._cache = ConnectionCache("No sessions created") + self.omit_trailing_semicolon: bool = False + self._connections: Dict[str, Connection] = {} + self.default_alias: str = "default" + + def _register_connection(self, client: Any, module_name: str, alias: str): + if alias in self._connections: + logger.warn(f"Overwriting not closed connection for alias = '{alias}'") + self._connections[alias] = Connection(client, module_name) def connect_to_database( self, @@ -45,7 +57,7 @@ def connect_to_database( dbDriver: Optional[str] = None, dbConfigFile: Optional[str] = None, driverMode: Optional[str] = None, - alias: Optional[str] = "default", + alias: str = "default", ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -261,10 +273,10 @@ def connect_to_database( host=dbHost, port=dbPort, ) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection(db_connection, db_api_module_name, alias) def connect_to_database_using_custom_params( - self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default" + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -299,10 +311,10 @@ def connect_to_database_using_custom_params( ) db_connection = eval(db_connect_string) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection(db_connection, db_api_module_name, alias) def connect_to_database_using_custom_connection_string( - self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: Optional[str] = "default" + self, dbapiModuleName: Optional[str] = None, db_connect_string: str = "", alias: str = "default" ): """ Loads the DB API 2.0 module given `dbapiModuleName` then uses it to @@ -323,9 +335,9 @@ def connect_to_database_using_custom_connection_string( f"'{db_connect_string}')" ) db_connection = db_api_2.connect(db_connect_string) - self._cache.register((db_connection, db_api_module_name), alias=alias) + self._register_connection(db_connection, db_api_module_name, alias) - def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = "default"): + def disconnect_from_database(self, error_if_no_connection: bool = False, alias: Optional[str] = None): """ Disconnects from the database. @@ -338,13 +350,12 @@ def disconnect_from_database(self, error_if_no_connection: bool = False, alias: | Disconnect From Database | alias=my_alias | # disconnects from current connection to the database | """ logger.info("Executing : Disconnect From Database") + if not alias: + alias = "default" try: - db_connection, _ = self._cache.switch(alias) - except RuntimeError: # Non-existing index or alias - db_connection = None - if db_connection: - db_connection.close() - else: + db_connection = self._connections.pop(alias) + db_connection.client.close() + except KeyError: # Non-existing alias log_msg = "No open database connection to close" if error_if_no_connection: raise ConnectionError(log_msg) from None @@ -358,12 +369,11 @@ def disconnect_from_all_databases(self): | Disconnect From All Databases | # disconnects from all connections to the database | """ logger.info("Executing : Disconnect From All Databases") - for db_connection, _ in self._cache: - if db_connection: - db_connection.close() - self._cache.empty_cache() + for db_connection in self._connections.values(): + db_connection.client.close() + self._connections = {} - def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "default"): + def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = None): """ Turn the autocommit on the database connection ON or OFF. @@ -381,10 +391,10 @@ def set_auto_commit(self, autoCommit: bool = True, alias: Optional[str] = "defau | Set Auto Commit | False """ logger.info("Executing : Set Auto Commit") - db_connection, _ = self._cache.switch(alias) - db_connection.autocommit = autoCommit + db_connection = self._get_connection_with_alias(alias) + db_connection.client.autocommit = autoCommit - def switch_database(self, alias): + def switch_database(self, alias: str): """ Switch default database. @@ -392,4 +402,23 @@ def switch_database(self, alias): | Switch Database | my_alias | | Switch Database | alias=my_alias | """ - self._cache.switch(alias) + if alias not in self._connections: + raise ValueError(f"Alias '{alias}' not found in existing connections.") + self.default_alias = alias + + def _get_connection_with_alias(self, alias: Optional[str]) -> Connection: + """ + Return connection with given alias. + + If alias is not provided, it will return default connection. + If there is no default connection, it will return last opened connection. + """ + if not self._connections: + raise ValueError(f"No database connection is open.") + if not alias: + if self.default_alias in self._connections: + return self._connections[self.default_alias] + return list(self._connections.values())[-1] + if alias not in self._connections: + raise ValueError(f"Alias '{alias}' not found in existing connections.") + return self._connections[alias] diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index a59ba53..899cd75 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -24,7 +24,9 @@ class Query: Query handles all the querying done by the Database Library. """ - def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None): + def query( + self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None + ): """ Uses the input `selectStatement` to query for the values that will be returned as a list of tuples. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. @@ -59,10 +61,10 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{queryResults} | Query | SELECT * FROM person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) cur = None try: - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info(f"Executing : Query | {selectStatement} ") self.__execute_sql(cur, selectStatement) all_rows = cur.fetchall() @@ -72,7 +74,7 @@ def query(self, selectStatement: str, sansTran: bool = False, returnAsDict: bool return all_rows finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None): """ @@ -102,19 +104,19 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona Using optional `sansTran` to run command without an explicit transaction commit or rollback: | ${rowCount} | Row Count | SELECT * FROM person | True | """ - db_connection, db_api_module_name = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) cur = None try: - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info(f"Executing : Row Count | {selectStatement}") self.__execute_sql(cur, selectStatement) data = cur.fetchall() - if db_api_module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]: + if db_connection.module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]: return len(data) return cur.rowcount finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None): """ @@ -138,10 +140,10 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{queryResults} | Description | SELECT * FROM person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) cur = None try: - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info("Executing : Description | {selectStatement}") self.__execute_sql(cur, selectStatement) description = list(cur.description) @@ -151,7 +153,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio return description finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: Optional[str] = None): """ @@ -173,22 +175,22 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Delete All Rows From Table | person | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) cur = None query = f"DELETE FROM {tableName}" try: - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info(f"Executing : Delete All Rows From Table | {query}") result = self.__execute_sql(cur, query) if result is not None: if not sansTran: - db_connection.commit() + db_connection.client.commit() return result if not sansTran: - db_connection.commit() + db_connection.client.commit() finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: Optional[str] = None): """ @@ -249,12 +251,12 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) with open(sqlScriptFileName, encoding="UTF-8") as sql_file: cur = None try: statements_to_execute = [] - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info(f"Executing : Execute SQL Script | {sqlScriptFileName}") current_statement = "" inside_statements_group = False @@ -310,10 +312,10 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali omit_semicolon = not statement.lower().endswith("end;") self.__execute_sql(cur, statement, omit_semicolon) if not sansTran: - db_connection.commit() + db_connection.client.commit() finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None): """ @@ -332,19 +334,21 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True | """ - db_connection, _ = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) cur = None try: - cur = db_connection.cursor() + cur = db_connection.client.cursor() logger.info(f"Executing : Execute SQL String | {sqlString}") self.__execute_sql(cur, sqlString) if not sansTran: - db_connection.commit() + db_connection.client.commit() finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() - def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: Optional[str] = None): + def call_stored_procedure( + self, spName: str, spParams: Optional[List[str]] = None, sansTran: bool = False, alias: Optional[str] = None + ): """ Calls a stored procedure `spName` with the `spParams` - a *list* of parameters the procedure requires. Use the special *CURSOR* value for OUT params, which should receive result sets - @@ -383,21 +387,21 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non Using optional `sansTran` to run command without an explicit transaction commit or rollback: | @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True | """ - db_connection, db_api_module_name = self._cache.switch(alias) + db_connection = self._get_connection_with_alias(alias) if spParams is None: spParams = [] cur = None try: logger.info(f"Executing : Call Stored Procedure | {spName} | {spParams}") - if db_api_module_name == "pymssql": - cur = db_connection.cursor(as_dict=False) + if db_connection.module_name == "pymssql": + cur = db_connection.client.cursor(as_dict=False) else: - cur = db_connection.cursor() + cur = db_connection.client.cursor() param_values = [] result_sets = [] - if db_api_module_name == "pymysql": + if db_connection.module_name == "pymysql": cur.callproc(spName, spParams) # first proceed the result sets if available @@ -414,22 +418,22 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non cur.execute(f"select @_{spName}_{i}") param_values.append(cur.fetchall()[0][0]) - elif db_api_module_name in ["oracledb", "cx_Oracle"]: + elif db_connection.module_name in ["oracledb", "cx_Oracle"]: # check if "CURSOR" params were passed - they will be replaced # with cursor variables for storing the result sets params_substituted = spParams.copy() cursor_params = [] for i in range(0, len(spParams)): if spParams[i] == "CURSOR": - cursor_param = db_connection.cursor() + cursor_param = db_connection.client.cursor() params_substituted[i] = cursor_param cursor_params.append(cursor_param) param_values = cur.callproc(spName, params_substituted) for result_set in cursor_params: result_sets.append(list(result_set)) - elif db_api_module_name in ["psycopg2", "psycopg3"]: - cur = db_connection.cursor() + elif db_connection.module_name in ["psycopg2", "psycopg3"]: + cur = db_connection.client.cursor() # check if "CURSOR" params were passed - they will be replaced # with cursor variables for storing the result sets params_substituted = spParams.copy() @@ -446,7 +450,7 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non result_set = cur.fetchall() result_sets.append(list(result_set)) else: - if db_api_module_name in ["psycopg3"]: + if db_connection.module_name in ["psycopg3"]: result_sets_available = True while result_sets_available: result_sets.append(list(cur.fetchall())) @@ -457,10 +461,10 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non else: logger.info( - f"CAUTION! Calling a stored procedure for '{db_api_module_name}' is not tested, " + f"CAUTION! Calling a stored procedure for '{db_connection.module_name}' is not tested, " "results might be invalid!" ) - cur = db_connection.cursor() + cur = db_connection.client.cursor() param_values = cur.callproc(spName, spParams) logger.info("Reading the procedure results..") result_sets_available = True @@ -476,12 +480,12 @@ def call_stored_procedure(self, spName: str, spParams: Optional[List[str]] = Non result_sets_available = False if not sansTran: - db_connection.commit() + db_connection.client.commit() return param_values, result_sets finally: if cur and not sansTran: - db_connection.rollback() + db_connection.client.rollback() def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None): """ diff --git a/test/tests/common_tests/aliased_connection.robot b/test/tests/common_tests/aliased_connection.robot index 29b8902..1ff6ee6 100644 --- a/test/tests/common_tests/aliased_connection.robot +++ b/test/tests/common_tests/aliased_connection.robot @@ -45,7 +45,7 @@ Switch Not Existing Alias Execute SQL Script - Insert Data In Person table Connect To DB alias=aliased_conn - ${output}= Execute SQL Script ${CURDIR}/../insert_data_in_person_table.sql alias=aliased_conn + ${output}= Execute SQL Script ../../resources/insert_data_in_person_table.sql alias=aliased_conn Should Be Equal As Strings ${output} None Check If Exists In DB - Franz Allan