-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Array string distance alpha #2195
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -742,6 +742,66 @@ def create_sql(self, sql_dialect: SplinkDialect) -> str: | |||||||||||||
|
||||||||||||||
def create_label_for_charts(self) -> str: | ||||||||||||||
return f"Distance less than {self.km_threshold}km" | ||||||||||||||
|
||||||||||||||
class ArrayStringDistanceLevel(ComparisonLevelCreator): | ||||||||||||||
def __init__(self, col_name: str | ColumnExpression, distance_threshold: int, distance_function: str): | ||||||||||||||
"""Represents a comparison level based around the distance between | ||||||||||||||
arrays | ||||||||||||||
|
||||||||||||||
Args: | ||||||||||||||
col_name (str): Input column name | ||||||||||||||
distance_threshold (int): the maximum distance between string | ||||||||||||||
elements in the arrays for this comparison level. | ||||||||||||||
distance_function (str): Distance function name to calculate | ||||||||||||||
pair-wise between arrays | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
self.col_expression = ColumnExpression.instantiate_if_str(col_name) | ||||||||||||||
self.distance_threshold = validate_numeric_parameter( | ||||||||||||||
lower_bound=0, | ||||||||||||||
upper_bound=float("inf"), | ||||||||||||||
parameter_value=distance_threshold, | ||||||||||||||
level_name=self.__class__.__name__, | ||||||||||||||
parameter_name="distance_threshold", | ||||||||||||||
) | ||||||||||||||
self.distance_function = validate_categorical_parameter( | ||||||||||||||
allowed_values=["levenshtein", "damerau_levenshtein", "jaro_winkler", "jaro"], | ||||||||||||||
parameter_value=distance_function, | ||||||||||||||
level_name=self.__class__.__name__, | ||||||||||||||
parameter_name="distance_function" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
@unsupported_splink_dialects(["sqlite", "spark", "postgres", "athena"]) | ||||||||||||||
def create_sql(self, sql_dialect: SplinkDialect) -> str: | ||||||||||||||
self.col_expression.sql_dialect = sql_dialect | ||||||||||||||
col = self.col_expression | ||||||||||||||
if (self.distance_function == "levenshtein"): | ||||||||||||||
d_fn = sql_dialect.levenshtein_function_name | ||||||||||||||
elif (self.distance_function == "damerau_levenshtein"): | ||||||||||||||
d_fn = sql_dialect.damerau_levenshtein_function_name | ||||||||||||||
elif (self.distance_function == "jaro_winkler"): | ||||||||||||||
d_fn = sql_dialect.jaro_winkler_function_name | ||||||||||||||
elif (self.distance_function == "jaro"): | ||||||||||||||
d_fn = sql_dialect.jaro_function_name | ||||||||||||||
return ( | ||||||||||||||
f"""list_max( | ||||||||||||||
list_transform( | ||||||||||||||
flatten( | ||||||||||||||
list_transform( | ||||||||||||||
{col.name_l}, | ||||||||||||||
x -> list_transform( | ||||||||||||||
{col.name_r}, | ||||||||||||||
y -> [x,y] | ||||||||||||||
Comment on lines
+792
to
+794
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Just for readability |
||||||||||||||
) | ||||||||||||||
) | ||||||||||||||
), | ||||||||||||||
pair -> {d_fn}(pair[1], pair[2]) | ||||||||||||||
) | ||||||||||||||
) <= {self.distance_threshold}""" | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This comparison should likely change to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh and distance_threshold should probably be a float, not an int. |
||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def create_label_for_charts(self) -> str: | ||||||||||||||
return f"Array string distance <= {self.distance_threshold}" | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's include the distance function name here |
||||||||||||||
|
||||||||||||||
|
||||||||||||||
class ArrayIntersectLevel(ComparisonLevelCreator): | ||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -535,6 +535,39 @@ def datetime_parse_function(self): | |
def cll_class(self): | ||
return cll.AbsoluteDateDifferenceLevel | ||
|
||
class ArrayStringDistance(ComparisonCreator): | ||
def __init__( | ||
self, | ||
col_name: str, | ||
distance_threshold_or_thresholds: Union[Iterable[int], int] = [1], | ||
distance_function: str = "levenshtein", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should probably be a custom type that lists the allowable values, similar to |
||
): | ||
thresholds_as_iterable = ensure_is_iterable(distance_threshold_or_thresholds) | ||
self.thresholds = [*thresholds_as_iterable] | ||
self.distance_function = distance_function | ||
super().__init__(col_name) | ||
|
||
def create_comparison_levels(self) -> List[ComparisonLevelCreator]: | ||
return [ | ||
cll.NullLevel(self.col_expression), | ||
cll.ArrayIntersectLevel(self.col_expression, min_intersection=1), | ||
*[ | ||
cll.ArrayStringDistanceLevel(self.col_expression, distance_threshold=threshold, distance_function=self.distance_function) | ||
for threshold in self.thresholds | ||
], | ||
cll.ElseLevel(), | ||
] | ||
|
||
def create_description(self) -> str: | ||
comma_separated_thresholds_string = ", ".join(map(str, self.thresholds)) | ||
plural = "s" if len(self.thresholds) > 1 else "" | ||
return ( | ||
f"Array string distance at maximum size{plural} " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's include the distance function name here |
||
f"{comma_separated_thresholds_string} vs. anything else" | ||
) | ||
|
||
def create_output_column_name(self) -> str: | ||
return self.col_expression.output_column_name | ||
|
||
class ArrayIntersectAtSizes(ComparisonCreator): | ||
def __init__( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's an open question whether it is better for the user to define any function name they want (as in
DistanceFunctionLevel
) or only to have certain options (but it gets auto-transpiled).