Skip to content

Commit

Permalink
Merge pull request #52 from Amsterdam/compound_primary_keys
Browse files Browse the repository at this point in the history
Compound primary keys
  • Loading branch information
SSchotten authored Sep 16, 2024
2 parents 2ccdc7d + 9560936 commit 460a358
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "dq-suite-amsterdam"
version = "0.9.0"
version = "0.9.1"
authors = [
{ name="Arthur Kordes", email="[email protected]" },
{ name="Aysegul Cayir Aydar", email="[email protected]" },
Expand Down
1 change: 1 addition & 0 deletions src/dq_suite/df_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,6 @@ def run(
validation_output=validation_output,
validation_settings_obj=validation_settings_obj,
df=df,
dataset_name=validation_dict["dataset"]["name"],
unique_identifier=rules_dict["unique_identifier"],
)
133 changes: 86 additions & 47 deletions src/dq_suite/output_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def list_of_dicts_to_df(


def construct_regel_id(
df: str,
output_columns_list: list,
df: DataFrame,
output_columns_list: list[str],
) -> DataFrame:
df_with_id = df.withColumn("regelId", xxhash64(col("regelNaam"), col("regelParameters"), col("bronTabelId")))
return df_with_id.select(*output_columns_list)
Expand All @@ -51,8 +51,56 @@ def create_parameter_list_from_results(result: dict) -> list[dict]:
return [parameters]


def get_target_attr_for_rule(result: dict) -> str:
if "column" in result["expectation_config"]["kwargs"]:
return result["expectation_config"]["kwargs"].get("column")
else:
return result["expectation_config"]["kwargs"].get("column_list")


def get_unique_deviating_values(deviating_attribute_value: list[str]) -> set[str]:
unique_deviating_values = set()
for waarde in deviating_attribute_value:
if isinstance(waarde, dict):
waarde = tuple(waarde.items()) #transform because a dict cannot be added to a set
unique_deviating_values.add(waarde)
return unique_deviating_values


def filter_df_based_on_deviating_values(
value: str,
attribute: str,
df: DataFrame,
) -> DataFrame:
if value is None:
return df.filter(col(attribute).isNull())
elif isinstance(attribute, list):
# In case of compound keys, "attribute" is a list and "value" is a dict like tuple.
# The indeces will match, and we take [1] for value, because the "key" is stored in [0].
number_of_attrs = len(attribute)
for i in range(number_of_attrs):
df = df.filter(col(attribute[i]) == value[i][1])
return df
else:
return df.filter(col(attribute) == value)


def get_grouped_ids_per_deviating_value(
filtered_df: DataFrame,
unique_identifier: list[str],
) -> list[str]:
ids = (
filtered_df.select(unique_identifier)
.rdd.flatMap(lambda x: x)
.collect()
)
number_of_unique_ids = len(unique_identifier)
return [ids[x:x+number_of_unique_ids] for x in range(0, len(ids), number_of_unique_ids)]


def extract_dq_validatie_data(
table_name: str,
dataset_name: str,
dq_result: dict,
catalog_name: str,
spark_session: SparkSession,
Expand All @@ -65,10 +113,8 @@ def extract_dq_validatie_data(
:param catalog_name:
:param spark_session:
"""

# Access run_time attribute
tabel_id = f"{dataset_name}_{table_name}"
run_time = dq_result["meta"]["run_id"].run_time
# Extracted data
extracted_data = []
for result in dq_result["results"]:
element_count = int(result["result"].get("element_count", 0))
Expand All @@ -88,10 +134,10 @@ def extract_dq_validatie_data(
"dqResultaat": output_text,
"regelNaam": expectation_type,
"regelParameters": parameter_list,
"bronTabelId": table_name,
"bronTabelId": tabel_id,
}
)

df_validatie = list_of_dicts_to_df(
list_of_dicts=extracted_data,
spark_session=spark_session,
Expand All @@ -115,6 +161,7 @@ def extract_dq_validatie_data(

def extract_dq_afwijking_data(
table_name: str,
dataset_name: str,
dq_result: dict, # TODO: add dataclass?
df: DataFrame,
unique_identifier: str,
Expand All @@ -131,53 +178,42 @@ def extract_dq_afwijking_data(
:param catalog_name:
:param spark_session:
"""
# Extracting information from the JSON
run_time = dq_result["meta"]["run_id"].run_time # Access run_time attribute
# Extracted data for df
tabel_id = f"{dataset_name}_{table_name}"
run_time = dq_result["meta"]["run_id"].run_time # Get the run timestamp
extracted_data = []

# To store unique combinations of value and IDs
unique_entries = set()
if not isinstance(unique_identifier, list): unique_identifier = [unique_identifier]

for result in dq_result["results"]:
expectation_type = result["expectation_config"]["expectation_type"]
parameter_list = create_parameter_list_from_results(result=result)
attribute = result["expectation_config"]["kwargs"].get("column")
afwijkende_attribuut_waarde = result["result"].get(
attribute = get_target_attr_for_rule(result=result)
deviating_attribute_value = result["result"].get(
"partial_unexpected_list", []
)
for value in afwijkende_attribuut_waarde:
if value is None:
filtered_df = df.filter(col(attribute).isNull())
ids = (
filtered_df.select(unique_identifier)
.rdd.flatMap(lambda x: x)
.collect()
)
else:
filtered_df = df.filter(col(attribute) == value)
ids = (
filtered_df.select(unique_identifier)
.rdd.flatMap(lambda x: x)
.collect()
)

for id_value in ids:
entry = id_value
if (
entry not in unique_entries
): # Check for uniqueness before appending
unique_entries.add(entry)
extracted_data.append(
{
"identifierVeldWaarde": id_value,
"afwijkendeAttribuutWaarde": value,
"dqDatum": run_time,
"regelNaam": expectation_type,
"regelParameters": parameter_list,
"bronTabelId": table_name,
}
)
unique_deviating_values = get_unique_deviating_values(
deviating_attribute_value
)
for value in unique_deviating_values:
filtered_df = filter_df_based_on_deviating_values(
value=value,
attribute=attribute,
df=df
)
grouped_ids = get_grouped_ids_per_deviating_value(
filtered_df=filtered_df,
unique_identifier=unique_identifier
)
if isinstance(attribute, list): value = str(value)
extracted_data.append(
{
"identifierVeldWaarde": grouped_ids,
"afwijkendeAttribuutWaarde": value,
"dqDatum": run_time,
"regelNaam": expectation_type,
"regelParameters": parameter_list,
"bronTabelId": tabel_id,
}
)

df_afwijking = list_of_dicts_to_df(
list_of_dicts=extracted_data,
Expand Down Expand Up @@ -431,18 +467,21 @@ def write_validation_table(
validation_output: Any,
validation_settings_obj: ValidationSettings,
df: DataFrame,
dataset_name: str,
unique_identifier: str,
):
for results in validation_output.values():
result = results["validation_result"]
extract_dq_validatie_data(
validation_settings_obj.table_name,
dataset_name,
result,
validation_settings_obj.catalog_name,
validation_settings_obj.spark_session,
)
extract_dq_afwijking_data(
validation_settings_obj.table_name,
dataset_name,
result,
df,
unique_identifier,
Expand Down

0 comments on commit 460a358

Please sign in to comment.