Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistence of TabularTransform constructor default arguments #88

Open
phantom-duck opened this issue Jul 13, 2023 · 1 comment
Open

Comments

@phantom-duck
Copy link

Hello! I think there is a potential issue when one tries to create several TabularTransform instances (omnixai.preprocessing.tabular.TabularTransform).

Specifically, the __init__ function of TabularTransform has default values for its 3 arguments which are instantiations of transformation classes. Due to python's way of parsing, these instantiations are evaluated when the function is defined, and never again. This means that the default values are objects shared across all instances of TabularTransform.

This leads to highly unintuitive and misleading results in some cases. As an example:

import pandas as pd
from omnixai.data.tabular import Tabular
from omnixai.preprocessing.tabular import TabularTransform

df1 = pd.DataFrame({"cat_feat": ["A", "B", "B", "A"], "num_feat": [1, 2, 3, 4]})
print(df1)
# output:
#   cat_feat  num_feat
# 0        A         1
# 1        B         2
# 2        B         3
# 3        A         4

tabular1 = Tabular(df1, categorical_columns=["cat_feat"])
transform1 = TabularTransform().fit(tabular1)
print(transform1.cate_shape)
# output: 2
print(transform1.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B']

df2 = pd.DataFrame({"cat_feat_1": ["A", "B", "B", "A"], "cat_feat_2": ["h", "h", "h", "l"], "num_feat": [1, 2, 3, 4]})
print(df2)
#output:
#   cat_feat_1 cat_feat_2  num_feat
# 0          A          h         1
# 1          B          h         2
# 2          B          h         3
# 3          A          l         4

tabular2 = Tabular(df2, categorical_columns=["cat_feat_1", "cat_feat_2"])
transform2 = TabularTransform().fit(tabular2)
print(transform2.cate_shape)
# output: 4
print(transform2.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B' 'x1_h' 'x1_l']

print(transform1.cate_shape)
# output: 2
print(transform1.cate_transform.get_feature_names())
# output: ['x0_A' 'x0_B' 'x1_h' 'x1_l']

It can be observed that despite not touching the transform1 object, its get_feature_names method returns the same array as transform2 in the end.

@yangwenz
Copy link
Collaborator

Thanks for mentioning this issue, we will move the default parameters into the init function in the next minor version.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants