Skip to content

Commit

Permalink
preserve DS artifact mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
ShubhamVashisth7 committed Jul 21, 2023
1 parent 74ba4f6 commit 7541806
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions training_manager/build_transformation_recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self):
self.model = RandomForestClassifier(random_state=RANDOM_STATE)
self.label_encoder = LabelEncoder()

def __load_column_embeddings(self, path_to_embeddings: str = '../../storage/CoLR_embeddings_data_transformation'):
def __load_column_embeddings(self, path_to_embeddings: str = '../kg_augmentor/embeddings/CoLR_embeddings_data_transformation'):
for data_type in os.listdir(path=path_to_embeddings):
if data_type == '.DS_Store':
continue
Expand Down Expand Up @@ -90,7 +90,7 @@ def average_embeddings(embeddings: list):

scaling_transformation_column = []
table_embedding_column = []

table_id = []
# average embedding by grouping on table
embeddings_per_table = []
previous_table_id = list(scaling_df['Table_id'])[0]
Expand All @@ -106,11 +106,13 @@ def average_embeddings(embeddings: list):
averaged_embeddings = average_embeddings(embeddings=embeddings_per_table)
scaling_transformation_column.append(transformation)
table_embedding_column.append(averaged_embeddings)
table_id.append(previous_table_id)
# refresh variables
embeddings_per_table = [column_embedding]
previous_table_id = current_table_id

self.modeling_data_scaling = pd.DataFrame({'Transformation': scaling_transformation_column, 'Embeddings': table_embedding_column})
self.modeling_data_unary = unary_df.drop('Transformed_column_id', axis=1)
self.modeling_data_scaling = pd.DataFrame({'Transformation': scaling_transformation_column, 'Transformed_table_id': table_id, 'Embeddings': table_embedding_column})
self.modeling_data_unary = unary_df
self.modeling_data_scaling.to_csv('modeling_data_scaling.csv', index=False)
self.modeling_data_unary.to_csv('modeling_data_unary.csv', index=False)
print('done.')
Expand Down

0 comments on commit 7541806

Please sign in to comment.