diff --git a/src/DatabaseLibrary/query.py b/src/DatabaseLibrary/query.py index 0797d69..2e817a1 100644 --- a/src/DatabaseLibrary/query.py +++ b/src/DatabaseLibrary/query.py @@ -200,6 +200,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): state before running your tests, or clearing out your test data after running each a test. Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback. + Sample usage : | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | | Execute Sql Script | ${EXECDIR}${/}resources${/}DML-setup.sql | @@ -207,7 +208,7 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): | Execute Sql Script | ${EXECDIR}${/}resources${/}DML-teardown.sql | | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-teardown.sql | - SQL commands are expected to be delimited by a semi-colon (';'). + SQL commands are expected to be delimited by a semi-colon (';') - they will be executed separately. For example: DELETE FROM person_employee_table; @@ -232,8 +233,9 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): DELETE FROM employee_table - However, lines that starts with a number sign (`#`) are treated as a - commented line. Thus, none of the contents of that line will be executed. + However, lines that starts with a number sign (`#`) or a double dash ("--") + are treated as a commented line. Thus, none of the contents of that line will be executed. + For example: # Delete the bridging table first... @@ -245,50 +247,68 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False): DELETE FROM employee_table + The slash signs ("/") are always ignored and have no impact on execution order. + Using optional `sansTran` to run command without an explicit transaction commit or rollback: | Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True | """ - sqlScriptFile = open(sqlScriptFileName ,encoding='UTF-8') - - cur = None - try: - cur = self._dbconnection.cursor() - logger.info('Executing : Execute SQL Script | %s ' % sqlScriptFileName) - sqlStatement = '' - for line in sqlScriptFile: - PY3K = sys.version_info >= (3, 0) - if not PY3K: - #spName = spName.encode('ascii', 'ignore') - line = line.strip().decode("utf-8") - if line.startswith('#'): - continue - elif line.startswith('--'): - continue - - sqlFragments = line.split(';') - if len(sqlFragments) == 1: - sqlStatement += line + ' ' - else: + with open(sqlScriptFileName, encoding='UTF-8') as sql_file: + cur = None + try: + statements_to_execute = [] + cur = self._dbconnection.cursor() + logger.info('Executing : Execute SQL Script | %s ' % sqlScriptFileName) + current_statement = '' + inside_statements_group = False + + for line in sql_file: + line = line.strip() + if line.startswith('#') or line.startswith('--') or line == "/": + continue + if line.lower().startswith("begin"): + inside_statements_group = True + + # semicolons inside the line? use them to separate statements + # ... but not if they are inside a begin/end block (aka. statements group) + sqlFragments = line.split(';') + + # no semicolons + if len(sqlFragments) == 1: + current_statement += line + ' ' + continue + + # "select * from person;" -> ["select..", ""] for sqlFragment in sqlFragments: sqlFragment = sqlFragment.strip() if len(sqlFragment) == 0: continue - - sqlStatement += sqlFragment + ' ' - - self.__execute_sql(cur, sqlStatement) - sqlStatement = '' - - sqlStatement = sqlStatement.strip() - if len(sqlStatement) != 0: - self.__execute_sql(cur, sqlStatement) - - if not sansTran: - self._dbconnection.commit() - finally: - if cur: + if inside_statements_group: + # if statements inside a begin/end block have semicolns, + # they must persist - even with oracle + sqlFragment += "; " + if sqlFragment.lower() == "end; ": + inside_statements_group = False + elif sqlFragment.lower().startswith("begin"): + inside_statements_group = True + current_statement += sqlFragment + if not inside_statements_group: + statements_to_execute.append(current_statement.strip()) + current_statement = '' + + current_statement = current_statement.strip() + if len(current_statement) != 0: + statements_to_execute.append(current_statement) + + for statement in statements_to_execute: + logger.info(f"Executing statement from script file: {statement}") + omit_semicolon = not statement.lower().endswith("end;") + self.__execute_sql(cur, statement, omit_semicolon) if not sansTran: - self._dbconnection.rollback() + self._dbconnection.commit() + finally: + if cur: + if not sansTran: + self._dbconnection.rollback() def execute_sql_string(self, sqlString, sansTran=False): """