Skip to content

Commit

Permalink
Merge pull request #12 from musicpiano/master
Browse files Browse the repository at this point in the history
add some conditions
  • Loading branch information
djgagne authored Jul 21, 2020
2 parents fa37bab + 720e09b commit 87a7e71
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions mlmicrophysics/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,12 @@ def categorize_output_values(output_values, output_transforms, output_scalers=No
for label, comparison in output_transforms.items():
class_indices = ops[comparison[0]](output_values, float(comparison[1]))
labels[class_indices] = label
transformed_outputs[class_indices] = transforms[comparison[2]](output_values[class_indices],
if comparison[2] != "None":
transformed_outputs[class_indices] = transforms[comparison[2]](output_values[class_indices],
eps=float(comparison[1]))
else:
transformed_outputs[class_indices] = output_values[class_indices]
# If the transform is 'None', then don't transform, copy exactly from the original data
if comparison[3] != "None":
if label not in list(output_scalers.keys()):
output_scalers[label] = scalers[comparison[3]]()
Expand All @@ -374,6 +378,8 @@ def categorize_output_values(output_values, output_transforms, output_scalers=No
transformed_outputs[class_indices].reshape(-1, 1)).ravel()
else:
output_scalers[label] = None
scaled_outputs[class_indices] = transformed_outputs[class_indices]
# If the scaler is None, copy exactly from Transform data, should not be 0
return labels, transformed_outputs, scaled_outputs, output_scalers


Expand Down Expand Up @@ -433,7 +439,8 @@ def assemble_data_files(files, input_cols, output_cols, input_transforms, output
del all_output_data[:]
print("Transforming data")
for var, transform_name in input_transforms.items():
combined_input_data.loc[:, var] = transforms[transform_name](combined_input_data[var])
if transform_name != "None": # handle the situation when you don't need to transform the data
combined_input_data.loc[:, var] = transforms[transform_name](combined_input_data[var])
transformed_output_data = pd.DataFrame(0,
columns=combined_output_data.columns,
index=combined_output_data.index,
Expand Down

0 comments on commit 87a7e71

Please sign in to comment.