Skip to content

Commit

Permalink
additional checks for idf dataframe equality
Browse files Browse the repository at this point in the history
  • Loading branch information
R7L208 committed Jul 9, 2024
1 parent 09e6423 commit 981f1ab
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions python/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ def as_idf(self) -> IntervalsDF:
"""
sdf = self.as_sdf()
if self.idf_construct is not None:
return getattr(IntervalsDF, self.idf_construct)(sdf, **self.tsdf)
return getattr(IntervalsDF, self.idf_construct)(sdf, **self.idf)
else:
return IntervalsDF(self.as_sdf(), **self.tsdf)
return IntervalsDF(self.as_sdf(), **self.idf)


class SparkTest(unittest.TestCase):
Expand Down Expand Up @@ -305,8 +305,8 @@ def assertSchemaContainsField(self, schema, field):

def assertDataFrameEquality(
self,
df1: Union[TSDF, DataFrame],
df2: Union[TSDF, DataFrame],
df1: Union[TSDF, DataFrame, IntervalsDF],
df2: Union[TSDF, DataFrame, IntervalsDF],
ignore_row_order: bool = False,
ignore_column_order: bool = True,
ignore_nullable: bool = True,
Expand All @@ -324,6 +324,14 @@ def assertDataFrameEquality(
df1 = df1.df
df2 = df2.df

# Handle IDFs
if isinstance(df1, IntervalsDF):
# df2 must also be a IntervalsDF
self.assertIsInstance(df2, IntervalsDF)
# get the underlying Spark DataFrames
df1 = df1.df
df2 = df2.df

# handle DataFrames
assert_df_equality(
df1,
Expand Down

0 comments on commit 981f1ab

Please sign in to comment.