Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Array string distance alpha #2195

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions splink/internals/comparison_level_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,66 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str:

def create_label_for_charts(self) -> str:
return f"Distance less than {self.km_threshold}km"

class ArrayStringDistanceLevel(ComparisonLevelCreator):
def __init__(self, col_name: str | ColumnExpression, distance_threshold: int, distance_function: str):
"""Represents a comparison level based around the distance between
arrays

Args:
col_name (str): Input column name
distance_threshold (int): the maximum distance between string
elements in the arrays for this comparison level.
distance_function (str): Distance function name to calculate
pair-wise between arrays
"""

self.col_expression = ColumnExpression.instantiate_if_str(col_name)
self.distance_threshold = validate_numeric_parameter(
lower_bound=0,
upper_bound=float("inf"),
parameter_value=distance_threshold,
level_name=self.__class__.__name__,
parameter_name="distance_threshold",
)
self.distance_function = validate_categorical_parameter(
allowed_values=["levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"],
parameter_value=distance_function,
level_name=self.__class__.__name__,
parameter_name="distance_function"
)
Comment on lines +767 to +772
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's an open question whether it is better for the user to define any function name they want (as in DistanceFunctionLevel) or only to have certain options (but it gets auto-transpiled).


@unsupported_splink_dialects(["sqlite", "spark", "postgres", "athena"])
def create_sql(self, sql_dialect: SplinkDialect) -> str:
self.col_expression.sql_dialect = sql_dialect
col = self.col_expression
if (self.distance_function == "levenshtein"):
d_fn = sql_dialect.levenshtein_function_name
elif (self.distance_function == "damerau_levenshtein"):
d_fn = sql_dialect.damerau_levenshtein_function_name
elif (self.distance_function == "jaro_winkler"):
d_fn = sql_dialect.jaro_winkler_function_name
elif (self.distance_function == "jaro"):
d_fn = sql_dialect.jaro_function_name
return (
f"""list_max(
list_transform(
flatten(
list_transform(
{col.name_l},
x -> list_transform(
{col.name_r},
y -> [x,y]
Comment on lines +792 to +794
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
x -> list_transform(
{col.name_r},
y -> [x,y]
l_item -> list_transform(
{col.name_r},
r_item -> [l_item, r_item]

Just for readability

)
)
),
pair -> {d_fn}(pair[1], pair[2])
)
) <= {self.distance_threshold}"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comparison should likely change to a >= for "jaro" and "jaro_winkler" where higher scores are more similar.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh and distance_threshold should probably be a float, not an int.

)

def create_label_for_charts(self) -> str:
return f"Array string distance <= {self.distance_threshold}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include the distance function name here



class ArrayIntersectLevel(ComparisonLevelCreator):
Expand Down
33 changes: 33 additions & 0 deletions splink/internals/comparison_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,39 @@ def datetime_parse_function(self):
def cll_class(self):
return cll.AbsoluteDateDifferenceLevel

class ArrayStringDistance(ComparisonCreator):
def __init__(
self,
col_name: str,
distance_threshold_or_thresholds: Union[Iterable[int], int] = [1],
distance_function: str = "levenshtein",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should probably be a custom type that lists the allowable values, similar to DateMetricType above.

):
thresholds_as_iterable = ensure_is_iterable(distance_threshold_or_thresholds)
self.thresholds = [*thresholds_as_iterable]
self.distance_function = distance_function
super().__init__(col_name)

def create_comparison_levels(self) -> List[ComparisonLevelCreator]:
return [
cll.NullLevel(self.col_expression),
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1),
*[
cll.ArrayStringDistanceLevel(self.col_expression, distance_threshold=threshold, distance_function=self.distance_function)
for threshold in self.thresholds
],
cll.ElseLevel(),
]

def create_description(self) -> str:
comma_separated_thresholds_string = ", ".join(map(str, self.thresholds))
plural = "s" if len(self.thresholds) > 1 else ""
return (
f"Array string distance at maximum size{plural} "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's include the distance function name here

f"{comma_separated_thresholds_string} vs. anything else"
)

def create_output_column_name(self) -> str:
return self.col_expression.output_column_name

class ArrayIntersectAtSizes(ComparisonCreator):
def __init__(
Expand Down