Skip to content

Commit

Permalink
refactor intervals_tests
Browse files Browse the repository at this point in the history
  • Loading branch information
R7L208 committed Jul 9, 2024
1 parent 981f1ab commit 55e65ad
Show file tree
Hide file tree
Showing 2 changed files with 890 additions and 932 deletions.
123 changes: 61 additions & 62 deletions python/tests/intervals_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class IntervalsDFTests(SparkTest):
]

def test_init_series_str(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

idf = IntervalsDF(df_input, "start_ts", "end_ts", "series_1")

Expand All @@ -91,7 +91,7 @@ def test_init_series_str(self):
self.assertCountEqual(idf.metric_columns, ["metric_1", "metric_2"])

def test_init_series_comma_seperated_str(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

idf = IntervalsDF(df_input, "start_ts", "end_ts", "series_1, series_2")

Expand All @@ -108,7 +108,7 @@ def test_init_series_comma_seperated_str(self):
self.assertCountEqual(idf.metric_columns, ["metric_1", "metric_2"])

def test_init_series_tuple(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

idf = IntervalsDF(df_input, "start_ts", "end_ts", ("series_1",))

Expand All @@ -125,7 +125,7 @@ def test_init_series_tuple(self):
self.assertCountEqual(idf.metric_columns, ["metric_1", "metric_2"])

def test_init_series_list(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

idf = IntervalsDF(df_input, "start_ts", "end_ts", ["series_1"])

Expand All @@ -142,7 +142,7 @@ def test_init_series_list(self):
self.assertCountEqual(idf.metric_columns, ["metric_1", "metric_2"])

def test_init_series_none(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

idf = IntervalsDF(df_input, "start_ts", "end_ts", None)

Expand All @@ -159,7 +159,7 @@ def test_init_series_none(self):
self.assertCountEqual(idf.metric_columns, ["metric_1", "metric_2"])

def test_init_series_int(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

self.assertRaises(
ValueError,
Expand All @@ -171,14 +171,12 @@ def test_init_series_int(self):
)

def test_window_property(self):
df_input = self.get_data_as_sdf("input")

idf = IntervalsDF(df_input, "start_ts", "end_ts", "series_1")
idf: IntervalsDF = self.get_test_df_builder("init").as_idf()

self.assertIsInstance(idf.window, pyspark.sql.window.WindowSpec)

def test_fromStackedMetrics_series_str(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

self.assertRaises(
ValueError,
Expand All @@ -192,7 +190,7 @@ def test_fromStackedMetrics_series_str(self):
)

def test_fromStackedMetrics_series_tuple(self):
df_input = self.get_data_as_sdf("input")
df_input = self.get_test_df_builder("init").as_sdf()

self.assertRaises(
ValueError,
Expand All @@ -206,8 +204,8 @@ def test_fromStackedMetrics_series_tuple(self):
)

def test_fromStackedMetrics_series_list(self):
df_input = self.get_data_as_sdf("input")
idf_expected = self.get_data_as_idf("expected")
df_input = self.get_test_df_builder("init").as_sdf()
idf_expected = self.get_test_df_builder("expected").as_idf()

df_input = df_input.withColumn(
"start_ts", f.to_timestamp("start_ts")
Expand All @@ -224,11 +222,11 @@ def test_fromStackedMetrics_series_list(self):
"metric_value",
)

self.assertDataFrameEquality(idf, idf_expected, from_idf=True)
self.assertDataFrameEquality(idf, idf_expected)

def test_fromStackedMetrics_metric_names(self):
df_input = self.get_data_as_sdf("input")
idf_expected = self.get_data_as_idf("expected")
df_input = self.get_test_df_builder("init").as_sdf()
idf_expected = self.get_test_df_builder("expected").as_idf()

df_input = df_input.withColumn(
"start_ts", f.to_timestamp("start_ts")
Expand All @@ -246,21 +244,21 @@ def test_fromStackedMetrics_metric_names(self):
["metric_1", "metric_2"],
)

self.assertDataFrameEquality(idf, idf_expected, from_idf=True)
self.assertDataFrameEquality(idf, idf_expected)

def test_make_disjoint(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_contains_interval_already_disjoint(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()
print("expected")
print(idf_expected.df.toPandas())

Expand All @@ -269,72 +267,72 @@ def test_make_disjoint_contains_interval_already_disjoint(self):
print(idf_actual)

# self.assertDataFrameEquality(
# idf_expected, idf_actual, from_idf=True, ignore_row_order=True
# idf_expected, idf_actual, ignore_row_order=True
# )

def test_make_disjoint_contains_intervals_equal(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_intervals_same_start(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_intervals_same_end(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_multiple_series(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_single_metric(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_make_disjoint_interval_is_subset(self):
idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)

def test_union_other_idf(self):
idf_input_1 = self.get_data_as_idf("input")
idf_input_2 = self.get_data_as_idf("input")
idf_input_1 = self.get_test_df_builder("init").as_idf()
idf_input_2 = self.get_test_df_builder("init").as_idf()

count_idf_1 = idf_input_1.df.count()
count_idf_2 = idf_input_2.df.count()
Expand All @@ -346,21 +344,21 @@ def test_union_other_idf(self):
self.assertEqual(count_idf_1 + count_idf_2, count_union)

def test_union_other_df(self):
idf_input = self.get_data_as_idf("input")
df_input = self.get_data_as_sdf("input")
idf_input = self.get_test_df_builder("init").as_idf()
df_input = self.get_test_df_builder("init").as_sdf()

self.assertRaises(TypeError, idf_input.union, df_input)

def test_union_other_list_dicts(self):
idf_input = self.get_data_as_idf("input")
idf_input = self.get_test_df_builder("init").as_idf()

self.assertRaises(
TypeError, idf_input.union, IntervalsDFTests.union_tests_dict_input
)

def test_unionByName_other_idf(self):
idf_input_1 = self.get_data_as_idf("input")
idf_input_2 = self.get_data_as_idf("input")
idf_input_1 = self.get_test_df_builder("init").as_idf()
idf_input_2 = self.get_test_df_builder("init").as_idf()

count_idf_1 = idf_input_1.df.count()
count_idf_2 = idf_input_2.df.count()
Expand All @@ -372,41 +370,42 @@ def test_unionByName_other_idf(self):
self.assertEqual(count_idf_1 + count_idf_2, count_union_by_name)

def test_unionByName_other_df(self):
idf_input = self.get_data_as_idf("input")
df_input = self.get_data_as_sdf("input")
idf_input = self.get_test_df_builder("init").as_idf()
df_input = self.get_test_df_builder("init").as_sdf()

self.assertRaises(TypeError, idf_input.unionByName, df_input)

def test_unionByName_other_list_dicts(self):
idf_input = self.get_data_as_idf("input")
idf_input = self.get_test_df_builder("init").as_idf()

self.assertRaises(
TypeError, idf_input.unionByName, IntervalsDFTests.union_tests_dict_input
)

def test_unionByName_extra_column(self):
idf_extra_col = self.get_data_as_idf("input_extra_col")
idf_input = self.get_data_as_idf("input")
idf_extra_col = self.get_test_df_builder("init_extra_col").as_idf()
idf_input = self.get_test_df_builder("init").as_idf()

self.assertRaises(AnalysisException, idf_extra_col.unionByName, idf_input)

def test_unionByName_other_extra_column(self):
idf_input = self.get_data_as_idf("input")
idf_extra_col = self.get_data_as_idf("input_extra_col")
idf_input = self.get_test_df_builder("init").as_idf()
idf_extra_col = self.get_test_df_builder("init_extra_col").as_idf()

self.assertRaises(AnalysisException, idf_input.unionByName, idf_extra_col)

def test_toDF(self):
idf_input = self.get_data_as_idf("input")
expected_df = self.get_data_as_sdf("input")
# NB: init is used for both since the expected df is the same
idf_input = self.get_test_df_builder("init").as_idf()
expected_df = self.get_test_df_builder("init").as_sdf()

actual_df = idf_input.toDF()

self.assertDataFrameEquality(actual_df, expected_df)

def test_toDF_stack(self):
idf_input = self.get_data_as_idf("input")
expected_df = self.get_data_as_sdf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
expected_df = self.get_test_df_builder("expected").as_sdf()

expected_df = expected_df.withColumn(
"start_ts", f.to_timestamp("start_ts")
Expand All @@ -419,14 +418,14 @@ def test_toDF_stack(self):
def test_make_disjoint_issue_268(self):
# https://github.com/databrickslabs/tempo/issues/268

idf_input = self.get_data_as_idf("input")
idf_expected = self.get_data_as_idf("expected")
idf_input = self.get_test_df_builder("init").as_idf()
idf_expected = self.get_test_df_builder("expected").as_idf()

idf_actual = idf_input.make_disjoint()
idf_actual.df.show(truncate=False)

self.assertDataFrameEquality(
idf_expected, idf_actual, from_idf=True, ignore_row_order=True
idf_expected, idf_actual, ignore_row_order=True
)


Expand Down
Loading

0 comments on commit 55e65ad

Please sign in to comment.