Skip to content

Commit

Permalink
Improve running SQL script files - separate statements more precisely…
Browse files Browse the repository at this point in the history
… and handle begin/end blocks properly
  • Loading branch information
amochin committed Jul 11, 2023
1 parent 841260a commit 46d6120
Showing 1 changed file with 59 additions and 39 deletions.
98 changes: 59 additions & 39 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,15 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False):
state before running your tests, or clearing out your test data after running each a test. Set optional input
`sansTran` to True to run command without an explicit transaction commit or rollback.
Sample usage :
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql |
| Execute Sql Script | ${EXECDIR}${/}resources${/}DML-setup.sql |
| #interesting stuff here |
| Execute Sql Script | ${EXECDIR}${/}resources${/}DML-teardown.sql |
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-teardown.sql |
SQL commands are expected to be delimited by a semi-colon (';').
SQL commands are expected to be delimited by a semi-colon (';') - they will be executed separately.
For example:
DELETE FROM person_employee_table;
Expand All @@ -232,8 +233,9 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False):
DELETE
FROM employee_table
However, lines that starts with a number sign (`#`) are treated as a
commented line. Thus, none of the contents of that line will be executed.
However, lines that starts with a number sign (`#`) or a double dash ("--")
are treated as a commented line. Thus, none of the contents of that line will be executed.
For example:
# Delete the bridging table first...
Expand All @@ -245,50 +247,68 @@ def execute_sql_script(self, sqlScriptFileName, sansTran=False):
DELETE
FROM employee_table
The slash signs ("/") are always ignored and have no impact on execution order.
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True |
"""
sqlScriptFile = open(sqlScriptFileName ,encoding='UTF-8')

cur = None
try:
cur = self._dbconnection.cursor()
logger.info('Executing : Execute SQL Script | %s ' % sqlScriptFileName)
sqlStatement = ''
for line in sqlScriptFile:
PY3K = sys.version_info >= (3, 0)
if not PY3K:
#spName = spName.encode('ascii', 'ignore')
line = line.strip().decode("utf-8")
if line.startswith('#'):
continue
elif line.startswith('--'):
continue

sqlFragments = line.split(';')
if len(sqlFragments) == 1:
sqlStatement += line + ' '
else:
with open(sqlScriptFileName, encoding='UTF-8') as sql_file:
cur = None
try:
statements_to_execute = []
cur = self._dbconnection.cursor()
logger.info('Executing : Execute SQL Script | %s ' % sqlScriptFileName)
current_statement = ''
inside_statements_group = False

for line in sql_file:
line = line.strip()
if line.startswith('#') or line.startswith('--') or line == "/":
continue
if line.lower().startswith("begin"):
inside_statements_group = True

# semicolons inside the line? use them to separate statements
# ... but not if they are inside a begin/end block (aka. statements group)
sqlFragments = line.split(';')

# no semicolons
if len(sqlFragments) == 1:
current_statement += line + ' '
continue

# "select * from person;" -> ["select..", ""]
for sqlFragment in sqlFragments:
sqlFragment = sqlFragment.strip()
if len(sqlFragment) == 0:
continue

sqlStatement += sqlFragment + ' '

self.__execute_sql(cur, sqlStatement)
sqlStatement = ''

sqlStatement = sqlStatement.strip()
if len(sqlStatement) != 0:
self.__execute_sql(cur, sqlStatement)

if not sansTran:
self._dbconnection.commit()
finally:
if cur:
if inside_statements_group:
# if statements inside a begin/end block have semicolns,
# they must persist - even with oracle
sqlFragment += "; "
if sqlFragment.lower() == "end; ":
inside_statements_group = False
elif sqlFragment.lower().startswith("begin"):
inside_statements_group = True
current_statement += sqlFragment
if not inside_statements_group:
statements_to_execute.append(current_statement.strip())
current_statement = ''

current_statement = current_statement.strip()
if len(current_statement) != 0:
statements_to_execute.append(current_statement)

for statement in statements_to_execute:
logger.info(f"Executing statement from script file: {statement}")
omit_semicolon = not statement.lower().endswith("end;")
self.__execute_sql(cur, statement, omit_semicolon)
if not sansTran:
self._dbconnection.rollback()
self._dbconnection.commit()
finally:
if cur:
if not sansTran:
self._dbconnection.rollback()

def execute_sql_string(self, sqlString, sansTran=False):
"""
Expand Down

0 comments on commit 46d6120

Please sign in to comment.