From ca01ac37143886d5a3a362aaa7ed6e03d7d82652 Mon Sep 17 00:00:00 2001 From: ArthurKordes <75675106+ArthurKordes@users.noreply.github.com> Date: Wed, 6 Nov 2024 14:05:47 +0100 Subject: [PATCH] Unit test output (#69) * Second unit test and chispa * some more unit tests * Added unit tests and test schema (#65) Co-authored-by: esra ozturk * Update pyproject.toml * Update pyproject.toml * Isolated some output logic and added unit tests * Merge fix * Update pyproject.toml * Update README.md --------- Co-authored-by: esraozturkerdem Co-authored-by: esra ozturk --- README.md | 2 + pyproject.toml | 3 +- src/dq_suite/output_transformations.py | 296 +++++++++------- tests/conftest.py | 12 + tests/test_data/dq_result.json | 30 ++ tests/test_data/dq_rules.json | 21 ++ tests/test_data/test_schema.py | 7 + tests/test_input_helpers.py | 16 +- tests/test_output_transformations.py | 450 ++++++++++++++++++++++++- 9 files changed, 702 insertions(+), 135 deletions(-) create mode 100644 tests/conftest.py create mode 100644 tests/test_data/dq_result.json diff --git a/README.md b/README.md index 7cd601e..6d94283 100644 --- a/README.md +++ b/README.md @@ -80,3 +80,5 @@ Version 0.8: Implemented output historization Version 0.9: Added dataset descriptions Version 0.10: Switched to GX 1.0 + +Version 0.11: Stability and testability improvements diff --git a/pyproject.toml b/pyproject.toml index 94d8afb..403f29e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "dq-suite-amsterdam" -version = "0.10.6" +version = "0.11.0" authors = [ { name="Arthur Kordes", email="a.kordes@amsterdam.nl" }, { name="Aysegul Cayir Aydar", email="a.cayiraydar@amsterdam.nl" }, @@ -34,6 +34,7 @@ dev = [ 'pylint ~= 2.16', 'autoflake ~= 2.0.1', 'coverage ~= 7.6.1', + 'chispa ~= 0.10.1', ] [tool.isort] diff --git a/src/dq_suite/output_transformations.py b/src/dq_suite/output_transformations.py index 8625db0..c99f6e6 100644 --- a/src/dq_suite/output_transformations.py +++ b/src/dq_suite/output_transformations.py @@ -32,6 +32,8 @@ def create_empty_dataframe( def list_of_dicts_to_df( list_of_dicts: List[dict], spark_session: SparkSession, schema: StructType ) -> DataFrame: + if not isinstance(list_of_dicts, list): + raise TypeError("'list_of_dicts' should be of type 'list'") if len(list_of_dicts) == 0: return create_empty_dataframe( spark_session=spark_session, schema=schema @@ -45,6 +47,8 @@ def construct_regel_id( df: DataFrame, output_columns_list: list[str], ) -> DataFrame: + if not isinstance(output_columns_list, list): + raise TypeError("'output_columns_list' should be of type 'list'") df_with_id = df.withColumn( "regelId", xxhash64(col("regelNaam"), col("regelParameters"), col("bronTabelId")), @@ -52,10 +56,10 @@ def construct_regel_id( return df_with_id.select(*output_columns_list) -def create_parameter_list_from_results(result: dict) -> list[dict]: +def get_parameters_from_results(result: dict) -> list[dict]: parameters = result["kwargs"] parameters.pop("batch_id", None) - return [parameters] + return parameters def get_target_attr_for_rule(result: dict) -> str: @@ -111,28 +115,87 @@ def get_grouped_ids_per_deviating_value( ] -def extract_dq_validatie_data( +def extract_dataset_data(dq_rules_dict: dict) -> list[dict]: + name = dq_rules_dict["dataset"]["name"] + layer = dq_rules_dict["dataset"]["layer"] + return [{"bronDatasetId": name, "medaillonLaag": layer}] + + +def extract_table_data(dq_rules_dict: dict) -> list[dict]: + extracted_data = [] + dataset_name = dq_rules_dict["dataset"]["name"] + for param in dq_rules_dict["tables"]: + table_name = param["table_name"] + tabel_id = f"{dataset_name}_{table_name}" + unique_identifier = param["unique_identifier"] + extracted_data.append( + { + "bronTabelId": tabel_id, + "tabelNaam": table_name, + "uniekeSleutel": unique_identifier, + } + ) + return extracted_data + + +def extract_attribute_data(dq_rules_dict: dict) -> list[dict]: + extracted_data = [] + dataset_name = dq_rules_dict["dataset"]["name"] + used_ids = set() # To keep track of used IDs + for param in dq_rules_dict["tables"]: + table_name = param["table_name"] + tabel_id = f"{dataset_name}_{table_name}" + for rule in param["rules"]: + parameters = rule.get("parameters", []) + if isinstance(parameters, dict) and "column" in parameters: + attribute_name = parameters["column"] + # Create a unique ID + unique_id = f"{tabel_id}_{attribute_name}" + # Check if the ID is already used + if unique_id not in used_ids: + used_ids.add(unique_id) + extracted_data.append( + { + "bronAttribuutId": unique_id, + "attribuutNaam": attribute_name, + "bronTabelId": tabel_id, + } + ) + return extracted_data + + +def extract_regel_data(dq_rules_dict: dict) -> list[dict]: + extracted_data = [] + dataset_name = dq_rules_dict["dataset"]["name"] + for table in dq_rules_dict["tables"]: + table_name = table["table_name"] + tabel_id = f"{dataset_name}_{table_name}" + for rule in table["rules"]: + rule_name = rule["rule_name"] + parameters = rule.get("parameters") + norm = rule.get("norm", None) + column = parameters.get("column", None) + extracted_data.append( + { + "regelNaam": rule_name, + "regelParameters": parameters, + "norm": norm, + "bronTabelId": tabel_id, + "attribuut": column, + } + ) + return extracted_data + + +def extract_validatie_data( table_name: str, dataset_name: str, run_time: datetime, dq_result: CheckpointDescriptionDict, - catalog_name: str, - spark_session: SparkSession, -) -> None: - """ - [insert explanation here] - - :param table_name: Name of the tables - :param dataset_name: - :param run_time: - :param dq_result: # TODO: add dataclass? - :param catalog_name: - :param spark_session: - """ - tabel_id = f"{dataset_name}_{table_name}" - +) -> list[dict]: # "validation_results" is typed List[Dict[str, Any]] in GX dq_result = dq_result["validation_results"] + tabel_id = f"{dataset_name}_{table_name}" extracted_data = [] for validation_result in dq_result: @@ -148,7 +211,7 @@ def extract_dq_validatie_data( ) number_of_valid_records = element_count - unexpected_count expectation_type = expectation_result["expectation_type"] - parameter_list = create_parameter_list_from_results( + parameter_list = get_parameters_from_results( result=expectation_result ) expectation_result["kwargs"].get("column") @@ -167,7 +230,84 @@ def extract_dq_validatie_data( "bronTabelId": tabel_id, } ) + return extracted_data + + +def extract_afwijking_data( + df: DataFrame, + unique_identifier: str, + table_name: str, + dataset_name: str, + run_time: datetime, + dq_result: CheckpointDescriptionDict, +) -> list[dict]: + # "validation_results" is typed List[Dict[str, Any]] in GX + dq_result = dq_result["validation_results"] + tabel_id = f"{dataset_name}_{table_name}" + + extracted_data = [] + if not isinstance(unique_identifier, list): + unique_identifier = [unique_identifier] + + for validation_result in dq_result: + for expectation_result in validation_result["expectations"]: + expectation_type = expectation_result["expectation_type"] + parameter_list = get_parameters_from_results( + result=expectation_result + ) + attribute = get_target_attr_for_rule(result=expectation_result) + deviating_attribute_value = expectation_result["result"].get( + "partial_unexpected_list", [] + ) + unique_deviating_values = get_unique_deviating_values( + deviating_attribute_value + ) + for value in unique_deviating_values: + filtered_df = filter_df_based_on_deviating_values( + value=value, attribute=attribute, df=df + ) + grouped_ids = get_grouped_ids_per_deviating_value( + filtered_df=filtered_df, unique_identifier=unique_identifier + ) + if isinstance(attribute, list): + value = str(value) + extracted_data.append( + { + "identifierVeldWaarde": grouped_ids, + "afwijkendeAttribuutWaarde": value, + "dqDatum": run_time, + "regelNaam": expectation_type, + "regelParameters": parameter_list, + "bronTabelId": tabel_id, + } + ) + return extracted_data + + +def create_dq_validatie( + table_name: str, + dataset_name: str, + run_time: datetime, + dq_result: CheckpointDescriptionDict, + catalog_name: str, + spark_session: SparkSession, +) -> None: + """ + [insert explanation here] + :param table_name: Name of the tables + :param dataset_name: + :param run_time: + :param dq_result: # TODO: add dataclass? + :param catalog_name: + :param spark_session: + """ + extracted_data = extract_validatie_data( + table_name=table_name, + dataset_name=dataset_name, + run_time=run_time, + dq_result=dq_result, + ) df_validatie = list_of_dicts_to_df( list_of_dicts=extracted_data, spark_session=spark_session, @@ -196,7 +336,7 @@ def extract_dq_validatie_data( pass -def extract_dq_afwijking_data( +def create_dq_afwijking( table_name: str, dataset_name: str, dq_result: CheckpointDescriptionDict, @@ -218,48 +358,14 @@ def extract_dq_afwijking_data( :param catalog_name: :param spark_session: """ - tabel_id = f"{dataset_name}_{table_name}" - - # "validation_results" is typed List[Dict[str, Any]] in GX - dq_result = dq_result["validation_results"] - - extracted_data = [] - if not isinstance(unique_identifier, list): - unique_identifier = [unique_identifier] - - for validation_result in dq_result: - for expectation_result in validation_result["expectations"]: - expectation_type = expectation_result["expectation_type"] - parameter_list = create_parameter_list_from_results( - result=expectation_result - ) - attribute = get_target_attr_for_rule(result=expectation_result) - deviating_attribute_value = expectation_result["result"].get( - "partial_unexpected_list", [] - ) - unique_deviating_values = get_unique_deviating_values( - deviating_attribute_value - ) - for value in unique_deviating_values: - filtered_df = filter_df_based_on_deviating_values( - value=value, attribute=attribute, df=df - ) - grouped_ids = get_grouped_ids_per_deviating_value( - filtered_df=filtered_df, unique_identifier=unique_identifier - ) - if isinstance(attribute, list): - value = str(value) - extracted_data.append( - { - "identifierVeldWaarde": grouped_ids, - "afwijkendeAttribuutWaarde": value, - "dqDatum": run_time, - "regelNaam": expectation_type, - "regelParameters": parameter_list, - "bronTabelId": tabel_id, - } - ) - + extracted_data = extract_afwijking_data( + df=df, + unique_identifier=unique_identifier, + table_name=table_name, + dataset_name=dataset_name, + run_time=run_time, + dq_result=dq_result, + ) df_afwijking = list_of_dicts_to_df( list_of_dicts=extracted_data, spark_session=spark_session, @@ -299,9 +405,7 @@ def create_brondataset( :param catalog_name: :param spark_session: """ - name = dq_rules_dict["dataset"]["name"] - layer = dq_rules_dict["dataset"]["layer"] - extracted_data = [{"bronDatasetId": name, "medaillonLaag": layer}] + extracted_data = extract_dataset_data(dq_rules_dict=dq_rules_dict) df_brondataset = list_of_dicts_to_df( list_of_dicts=extracted_data, @@ -336,19 +440,7 @@ def create_brontabel( :param catalog_name: :param spark_session: """ - extracted_data = [] - dataset_name = dq_rules_dict["dataset"]["name"] - for param in dq_rules_dict["tables"]: - table_name = param["table_name"] - tabel_id = f"{dataset_name}_{table_name}" - unique_identifier = param["unique_identifier"] - extracted_data.append( - { - "bronTabelId": tabel_id, - "tabelNaam": table_name, - "uniekeSleutel": unique_identifier, - } - ) + extracted_data = extract_table_data(dq_rules_dict=dq_rules_dict) df_brontabel = list_of_dicts_to_df( list_of_dicts=extracted_data, @@ -384,29 +476,7 @@ def create_bronattribute( :param catalog_name: :param spark_session: """ - extracted_data = [] - dataset_name = dq_rules_dict["dataset"]["name"] - used_ids = set() # To keep track of used IDs - for param in dq_rules_dict["tables"]: - table_name = param["table_name"] - tabel_id = f"{dataset_name}_{table_name}" - for rule in param["rules"]: - parameters = rule.get("parameters", []) - for parameter in parameters: - if isinstance(parameter, dict) and "column" in parameter: - attribute_name = parameter["column"] - # Create a unique ID - unique_id = f"{tabel_id}_{attribute_name}" - # Check if the ID is already used - if unique_id not in used_ids: - used_ids.add(unique_id) - extracted_data.append( - { - "bronAttribuutId": unique_id, - "attribuutNaam": attribute_name, - "bronTabelId": tabel_id, - } - ) + extracted_data = extract_attribute_data(dq_rules_dict=dq_rules_dict) df_bronattribuut = list_of_dicts_to_df( list_of_dicts=extracted_data, @@ -443,25 +513,7 @@ def create_dq_regel( :param catalog_name: :param spark_session: """ - extracted_data = [] - dataset_name = dq_rules_dict["dataset"]["name"] - for table in dq_rules_dict["tables"]: - table_name = table["table_name"] - tabel_id = f"{dataset_name}_{table_name}" - for rule in table["rules"]: - rule_name = rule["rule_name"] - parameters = rule.get("parameters") - norm = rule.get("norm", None) - column = parameters.get("column", None) - extracted_data.append( - { - "regelNaam": rule_name, - "regelParameters": parameters, - "norm": norm, - "bronTabelId": tabel_id, - "attribuut": column, - } - ) + extracted_data = extract_regel_data(dq_rules_dict=dq_rules_dict) df_regel = list_of_dicts_to_df( list_of_dicts=extracted_data, @@ -532,7 +584,7 @@ def write_validation_table( unique_identifier: str, run_time: datetime, ): - extract_dq_validatie_data( + create_dq_validatie( table_name=validation_settings_obj.table_name, dataset_name=dataset_name, run_time=run_time, @@ -540,7 +592,7 @@ def write_validation_table( catalog_name=validation_settings_obj.catalog_name, spark_session=validation_settings_obj.spark_session, ) - extract_dq_afwijking_data( + create_dq_afwijking( table_name=validation_settings_obj.table_name, dataset_name=dataset_name, dq_result=validation_output, diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1a8a77d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import pytest +from tests import TEST_DATA_FOLDER + + +@pytest.fixture +def rules_file_path(): + return f"{TEST_DATA_FOLDER}/dq_rules.json" + + +@pytest.fixture +def result_file_path(): + return f"{TEST_DATA_FOLDER}/dq_result.json" diff --git a/tests/test_data/dq_result.json b/tests/test_data/dq_result.json new file mode 100644 index 0000000..72d9602 --- /dev/null +++ b/tests/test_data/dq_result.json @@ -0,0 +1,30 @@ +{ + "validation_results": [ + { + "expectations": [ + { + "result": { + "element_count": 23538, + "unexpected_count": 1, + "unexpected_percent": 0.00424844931599966, + "partial_unexpected_list": [ + null + ], + "partial_unexpected_counts": [ + { + "value": null, + "count": 1 + } + ] + }, + "expectation_type": "ExpectColumnDistinctValuesToEqualSet", + "kwargs": { + "column": "the_column", + "value_set": [1, 2, 3] + }, + "success": "failure" + } + ] + } + ] +} \ No newline at end of file diff --git a/tests/test_data/dq_rules.json b/tests/test_data/dq_rules.json index ec8f214..d7f2120 100644 --- a/tests/test_data/dq_rules.json +++ b/tests/test_data/dq_rules.json @@ -16,6 +16,27 @@ } } ] + }, + { + "unique_identifier": "other_id", + "table_name": "the_other_table", + "rules": [ + { + "rule_name": "ExpectColumnValuesToBeBetween", + "parameters": { + "column": "the_other_column", + "min_value": 6, + "max_value": 10000 + } + }, + { + "rule_name": "ExpectTableRowCountToBeBetween", + "parameters": { + "min_value": 1, + "max_value": 1000 + } + } + ] } ] } \ No newline at end of file diff --git a/tests/test_data/test_schema.py b/tests/test_data/test_schema.py index a36570b..db49f78 100644 --- a/tests/test_data/test_schema.py +++ b/tests/test_data/test_schema.py @@ -3,3 +3,10 @@ SCHEMA = ( StructType().add("the_string", "string").add("the_timestamp", "timestamp") ) + +SCHEMA2 = ( + StructType() + .add("voornaam", "string") + .add("achternaam", "string") + .add("leeftijd", "integer") +) diff --git a/tests/test_input_helpers.py b/tests/test_input_helpers.py index 430c541..1f615cc 100644 --- a/tests/test_input_helpers.py +++ b/tests/test_input_helpers.py @@ -1,5 +1,4 @@ import pytest -from tests import TEST_DATA_FOLDER from src.dq_suite.input_helpers import ( load_data_quality_rules_from_json_string, @@ -14,13 +13,8 @@ @pytest.fixture -def real_file_path(): - return f"{TEST_DATA_FOLDER}/dq_rules.json" - - -@pytest.fixture -def data_quality_rules_json_string(real_file_path): - return read_data_quality_rules_from_json(file_path=real_file_path) +def data_quality_rules_json_string(rules_file_path): + return read_data_quality_rules_from_json(file_path=rules_file_path) @pytest.fixture @@ -35,7 +29,7 @@ def rules_dict(data_quality_rules_dict): return data_quality_rules_dict["tables"][0] -@pytest.mark.usefixtures("real_file_path") +@pytest.mark.usefixtures("rules_file_path") class TestReadDataQualityRulesFromJson: def test_read_data_quality_rules_from_json_raises_file_not_found_error( self, @@ -44,10 +38,10 @@ def test_read_data_quality_rules_from_json_raises_file_not_found_error( read_data_quality_rules_from_json(file_path="nonexistent_file_path") def test_read_data_quality_rules_from_json_returns_json_string( - self, real_file_path + self, rules_file_path ): data_quality_rules_json_string = read_data_quality_rules_from_json( - file_path=real_file_path + file_path=rules_file_path ) assert isinstance(data_quality_rules_json_string, str) diff --git a/tests/test_output_transformations.py b/tests/test_output_transformations.py index ec64b38..8c09d49 100644 --- a/tests/test_output_transformations.py +++ b/tests/test_output_transformations.py @@ -1,9 +1,46 @@ + +import json +from datetime import datetime + import pytest +from chispa import assert_df_equality from pyspark.sql import SparkSession -from src.dq_suite.output_transformations import create_empty_dataframe +from src.dq_suite.output_transformations import ( + construct_regel_id, + create_empty_dataframe, + get_parameters_from_results, + extract_afwijking_data, + extract_attribute_data, + extract_dataset_data, + extract_regel_data, + extract_table_data, + extract_validatie_data, + filter_df_based_on_deviating_values, + get_grouped_ids_per_deviating_value, + get_target_attr_for_rule, + get_unique_deviating_values, + list_of_dicts_to_df, +) from .test_data.test_schema import SCHEMA as AFWIJKING_SCHEMA +from .test_data.test_schema import SCHEMA2 as AFWIJKING_SCHEMA2 + + +@pytest.mark.usefixtures("rules_file_path") +@pytest.fixture() +def read_test_rules_as_dict(rules_file_path): + with open(rules_file_path, "r") as json_file: + dq_rules_json_string = json_file.read() + return json.loads(dq_rules_json_string) + + +@pytest.mark.usefixtures("result_file_path") +@pytest.fixture() +def read_test_result_as_dict(result_file_path): + with open(result_file_path, "r") as json_file: + dq_result_json_string = json_file.read() + return json.loads(dq_result_json_string) @pytest.fixture() @@ -19,3 +56,414 @@ def test_create_empty_dataframe_returns_empty_dataframe(self, spark): schema=AFWIJKING_SCHEMA, ) assert len(empty_dataframe.head(1)) == 0 + + +@pytest.mark.usefixtures("spark") +class TestListOfDictsToDf: + def test_list_of_dicts_to_df_raises_type_error(self, spark): + with pytest.raises(TypeError): + list_of_dicts_to_df( + list_of_dicts={}, spark_session=spark, schema=AFWIJKING_SCHEMA + ) + + def test_list_of_dicts_to_df_returns_dataframe(self, spark): + current_timestamp = datetime.now() + source_data = [ + {"the_string": "test_string", "the_timestamp": current_timestamp} + ] + + actual_df = list_of_dicts_to_df( + list_of_dicts=source_data, + spark_session=spark, + schema=AFWIJKING_SCHEMA, + ) + + expected_data = [("test_string", current_timestamp)] + expected_df = spark.createDataFrame( + expected_data, ["the_string", "the_timestamp"] + ) + assert_df_equality(actual_df, expected_df) + + +@pytest.mark.usefixtures("spark") +class TestConstructRegelId: + def test_output_columns_list_raises_type_error(self, spark): + df = spark.createDataFrame([("123", "456")], ["123", "456"]) + with pytest.raises(TypeError): + construct_regel_id(df=df, output_columns_list="123") + + def test_construct_regel_id_returns_correct_hash(self, spark): + input_data = [ + ("test_regelNaam", "test_regelParameters", "test_bronTabelId") + ] + input_df = spark.createDataFrame( + input_data, ["regelNaam", "regelParameters", "bronTabelId"] + ) + + actual_df = construct_regel_id( + df=input_df, + output_columns_list=[ + "regelId", + "regelNaam", + ], + ) + + expected_data = [(5287467170918921248, "test_regelNaam")] + expected_df = spark.createDataFrame( + expected_data, ["regelId", "regelNaam"] + ) + expected_df.schema["regelId"].nullable = False + assert_df_equality(actual_df, expected_df) + + +class TestGetParametersFromResults: + def test_get_parameters_from_results_with_and_without_batch_id(self): + result = { + "kwargs": { + "param1": 10, + "param2": "example", + "batch_id": 123, + } + } + result2 = {"kwargs": {"param1": 10, "param2": "example"}} + expected_output = {"param1": 10, "param2": "example"} + + assert get_parameters_from_results(result) == expected_output + assert get_parameters_from_results(result2) == expected_output + + def get_parameters_from_results(self): + result = {"kwargs": {}} + + expected_output = [{}] + assert get_parameters_from_results(result) == expected_output + + def get_parameters_from_results(self): + result = {} + + with pytest.raises(KeyError): + get_parameters_from_results(result) + + +class TestGetTargetAttrForRule: + def test_get_target_attr_for_rule_with_column(self): + result = {"kwargs": {"column": "age", "column_list": ["age", "name"]}} + expected_output = "age" + assert get_target_attr_for_rule(result) == expected_output + + def test_get_target_attr_for_rule_without_column(self): + result = {"kwargs": {"column_list": ["age", "name"]}} + expected_output = ["age", "name"] + assert get_target_attr_for_rule(result) == expected_output + + def test_get_target_attr_for_rule_no_column_or_column_list(self): + result = {"kwargs": {}} + expected_output = None + assert get_target_attr_for_rule(result) == expected_output + + def test_get_target_attr_for_rule_no_kwargs_key(self): + result = {} + with pytest.raises(KeyError): + get_target_attr_for_rule(result) + + +class TestGetUniqueDeviatingValues: + def test_get_unique_deviating_values_empty_list(self): + result = get_unique_deviating_values([]) + expected_output = set() + assert result == expected_output + + def test_get_unique_deviating_values_list_of_strings(self): + result = get_unique_deviating_values(["apple", "banana", "cherry"]) + expected_output = {"apple", "banana", "cherry"} + assert result == expected_output + + def test_get_unique_deviating_values_with_duplicate_strings(self): + result = get_unique_deviating_values(["apple", "banana", "apple"]) + expected_output = {"apple", "banana"} + assert result == expected_output + + def test_get_unique_deviating_values_with_duplicate_dicts(self): + result = get_unique_deviating_values( + [ + {"key1": "value1", "key2": "value2"}, + {"key1": "value1", "key2": "value2"}, # same dict + ] + ) + expected_output = {(("key1", "value1"), ("key2", "value2"))} + assert result == expected_output + + def test_get_unique_deviating_values_with_mixed_dicts_and_strings(self): + result = get_unique_deviating_values( + [ + "apple", + {"key1": "value1", "key2": "value2"}, + "banana", + {"key1": "value1", "key2": "value2"}, # same dict + "apple", # same string + ] + ) + expected_output = { + "apple", + "banana", + (("key1", "value1"), ("key2", "value2")), + } + assert result == expected_output + + +@pytest.mark.usefixtures("spark") +class TestFilterDfBasedOnDeviatingValues: + def test_filter_df_based_on_deviating_values_none_value(self, spark): + data = [("test", None, 20), ("John", None, 24), ("Alice", "Jansen", 45)] + df = spark.createDataFrame(data, AFWIJKING_SCHEMA2) + + result_df = filter_df_based_on_deviating_values(None, "achternaam", df) + expected_data = [("test", None, 20), ("John", None, 24)] + expected_df = spark.createDataFrame(expected_data, AFWIJKING_SCHEMA2) + assert_df_equality(result_df, expected_df) + + def test_filter_df_based_on_deviating_values_single_attribute(self, spark): + data = [ + ("Alice", "Jansen", 30), + ("John", "Doe", 42), + ("Alice", "Taylor", 28), + ] + df = spark.createDataFrame(data, AFWIJKING_SCHEMA2) + result_df = filter_df_based_on_deviating_values("Alice", "voornaam", df) + expected_data = [("Alice", "Jansen", 30), ("Alice", "Taylor", 28)] + expected_df = spark.createDataFrame(expected_data, AFWIJKING_SCHEMA2) + assert_df_equality(result_df, expected_df) + + def test_filter_df_based_on_deviating_values_compound_key(self, spark): + data = [ + ("Alice", "Jansen", 30), + ("John", "Doe", 42), + ("Alice", "Taylor", 28), + ] + df = spark.createDataFrame(data, AFWIJKING_SCHEMA2) + + result_df = filter_df_based_on_deviating_values( + [("voornaam", "Alice"), ("achternaam", "Jansen")], + ["voornaam", "achternaam"], + df, + ) + expected_data = [("Alice", "Jansen", 30)] + expected_df = spark.createDataFrame(expected_data, AFWIJKING_SCHEMA2) + assert_df_equality(result_df, expected_df) + + +@pytest.mark.usefixtures("spark") +class TestGetGroupedIdsPerDeviatingValue: + def test_get_grouped_ids_per_deviating_value(self, spark): + data = [ + ("Alice", "Jansen", 30), + ("John", "Doe", 25), + ("Alice", "Smith", 30), + ("John", "Doe", 25), + ] + df = spark.createDataFrame(data, AFWIJKING_SCHEMA2) + filtered_df = df.filter(df.voornaam == "Alice") + unique_identifier = ["voornaam", "achternaam"] + grouped_ids = get_grouped_ids_per_deviating_value( + filtered_df, unique_identifier + ) + + expected_grouped_ids = [["Alice", "Jansen"], ["Alice", "Smith"]] + assert grouped_ids == expected_grouped_ids + + +@pytest.mark.usefixtures("read_test_rules_as_dict") +class TestExtractDatasetData: + def test_extract_dataset_data_raises_type_error(self): + with pytest.raises(TypeError): + extract_dataset_data(dq_rules_dict="123") + + def test_extract_dataset_data_returns_correct_list( + self, read_test_rules_as_dict + ): + test_output = extract_dataset_data( + dq_rules_dict=read_test_rules_as_dict + ) + expected_result = [ + {"bronDatasetId": "the_dataset", "medaillonLaag": "the_layer"} + ] + assert test_output == expected_result + + +@pytest.mark.usefixtures("read_test_rules_as_dict") +class TestExtractTableData: + def test_extract_table_data_raises_type_error(self): + with pytest.raises(TypeError): + extract_dataset_data(dq_rules_dict="123") + + def test_extract_table_data_returns_correct_list( + self, read_test_rules_as_dict + ): + test_output = extract_table_data(dq_rules_dict=read_test_rules_as_dict) + expected_result = [ + { + "bronTabelId": "the_dataset_the_table", + "tabelNaam": "the_table", + "uniekeSleutel": "id", + }, + { + "bronTabelId": "the_dataset_the_other_table", + "tabelNaam": "the_other_table", + "uniekeSleutel": "other_id", + }, + ] + assert test_output == expected_result + + +@pytest.mark.usefixtures("read_test_rules_as_dict") +class TestExtractAttributeData: + def test_extract_attribute_data_raises_type_error(self): + with pytest.raises(TypeError): + extract_attribute_data(dq_rules_dict="123") + + def test_extract_attribute_data_returns_correct_list( + self, read_test_rules_as_dict + ): + test_output = extract_attribute_data( + dq_rules_dict=read_test_rules_as_dict + ) + expected_result = [ + { + "bronAttribuutId": "the_dataset_the_table_the_column", + "attribuutNaam": "the_column", + "bronTabelId": "the_dataset_the_table", + }, + { + "bronAttribuutId": "the_dataset_the_other_table_the_other_column", + "attribuutNaam": "the_other_column", + "bronTabelId": "the_dataset_the_other_table", + }, + ] + assert test_output == expected_result + + +@pytest.mark.usefixtures("read_test_rules_as_dict") +class TestExtractRegelData: + def test_extract_regel_data_raises_type_error(self): + with pytest.raises(TypeError): + extract_regel_data(dq_rules_dict="123") + + def test_extract_regel_data_returns_correct_list( + self, read_test_rules_as_dict + ): + test_output = extract_regel_data(dq_rules_dict=read_test_rules_as_dict) + expected_result = [ + { + "regelNaam": "ExpectColumnDistinctValuesToEqualSet", + "regelParameters": { + "column": "the_column", + "value_set": [1, 2, 3], + }, + "bronTabelId": "the_dataset_the_table", + "attribuut": "the_column", + "norm": None, + }, + { + "regelNaam": "ExpectColumnValuesToBeBetween", + "regelParameters": { + "column": "the_other_column", + "min_value": 6, + "max_value": 10000, + }, + "bronTabelId": "the_dataset_the_other_table", + "attribuut": "the_other_column", + "norm": None, + }, + { + "regelNaam": "ExpectTableRowCountToBeBetween", + "regelParameters": {"min_value": 1, "max_value": 1000}, + "bronTabelId": "the_dataset_the_other_table", + "attribuut": None, + "norm": None, + }, + ] + assert test_output == expected_result + + +@pytest.mark.usefixtures("read_test_result_as_dict") +class TestExtractValidatieData: + def test_extract_validatie_data_raises_type_error(self): + with pytest.raises(TypeError): + extract_validatie_data( + table_name="table_name", + dataset_name="dataset_name", + run_time=datetime.now(), + dq_result="123", + ) + + def test_extract_validatie_data_returns_correct_list( + self, read_test_result_as_dict + ): + test_output = extract_validatie_data( + table_name="table_name", + dataset_name="dataset_name", + run_time=datetime.now(), + dq_result=read_test_result_as_dict, + ) + test_sample = test_output[0] + del test_sample["dqDatum"] #timestamp will be impossible to get right + expected_result = { + "aantalValideRecords": 23537, + "aantalReferentieRecords": 23538, + "dqResultaat": "success", + "percentageValideRecords": 99, + "regelNaam": "ExpectColumnDistinctValuesToEqualSet", + "regelParameters": { + "column": "the_column", + "value_set": [1, 2, 3], + }, + "bronTabelId": "dataset_name_table_name", + } + assert test_sample == expected_result + + +@pytest.mark.usefixtures("spark") +@pytest.mark.usefixtures("read_test_result_as_dict") +class TestExtractAfwijkingData: + def test_extract_afwijking_data_raises_type_error(self, spark): + with pytest.raises(TypeError): + mock_data = [("str1", "str2")] + mock_df = spark.createDataFrame( + mock_data, ["the_string", "the_other_string"] + ) + extract_afwijking_data( + df=mock_df, + unique_identifier="id", + table_name="table_name", + dataset_name="dataset_name", + run_time=datetime.now(), + dq_result="123", + ) + + def test_extract_afwijking_data_returns_correct_list( + self, spark, read_test_result_as_dict + ): + input_data = [("id1", None), ("id2", "the_value")] + input_df = spark.createDataFrame( + input_data, ["the_key", "the_column"] + ) + test_output = extract_afwijking_data( + df=input_df, + unique_identifier="the_key", + table_name="table_name", + dataset_name="dataset_name", + run_time=datetime.now(), + dq_result=read_test_result_as_dict, + ) + test_sample = test_output[0] + del test_sample["dqDatum"] #timestamp will be impossible to get right + expected_result = { + "identifierVeldWaarde": [["id1"]], + "afwijkendeAttribuutWaarde": None, + "regelNaam": "ExpectColumnDistinctValuesToEqualSet", + "regelParameters": { + "column": "the_column", + "value_set": [1, 2, 3] + }, + "bronTabelId": "dataset_name_table_name", + } + assert test_sample == expected_result