@@ -478,7 +478,30 @@ def get_interval_overlaps_nd(
478
478
start_columns : list [str ],
479
479
end_columns : list [str ],
480
480
end_included : bool ,
481
- ):
481
+ ) -> tuple [sa .sql .selectable .CompoundSelect , sa .sql .selectable .Select ]:
482
+ """Create selectables for interval overlaps in n dimensions.
483
+
484
+ We define the presence of 'overlap' as presence of a non-empty intersection
485
+ between two intervals.
486
+
487
+ Given that we care about a single dimension and have two intervals :math:`t1` and :math:`t2`,
488
+ we define an overlap follows:
489
+
490
+ .. math::
491
+ \\ begin{align} \\ text{overlap}(t_1, t_2) \\ Leftrightarrow
492
+ &(min(t_1) \\ leq min(t_2) \\ land max(t_1) \\ geq min(t_2)) \\ \\
493
+ &\\ lor \\ \\
494
+ &(min(t_2) \\ leq min(t_1) \\ land max(t_2) \\ geq min(t_1))
495
+ \\ end{align}
496
+
497
+ We can drop the second clause of the above disjunction if we define :math:`t_1` to be the 'leftmost'
498
+ interval. We do so when building our query.
499
+
500
+ Note that the above equations are representative of ``end_included=True`` and the second clause
501
+ of the conjunction would use a strict inequality if ``end_included=False``.
502
+
503
+ We define an overlap in several dimensions as the conjunction of overlaps in every single dimension.
504
+ """
482
505
if is_snowflake (engine ):
483
506
if key_columns :
484
507
key_columns = lowercase_column_names (key_columns )
@@ -502,10 +525,20 @@ def get_interval_overlaps_nd(
502
525
table_key_columns = get_table_columns (table1 , key_columns ) if key_columns else []
503
526
504
527
end_operator = operator .ge if end_included else operator .gt
505
- violation_condition = sa .and_ (
528
+
529
+ # We have a violation in two scenarios:
530
+ # 1. At least two entries are exactly equal in key and interval columns
531
+ # 2. Two entries are not exactly equal in key and interval_columns and fuilfill violation_condition
532
+
533
+ # Scenario 1
534
+ duplicate_selection = duplicates (table1 )
535
+
536
+ # scenario 2
537
+ naive_violation_condition = sa .and_ (
506
538
* [
507
539
sa .and_ (
508
- table1 .c [start_columns [dimension ]] < table2 .c [start_columns [dimension ]],
540
+ table1 .c [start_columns [dimension ]]
541
+ <= table2 .c [start_columns [dimension ]],
509
542
end_operator (
510
543
table1 .c [end_columns [dimension ]], table2 .c [start_columns [dimension ]]
511
544
),
@@ -514,8 +547,24 @@ def get_interval_overlaps_nd(
514
547
]
515
548
)
516
549
517
- join_condition = sa .and_ (* key_conditions , violation_condition )
518
- violation_selection = sa .select (
550
+ interval_inequality_condition = sa .or_ (
551
+ * [
552
+ sa .or_ (
553
+ table1 .c [start_columns [dimension ]]
554
+ != table2 .c [start_columns [dimension ]],
555
+ table2 .c [end_columns [dimension ]] != table2 .c [end_columns [dimension ]],
556
+ )
557
+ for dimension in range (dimensionality )
558
+ ]
559
+ )
560
+
561
+ distinct_violation_condition = sa .and_ (
562
+ naive_violation_condition ,
563
+ interval_inequality_condition ,
564
+ )
565
+
566
+ distinct_join_condition = sa .and_ (* key_conditions , distinct_violation_condition )
567
+ distinct_violation_selection = sa .select (
519
568
* table_key_columns ,
520
569
* [
521
570
table .c [start_column ]
@@ -527,7 +576,7 @@ def get_interval_overlaps_nd(
527
576
for table in [table1 , table2 ]
528
577
for end_column in end_columns
529
578
],
530
- ).select_from (table1 .join (table2 , join_condition ))
579
+ ).select_from (table1 .join (table2 , distinct_join_condition ))
531
580
532
581
# Note, Kevin, 21/12/09
533
582
# The following approach would likely be preferable to the approach used
@@ -544,6 +593,27 @@ def get_interval_overlaps_nd(
544
593
# violation_subquery
545
594
# )
546
595
596
+ # Merge scenarios 1 and 2.
597
+ # We need to 'impute' the missing columns for the duplicate selection in order for the union between
598
+ # both selections to work.
599
+ duplicate_selection = sa .select (
600
+ * (
601
+ # Already existing columns
602
+ [
603
+ duplicate_selection .c [column ]
604
+ for column in distinct_violation_selection .columns .keys ()
605
+ if column in duplicate_selection .columns .keys ()
606
+ ]
607
+ # Fill all missing columns with NULLs.
608
+ + [
609
+ sa .null ().label (column )
610
+ for column in distinct_violation_selection .columns .keys ()
611
+ if column not in duplicate_selection .columns .keys ()
612
+ ]
613
+ )
614
+ )
615
+ violation_selection = duplicate_selection .union (distinct_violation_selection )
616
+
547
617
violation_subquery = violation_selection .subquery ()
548
618
549
619
keys = (
@@ -1102,12 +1172,11 @@ def get_row_mismatch(engine, ref, ref2, match_and_compare):
1102
1172
return result_mismatch , result_n_rows , [selection_difference , selection_n_rows ]
1103
1173
1104
1174
1105
- def get_duplicate_sample (engine , ref ):
1106
- initial_selection = ref .get_selection (engine ).alias ()
1175
+ def duplicates (subquery : sa .sql .selectable .Subquery ) -> sa .sql .selectable .Select :
1107
1176
aggregate_subquery = (
1108
- sa .select (initial_selection , sa .func .count ().label ("n_copies" ))
1109
- .select_from (initial_selection )
1110
- .group_by (* initial_selection .columns )
1177
+ sa .select (subquery , sa .func .count ().label ("n_copies" ))
1178
+ .select_from (subquery )
1179
+ .group_by (* subquery .columns )
1111
1180
.alias ()
1112
1181
)
1113
1182
duplicate_selection = (
@@ -1121,6 +1190,12 @@ def get_duplicate_sample(engine, ref):
1121
1190
.select_from (aggregate_subquery )
1122
1191
.where (aggregate_subquery .c .n_copies > 1 )
1123
1192
)
1193
+ return duplicate_selection
1194
+
1195
+
1196
+ def get_duplicate_sample (engine : sa .engine .Engine , ref : DataReference ) -> tuple :
1197
+ initial_selection = ref .get_selection (engine ).alias ()
1198
+ duplicate_selection = duplicates (initial_selection )
1124
1199
result = engine .connect ().execute (duplicate_selection ).first ()
1125
1200
return result , [duplicate_selection ]
1126
1201
0 commit comments