Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zazass8 committed Oct 9, 2024
1 parent d061a9f commit 25dee2d
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 31 deletions.
2 changes: 1 addition & 1 deletion mlxtend/feature_selection/column_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def transform(self, X, y=None):

# We use the loc or iloc accessor if the input is a pandas dataframe
if hasattr(X, "loc") or hasattr(X, "iloc"):
if type(self.cols) == tuple:
if isinstance(self.cols, tuple):
self.cols = list(self.cols)
types = {type(i) for i in self.cols}
if len(types) > 1:
Expand Down
82 changes: 52 additions & 30 deletions mlxtend/frequent_patterns/tests/test_association_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,20 @@
"consequent support",
"support",
"confidence",
"representativity",
"lift",
"leverage",
"conviction",
"zhangs_metric",
"jaccard",
"certainty",
"kulczynski",
"kulczynski"
]


# fmt: off
def test_default():
res_df = association_rules(df_freq_items)
res_df = association_rules(df_freq_items, df, len(df))
res_df["antecedents"] = res_df["antecedents"].apply(lambda x: str(frozenset(x)))
res_df["consequents"] = res_df["consequents"].apply(lambda x: str(frozenset(x)))
res_df.sort_values(columns_ordered, inplace=True)
Expand Down Expand Up @@ -85,7 +86,7 @@ def test_default():


def test_datatypes():
res_df = association_rules(df_freq_items)
res_df = association_rules(df_freq_items, df, len(df))
for i in res_df["antecedents"]:
assert isinstance(i, frozenset) is True

Expand All @@ -100,7 +101,7 @@ def test_datatypes():
lambda x: set(x)
)

res_df = association_rules(df_freq_items)
res_df = association_rules(df_freq_items, df, len(df))
for i in res_df["antecedents"]:
assert isinstance(i, frozenset) is True

Expand All @@ -110,16 +111,16 @@ def test_datatypes():

def test_no_support_col():
df_no_support_col = df_freq_items.loc[:, ["itemsets"]]
numpy_assert_raises(ValueError, association_rules, df_no_support_col)
numpy_assert_raises(ValueError, association_rules, df_no_support_col, df, len(df))


def test_no_itemsets_col():
df_no_itemsets_col = df_freq_items.loc[:, ["support"]]
numpy_assert_raises(ValueError, association_rules, df_no_itemsets_col)
numpy_assert_raises(ValueError, association_rules, df_no_itemsets_col, df, len(df))


def test_wrong_metric():
numpy_assert_raises(ValueError, association_rules, df_freq_items, "unicorn")
numpy_assert_raises(ValueError, association_rules, df_freq_items, df, len(df), False, "unicorn")


def test_empty_result():
Expand All @@ -131,6 +132,7 @@ def test_empty_result():
"consequent support",
"support",
"confidence",
"representativity",
"lift",
"leverage",
"conviction",
Expand All @@ -140,82 +142,100 @@ def test_empty_result():
"kulczynski",
]
)
res_df = association_rules(df_freq_items, min_threshold=2)
res_df = association_rules(df_freq_items, df, len(df), min_threshold=2)
assert res_df.equals(expect)


def test_leverage():
res_df = association_rules(df_freq_items, min_threshold=0.1, metric="leverage")
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=0.1, metric="leverage"
)
assert res_df.values.shape[0] == 6

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=0.1, metric="leverage"
df_freq_items_with_colnames, df, len(df), min_threshold=0.1, metric="leverage"
)
assert res_df.values.shape[0] == 6


def test_conviction():
res_df = association_rules(df_freq_items, min_threshold=1.5, metric="conviction")
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=1.5, metric="conviction"
)
assert res_df.values.shape[0] == 11

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=1.5, metric="conviction"
df_freq_items_with_colnames, df, len(df), min_threshold=1.5, metric="conviction"
)
assert res_df.values.shape[0] == 11


def test_lift():
res_df = association_rules(df_freq_items, min_threshold=1.1, metric="lift")
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=1.1, metric="lift"
)
assert res_df.values.shape[0] == 6

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=1.1, metric="lift"
df_freq_items_with_colnames, df, len(df), min_threshold=1.1, metric="lift"
)
assert res_df.values.shape[0] == 6


