Skip to content

Commit

Permalink
Merge pull request #198 from MarketSquare/feature/execute_parameters
Browse files Browse the repository at this point in the history
Pass parameters list to cursor.execute
  • Loading branch information
amochin authored Nov 20, 2023
2 parents d562452 + e929dc0 commit 170050b
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 41 deletions.
40 changes: 29 additions & 11 deletions src/DatabaseLibrary/assertion.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import List, Optional

from robot.api import logger

Expand All @@ -22,7 +22,12 @@ class Assertion:
"""

def check_if_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Check if any row would be returned by given the input ``selectStatement``. If there are no results, then this will
Expand All @@ -43,13 +48,18 @@ def check_if_exists_in_database(
| Check If Exists In Database | SELECT id FROM person WHERE first_name = 'John' | sansTran=True |
"""
logger.info(f"Executing : Check If Exists In Database | {selectStatement}")
if not self.query(selectStatement, sansTran, alias=alias):
if not self.query(selectStatement, sansTran, alias=alias, parameters=parameters):
raise AssertionError(
msg or f"Expected to have have at least one row, but got 0 rows from: '{selectStatement}'"
)

def check_if_not_exists_in_database(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
This is the negation of `check_if_exists_in_database`.
Expand All @@ -71,14 +81,19 @@ def check_if_not_exists_in_database(
| Check If Not Exists In Database | SELECT id FROM person WHERE first_name = 'John' | sansTran=True |
"""
logger.info(f"Executing : Check If Not Exists In Database | {selectStatement}")
query_results = self.query(selectStatement, sansTran, alias=alias)
query_results = self.query(selectStatement, sansTran, alias=alias, parameters=parameters)
if query_results:
raise AssertionError(
msg or f"Expected to have have no rows from '{selectStatement}', but got some rows: {query_results}"
)

def row_count_is_0(
self, selectStatement: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
self,
selectStatement: str,
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Check if any rows are returned from the submitted ``selectStatement``. If there are, then this will throw an
Expand All @@ -99,7 +114,7 @@ def row_count_is_0(
| Row Count is 0 | SELECT id FROM person WHERE first_name = 'John' | sansTran=True |
"""
logger.info(f"Executing : Row Count Is 0 | {selectStatement}")
num_rows = self.row_count(selectStatement, sansTran, alias=alias)
num_rows = self.row_count(selectStatement, sansTran, alias=alias, parameters=parameters)
if num_rows > 0:
raise AssertionError(msg or f"Expected 0 rows, but {num_rows} were returned from: '{selectStatement}'")

Expand All @@ -110,6 +125,7 @@ def row_count_is_equal_to_x(
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Check if the number of rows returned from ``selectStatement`` is equal to the value submitted. If not, then this
Expand All @@ -129,7 +145,7 @@ def row_count_is_equal_to_x(
| Row Count Is Equal To X | SELECT id FROM person WHERE first_name = 'John' | 0 | sansTran=True |
"""
logger.info(f"Executing : Row Count Is Equal To X | {selectStatement} | {numRows}")
num_rows = self.row_count(selectStatement, sansTran, alias=alias)
num_rows = self.row_count(selectStatement, sansTran, alias=alias, parameters=parameters)
if num_rows != int(numRows.encode("ascii")):
raise AssertionError(
msg or f"Expected {numRows} rows, but {num_rows} were returned from: '{selectStatement}'"
Expand All @@ -142,6 +158,7 @@ def row_count_is_greater_than_x(
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Check if the number of rows returned from ``selectStatement`` is greater than the value submitted. If not, then
Expand All @@ -161,7 +178,7 @@ def row_count_is_greater_than_x(
| Row Count Is Greater Than X | SELECT id FROM person | 1 | sansTran=True |
"""
logger.info(f"Executing : Row Count Is Greater Than X | {selectStatement} | {numRows}")
num_rows = self.row_count(selectStatement, sansTran, alias=alias)
num_rows = self.row_count(selectStatement, sansTran, alias=alias, parameters=parameters)
if num_rows <= int(numRows.encode("ascii")):
raise AssertionError(
msg or f"Expected more than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
Expand All @@ -174,6 +191,7 @@ def row_count_is_less_than_x(
sansTran: bool = False,
msg: Optional[str] = None,
alias: Optional[str] = None,
parameters: Optional[List] = None,
):
"""
Check if the number of rows returned from ``selectStatement`` is less than the value submitted. If not, then this
Expand All @@ -194,7 +212,7 @@ def row_count_is_less_than_x(
"""
logger.info(f"Executing : Row Count Is Less Than X | {selectStatement} | {numRows}")
num_rows = self.row_count(selectStatement, sansTran, alias=alias)
num_rows = self.row_count(selectStatement, sansTran, alias=alias, parameters=parameters)
if num_rows >= int(numRows.encode("ascii")):
raise AssertionError(
msg or f"Expected less than {numRows} rows, but {num_rows} were returned from '{selectStatement}'"
Expand All @@ -204,7 +222,7 @@ def table_must_exist(
self, tableName: str, sansTran: bool = False, msg: Optional[str] = None, alias: Optional[str] = None
):
"""
Check if the table given exists in the database.
Check if the given table exists in the database.
Set optional input ``sansTran`` to True to run command without an
explicit transaction commit or rollback.
Expand Down
74 changes: 44 additions & 30 deletions src/DatabaseLibrary/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ class Query:
"""

def query(
self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None
self, selectStatement: str, sansTran: bool = False, returnAsDict: bool = False, alias: Optional[str] = None, parameters: Optional[List] = None
):
"""
Uses the input ``selectStatement`` to query for the values that will be returned as a list of tuples. Set
optional input ``sansTran`` to True to run command without an explicit transaction commit or rollback.
Uses the input ``selectStatement`` to query for the values that will be returned as a list of tuples.
Set optional input ``returnAsDict`` to True to return values as a list of dictionaries.
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
Expand Down Expand Up @@ -61,15 +60,20 @@ def query(
And get the following
See, Franz Allan
Using optional ``sansTran`` to run command without an explicit transaction commit or rollback:
Use optional ``parameters`` for query variable substitution (variable substitution syntax may be different
depending on the database client):
| parameters | Create List | person |
| Query | SELECT * FROM %s | parameters=${parameters} |
Use optional ``sansTran`` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Query | SELECT * FROM person | True |
"""
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Query | {selectStatement} ")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
all_rows = cur.fetchall()
if returnAsDict:
col_names = [c[0] for c in cur.description]
Expand All @@ -79,10 +83,9 @@ def query(
if cur and not sansTran:
db_connection.client.rollback()

def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None, parameters: Optional[List] = None):
"""
Uses the input ``selectStatement`` to query the database and returns the number of rows from the query. Set
optional input ``sansTran`` to True to run command without an explicit transaction commit or rollback.
Uses the input ``selectStatement`` to query the database and returns the number of rows from the query.
For example, given we have a table `person` with the following data:
| id | first_name | last_name |
Expand All @@ -107,15 +110,20 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Using optional ``sansTran`` to run command without an explicit transaction commit or rollback:
Use optional ``parameters`` for query variable substitution (variable substitution syntax may be different
depending on the database client):
| parameters | Create List | person |
| ${rowCount} | Row Count | SELECT * FROM %s | parameters=${parameters} |
Use optional ``sansTran`` to run command without an explicit transaction commit or rollback:
| ${rowCount} | Row Count | SELECT * FROM person | True |
"""
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Row Count | {selectStatement}")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
data = cur.fetchall()
if db_connection.module_name in ["sqlite3", "ibm_db", "ibm_db_dbi", "pyodbc"]:
return len(data)
Expand All @@ -124,10 +132,9 @@ def row_count(self, selectStatement: str, sansTran: bool = False, alias: Optiona
if cur and not sansTran:
db_connection.client.rollback()

def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None):
def description(self, selectStatement: str, sansTran: bool = False, alias: Optional[str] = None, parameters: Optional[List] = None):
"""
Uses the input ``selectStatement`` to query a table in the db which will be used to determine the description. Set
optional input ``sansTran` to True to run command without an explicit transaction commit or rollback.
Uses the input ``selectStatement`` to query a table in the db which will be used to determine the description.
For example, given we have a table `person` with the following data:
| id | first_name | last_name |
Expand All @@ -146,6 +153,11 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Use optional ``parameters`` for query variable substitution (variable substitution syntax may be different
depending on the database client):
| parameters | Create List | person |
| ${desc} | Description | SELECT * FROM %s | parameters=${parameters} |
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{queryResults} | Description | SELECT * FROM person | True |
"""
Expand All @@ -154,7 +166,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio
try:
cur = db_connection.client.cursor()
logger.info("Executing : Description | {selectStatement}")
self.__execute_sql(cur, selectStatement)
self.__execute_sql(cur, selectStatement, parameters=parameters)
description = list(cur.description)
if sys.version_info[0] < 3:
for row in range(0, len(description)):
Expand All @@ -166,8 +178,7 @@ def description(self, selectStatement: str, sansTran: bool = False, alias: Optio

def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, alias: Optional[str] = None):
"""
Delete all the rows within a given table. Set optional input `sansTran` to True to run command without an
explicit transaction commit or rollback.
Delete all the rows within a given table.
For example, given we have a table `person` in a database
Expand All @@ -184,7 +195,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
Use optional `sansTran` to run command without an explicit transaction commit or rollback:
| Delete All Rows From Table | person | True |
"""
db_connection = self.connection_store.get_connection(alias)
Expand All @@ -207,8 +218,7 @@ def delete_all_rows_from_table(self, tableName: str, sansTran: bool = False, ali
def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, alias: Optional[str] = None):
"""
Executes the content of the `sqlScriptFileName` as SQL commands. Useful for setting the database to a known
state before running your tests, or clearing out your test data after running each a test. Set optional input
`sansTran` to True to run command without an explicit transaction commit or rollback.
state before running your tests, or clearing out your test data after running each a test.
Sample usage :
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql |
Expand Down Expand Up @@ -262,7 +272,7 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
Use optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql Script | ${EXECDIR}${/}resources${/}DDL-setup.sql | True |
"""
db_connection = self.connection_store.get_connection(alias)
Expand Down Expand Up @@ -331,10 +341,9 @@ def execute_sql_script(self, sqlScriptFileName: str, sansTran: bool = False, ali
if cur and not sansTran:
db_connection.client.rollback()

def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None):
def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Optional[str] = None, parameters: Optional[List] = None):
"""
Executes the sqlString as SQL commands. Useful to pass arguments to your sql. Set optional input `sansTran` to
True to run command without an explicit transaction commit or rollback.
Executes the sqlString as SQL commands. Useful to pass arguments to your sql.
SQL commands are expected to be delimited by a semicolon (';').
Expand All @@ -348,15 +357,20 @@ def execute_sql_string(self, sqlString: str, sansTran: bool = False, alias: Opti
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
Use optional ``parameters`` for query variable substitution (variable substitution syntax may be different
depending on the database client):
| parameters | Create List | person_employee_table |
| Execute Sql String | SELECT * FROM %s | parameters=${parameters} |
Use optional `sansTran` to run command without an explicit transaction commit or rollback:
| Execute Sql String | DELETE FROM person_employee_table; DELETE FROM person_table | True |
"""
db_connection = self.connection_store.get_connection(alias)
cur = None
try:
cur = db_connection.client.cursor()
logger.info(f"Executing : Execute SQL String | {sqlString}")
self.__execute_sql(cur, sqlString)
self.__execute_sql(cur, sqlString, parameters=parameters)
if not sansTran:
db_connection.client.commit()
finally:
Expand All @@ -381,8 +395,6 @@ def call_stored_procedure(
It also depends on the database, how the procedure returns the values - as params or as result sets.
E.g. calling a procedure in *PostgreSQL* returns even a single value of an OUT param as a result set.
Set optional input `sansTran` to True to run command without an explicit transaction commit or rollback.
Simple example:
| @{Params} = | Create List | Jerry | out_second_name |
| @{Param values} @{Result sets} = | Call Stored Procedure | Get_second_name | ${Params} |
Expand All @@ -404,7 +416,7 @@ def call_stored_procedure(
Use optional ``alias`` parameter to specify what connection should be used for the query if you have more
than one connection open.
Using optional `sansTran` to run command without an explicit transaction commit or rollback:
Use optional `sansTran` to run command without an explicit transaction commit or rollback:
| @{Param values} @{Result sets} = | Call Stored Procedure | DBName.SchemaName.StoredProcName | ${Params} | True |
"""
db_connection = self.connection_store.get_connection(alias)
Expand Down Expand Up @@ -507,7 +519,7 @@ def call_stored_procedure(
if cur and not sansTran:
db_connection.client.rollback()

def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None):
def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Optional[bool] = None, parameters: Optional[List] = None):
"""
Runs the `sql_statement` using `cur` as Cursor object.
Use `omit_trailing_semicolon` parameter (bool) for explicit instruction,
Expand All @@ -519,5 +531,7 @@ def __execute_sql(self, cur, sql_statement: str, omit_trailing_semicolon: Option
omit_trailing_semicolon = self.omit_trailing_semicolon
if omit_trailing_semicolon:
sql_statement = sql_statement.rstrip(";")
if parameters is None:
parameters = []
logger.debug(f"Executing sql: {sql_statement}")
return cur.execute(sql_statement)
return cur.execute(sql_statement, parameters)
Loading

0 comments on commit 170050b

Please sign in to comment.