Skip to content

Commit

Permalink
Merge branch 'parameters' of https://github.com/carnegiemedal/Robotfr…
Browse files Browse the repository at this point in the history
…amework-Database-Library into feature/execute_parameters
  • Loading branch information
bhirsz committed Nov 19, 2023
2 parents 7df2cca + 079e5ba commit 40749fe
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 10 deletions.
43 changes: 33 additions & 10 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ class Query:
"""

def query(
self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None
self,
selectStatement: str,
sansTran: bool = False,
returnAsDict: bool = False,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Uses the input ``selectStatement`` to query for the values that will be returned as a list of tuples. Set
Expand Down Expand Up @@ -69,7 +74,7 @@ def query(
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Query | {selectStatement} ")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
all_rows = cur.fetchall()
if returnAsDict:
col_names = [c[0] for c in cur.description]
Expand All @@ -79,7 +84,13 @@ def query(
if cur and not sansTran:
db_connection.client.rollback()

def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def row_count(
self,
selectStatement: str,
sansTran: bool = False,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Uses the input ``selectStatement`` to query the database and returns the number of rows from the query. Set
optional input ``sansTran`` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -115,7 +126,7 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Row Count | {selectStatement}")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
data = cur.fetchall()
if db_connection.module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]:
return len(data)
Expand All @@ -124,7 +135,13 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
if cur and not sansTran:
db_connection.client.rollback()

def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def description(
self,
selectStatement: str,
sansTran: bool = False,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Uses the input ``selectStatement`` to query a table in the db which will be used to determine the description. Set
optional input ``sansTran` to True to run command without an explicit transaction commit or rollback.
Expand Down Expand Up @@ -154,7 +171,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
try:
cur = db_connection.client.cursor()
logger.info("Executing : Description | {selectStatement}")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
description = list(cur.description)
if sys.version_info[0] < 3:
for row in range(0, len(description)):
Expand Down Expand Up @@ -331,7 +348,9 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
if cur and not sansTran:
db_connection.client.rollback()

def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None):
def execute_sql_string(
self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None, parameters: Optional[List] = None
):
"""
Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to
True to run command without an explicit transaction commit or rollback.
Expand All @@ -356,7 +375,7 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Execute SQL String | {sqlString}")
self.__execute_sql(cur, sqlString)
self.__execute_sql(cur, sqlString, parameters=parameters)
if not sansTran:
db_connection.client.commit()
finally:
Expand Down Expand Up @@ -507,7 +526,9 @@ def call_stored_procedure(
if cur and not sansTran:
db_connection.client.rollback()

def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None):
def __execute_sql(
self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None, parameters: Optional[List] = None
):
"""
Runs the `sql_statement` using `cur` as Cursor object.
Use `omit_trailing_semicolon` parameter (bool) for explicit instruction,
Expand All @@ -519,5 +540,7 @@ def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Option
omit_trailing_semicolon = self.omit_trailing_semicolon
if omit_trailing_semicolon:
sql_statement = sql_statement.rstrip(";")
if parameters is None:
parameters = []
logger.debug(f"Executing sql: {sql_statement}")
return cur.execute(sql_statement)
return cur.execute(sql_statement, parameters)
13 changes: 13 additions & 0 deletions test/tests/common_tests/basic_tests.robot
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ SQL Statement Ending With Semicolon Works
SQL Statement Ending Without Semicolon Works
Query SELECT * FROM person;

SQL Statement With Parameters Works
@{params}= Create List 2

IF "${DB_MODULE}" in ["oracledb"]
${output}= Query SELECT * FROM person WHERE id < :id parameters=${params}
ELSE IF "${DB_MODULE}" in ["sqlite3", "pyodbc"]
${output}= Query SELECT * FROM person WHERE id < ? parameters=${params}
ELSE
${output}= Query SELECT * FROM person WHERE id < %s parameters=${params}
END

Length Should Be ${output} 1

Create Person Table
[Setup] Log No setup for this test
${output}= Create Person Table
Expand Down

0 comments on commit 40749fe

Please sign in to comment.