def test_confidence():
res_df = association_rules(df_freq_items, min_threshold=0.8, metric="confidence")
res_df = association_rules(
df_freq_items, df, len(df), min_threshold=0.8, metric="confidence"
)
assert res_df.values.shape[0] == 9

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=0.8, metric="confidence"
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, metric="confidence"
)
assert res_df.values.shape[0] == 9


def test_representativity():
res_df = association_rules(df_freq_items, df, len(df), min_threshold=1.0, metric="representativity")
assert res_df.values.shape[0] == 16

res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=1.0, metric="representativity"
)
assert res_df.values.shape[0] == 16


def test_jaccard():
res_df = association_rules(df_freq_items, min_threshold=0.7, metric="jaccard")
res_df = association_rules(df_freq_items, df, len(df), min_threshold=0.7, metric="jaccard")
assert res_df.values.shape[0] == 8

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=0.7, metric="jaccard"
df_freq_items_with_colnames, df, len(df), min_threshold=0.7, metric="jaccard"
)
assert res_df.values.shape[0] == 8


def test_certainty():
res_df = association_rules(df_freq_items, metric="certainty", min_threshold=0.6)
res_df = association_rules(df_freq_items, df, len(df), metric="certainty", min_threshold=0.6)
assert res_df.values.shape[0] == 3

res_df = association_rules(
df_freq_items_with_colnames, metric="certainty", min_threshold=0.6
df_freq_items_with_colnames, df, len(df), metric="certainty", min_threshold=0.6
)
assert res_df.values.shape[0] == 3


def test_kulczynski():
res_df = association_rules(df_freq_items, metric="kulczynski", min_threshold=0.9)
res_df = association_rules(df_freq_items, df, len(df), metric="kulczynski", min_threshold=0.9)
assert res_df.values.shape[0] == 2

res_df = association_rules(
df_freq_items_with_colnames, metric="kulczynski", min_threshold=0.6
df_freq_items_with_colnames, df, len(df), metric="kulczynski", min_threshold=0.6
)
assert res_df.values.shape[0] == 16


def test_frozenset_selection():
res_df = association_rules(df_freq_items)
res_df = association_rules(df_freq_items, df, len(df))

sel = res_df[res_df["consequents"] == frozenset((3, 5))]
assert sel.values.shape[0] == 1
Expand All @@ -231,17 +251,19 @@ def test_frozenset_selection():


def test_override_metric_with_support():
res_df = association_rules(df_freq_items_with_colnames, min_threshold=0.8)
res_df = association_rules(
df_freq_items_with_colnames, df, len(df), min_threshold=0.8
)
# default metric is confidence
assert res_df.values.shape[0] == 9

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=0.8, metric="support"
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, metric="support"
)
assert res_df.values.shape[0] == 2

res_df = association_rules(
df_freq_items_with_colnames, min_threshold=0.8, support_only=True
df_freq_items_with_colnames, df, len(df), min_threshold=0.8, support_only=True
)
assert res_df.values.shape[0] == 2

Expand Down Expand Up @@ -274,7 +296,7 @@ def test_on_df_with_missing_entries():

df = pd.DataFrame(dict)

numpy_assert_raises(KeyError, association_rules, df)
numpy_assert_raises(KeyError, association_rules, df , df, len(df))


def test_on_df_with_missing_entries_support_only():
Expand Down Expand Up @@ -304,13 +326,13 @@ def test_on_df_with_missing_entries_support_only():
}

df = pd.DataFrame(dict)
df_result = association_rules(df, support_only=True, min_threshold=0.1)
df_result = association_rules(df, df, len(df), support_only=True, min_threshold=0.1)

assert df_result["support"].shape == (18,)
assert int(np.isnan(df_result["support"].values).any()) != 1


def test_with_empty_dataframe():
df = df_freq_items_with_colnames.iloc[:0]
df_freq = df_freq_items_with_colnames.iloc[:0]
with pytest.raises(ValueError):
association_rules(df)
association_rules(df_freq, df, len(df))

0 comments on commit 25dee2d

Please sign in to comment.