Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multiprocessing #92

Open
wants to merge 99 commits into
base: developer
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
0e2242e
Add files for multiprocessing
Apr 18, 2024
2add9b9
Update identify_associations_multiprocess.py
LFT18 Apr 18, 2024
f645ea4
Clean multiprocessing script
LFT18 Apr 19, 2024
f471704
Update __main__.py multiprocessing
LFT18 Apr 19, 2024
85c28e5
Update schema.py multiprocessing
LFT18 Apr 19, 2024
6330e92
Update __init__.py multiprocessing
LFT18 Apr 19, 2024
bbe1b4e
Update preprocessing.py
LFT18 Apr 19, 2024
820c554
:fire: clean-up duplicated src/move files (pkg was in main folder)
Apr 22, 2024
ce9a9dc
:sparkles: add identify_associations_multiprocess to src/move/tasks
Apr 23, 2024
5327223
:bug: make mutliprocessing not stale: assign # of threads for each pr…
Apr 23, 2024
eaa858a
Merge pull request #1 from enryH/main
LFT18 Apr 23, 2024
5ab5e59
Updated identify_associations_multiprocess.py
Apr 23, 2024
33f565a
Update config files for small tries
Apr 23, 2024
63f128b
Multiprocessing for analyze_latent
Apr 24, 2024
ca389d2
Analyze latent multiprocessing
Apr 24, 2024
e08a94b
Analyze latent multiprocessing
Apr 24, 2024
e94ef90
Fix bayes_k calculation
Apr 25, 2024
f4f0aa3
Fix analyze_latent_multiprocessing
Apr 26, 2024
6a0b665
Update and new functions
May 11, 2024
a5310a6
Delete files and fix multiloop
May 11, 2024
e67bb75
Clean identify_association_multiprocess.py
May 21, 2024
86bfed5
Clean analyze_latent multiprocessing.py
May 21, 2024
f9d4961
Update perturbations.py
LFT18 Jun 7, 2024
c2c49e8
Update perturbations.py
LFT18 Jun 10, 2024
0a4bcae
Delete src/move/tasks/analyze_latent_efficient.py
LFT18 Jun 13, 2024
ce20dac
Delete src/move/tasks/analyze_latent_multiprocessing.py
LFT18 Jun 13, 2024
4a72842
Delete src/move/tasks/identify_associations_multiprocess_loop.py
LFT18 Jun 13, 2024
f537d21
Delete src/move/tasks/identify_associations_multiprocess_may.py
LFT18 Jun 13, 2024
568aaa8
Delete src/move/tasks/identify_associations_selected.py
LFT18 Jun 13, 2024
ede3707
Delete src/move/tasks/analyze_latent_original.py
LFT18 Jun 13, 2024
5df8a01
Remove multiprocess_loop
LFT18 Jun 13, 2024
52c37fd
Remove multiprocess_loop
LFT18 Jun 13, 2024
2e29f23
:art: format with black
Jun 18, 2024
13678cb
Merge branch 'main' into LFT18-main
Jun 18, 2024
f92a862
Merge branch 'developer' into LFT18-main
Jun 18, 2024
b7824f7
:art: add trigger of actions from PR
Jun 20, 2024
e5253a2
:art: format with black
Jun 20, 2024
d4118a3
:fire: remove duplicated code and intermediate scripts
Jun 20, 2024
3e19e24
Merge branch 'developer' into LFT18-main
Jun 21, 2024
dbc0238
:bug: fix f-string formatting errors
Jun 24, 2024
80352e6
:bug: remove unused imports
Jun 24, 2024
488b4a4
:rewind: add configuration files back in from developer branch
Jul 3, 2024
9b3a27e
:art: isort imports
Jul 3, 2024
05d4c34
:construction: see if this advances CI to the next step
Jul 3, 2024
ebb72ad
:fire: remove intermediate files of development
Jul 3, 2024
efbfd5c
:construction: multiprocess only defined for bayes factors
Jul 3, 2024
fbbeb19
:bug: remove non-existing, intermediate tasks (used for developing), …
Jul 3, 2024
cb3ad30
:bug: also deactivate mutliprocessing for KS as it's not implemented
Jul 3, 2024
8c61d35
:art: fix flake8-bugbear issues except missing multiprocessing of t-t…
Jul 3, 2024
5aa03ff
:bug: format and fix import
Jul 3, 2024
8b06298
:bug: use perturb_continuous_data_extended from perturbations
Jul 3, 2024
6cfd1f8
:fire: comments and old configurations; format
Jul 4, 2024
8277891
:fire: remove duplicated functionality
Jul 4, 2024
44802eb
:sparkles: integrate multiprocessing into analyze_latent.py
Jul 4, 2024
f64d779
:sparkles: merge multiprocessing bayes factors into identify_associat…
Jul 4, 2024
6a17110
:fire: remove old schema entries, increase run time
Jul 5, 2024
b8b4769
:zip: do no save intermediate files for single-process bayes_approach
Jul 5, 2024
2927d3f
:fire: remove comments
Jul 5, 2024
202eb74
:fire: remove unused code
Jul 5, 2024
d6bc896
:art: reorder functions
Jul 5, 2024
b99ce97
:construction: move bayes_parallel to own module
Jul 5, 2024
171c915
:construction: unify interface
Jul 5, 2024
95c9f40
:fire: remove not-used code
Jul 5, 2024
e556fae
:art: start separating recurrent code into fcts
Jul 5, 2024
1a2774d
:art: adapt to look more similar to single-core bayes factor fct
Jul 5, 2024
680c164
:sparkles: add back masking of self-perturbed feat.
Jul 5, 2024
8f43c90
:art: initiailize logger at the top of the module
Jul 8, 2024
6178c8f
:sparkles: pass feature_mask to bayes_parallell
Jul 8, 2024
60ed227
:bug: add condition for masking
Jul 8, 2024
f5da671
:art: align masking strategies
Jul 8, 2024
7eae82d
:bug: fix cont perturbation
Jul 8, 2024
50a623e
:bug: remove redefintion of nan_mask
Jul 8, 2024
2df057b
:art: only define logger once in module
Jul 8, 2024
2e87e12
:art: align single process bayes and multiprocess bayes fct
Jul 8, 2024
aa2e5d9
:art: just document in code that this cannot happen
Jul 8, 2024
4041fcb
:zap: improve CI speed, reduce stability (-> one refit only)
Jul 8, 2024
2d004e0
:bug: use default no. of epochs + t-test needs 4 refits
Jul 8, 2024
8d65528
:zap: do not run t-test check (for now)
Jul 9, 2024
1dd6788
:zap: bump up bayes factor training
Jul 9, 2024
dc9020e
:art: train both refits with 100 epochs
Jul 9, 2024
9cd2a7b
:sparkles: add log2 option
Jul 9, 2024
a4911d7
:art: document some more
Jul 9, 2024
e0421bd
:zap: test multiprocess on continuous tutorial
Jul 9, 2024
c70d328
:bug: remove non-exisitng key
Jul 9, 2024
1c72316
:sparkles: build dataloader fct
Jul 9, 2024
58f08e4
:bug: fix minor bug (wrongly assigned feat)
Jul 9, 2024
f895237
:zap: move masking code into main fct of module
Jul 9, 2024
5eb7954
:art: move feat_mask creation out
Jul 9, 2024
8c4e53b
:ambulance: temp. fix of CI
Jul 9, 2024
49a93d0
:zap: do not build dataloaders for multiprocessing
Jul 9, 2024
709c674
:construction: test t-test again, re-run pert. w/o model training
Jul 10, 2024
6e65cc6
:sparkles: add categorical pert. to multiprocessing
Jul 10, 2024
4efbdd9
:fire: remove unused code
Jul 10, 2024
dab767a
:art: remove unused argument
Jul 10, 2024
980bbce
:rewind: checkout developer version
Jul 12, 2024
c26b2dd
:art: move shared key to base class
Jul 12, 2024
05c1735
:fire: remove comments and code duplications
Jul 12, 2024
c5002cd
:art: update type hints, remove unused import
Jul 12, 2024
fe8c48b
Merge branch 'developer' into main
enryH Aug 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
229 changes: 219 additions & 10 deletions src/move/data/perturbations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["perturb_categorical_data", "perturb_continuous_data"]
__all__ = ["perturb_categorical_data", "perturb_continuous_data", "perturb_continuous_data_extended_one", "perturb_continuous_data_extended"]

from pathlib import Path
from typing import Literal, Optional, cast
Expand All @@ -14,6 +14,207 @@

ContinuousPerturbationType = Literal["minimum", "maximum", "plus_std", "minus_std"]

def perturb_continuous_data_one(
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
target_value: float,
index_pert_feat: int, # Index of the datasetto perturb
) -> DataLoader: # change list(DataLoader) to just one DataLoader
"""Add perturbations to continuous data. For each feature in the target
dataset, change its value to target.

Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
target_value: Target value. In analyze_latent, it will be 0

Returns:
One dataloader, with the ith dataset perturbed
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

target_idx = con_dataset_names.index(target_dataset_name)
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

num_features = baseline_dataset.con_shapes[target_idx]
#dataloaders = []
i = index_pert_feat
# Instead of the loop, we do it only for one
#for i in range(num_features):
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
target_dataset[:, i] = torch.FloatTensor([target_value])
perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)

return perturbed_dataloader




def perturb_categorical_data_one(
baseline_dataloader: DataLoader,
cat_dataset_names: list[str],
target_dataset_name: str,
target_value: np.ndarray,
index_pert_feat: int,
) -> DataLoader:
"""Add perturbations to categorical data. For each feature in the target
dataset, change its value to target.

Args:
baseline_dataloader: Baseline dataloader
cat_dataset_names: List of categorical dataset names
target_dataset_name: Target categorical dataset to perturb
target_value: Target value

Returns:
List of dataloaders containing all perturbed datasets
"""

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.cat_shapes is not None
assert baseline_dataset.cat_all is not None

target_idx = cat_dataset_names.index(target_dataset_name)
splits = np.cumsum(
[0] + [int.__mul__(*shape) for shape in baseline_dataset.cat_shapes]
)
slice_ = slice(*splits[target_idx : target_idx + 2])

target_shape = baseline_dataset.cat_shapes[target_idx]
#num_features = target_shape[0] # CHANGE

i = index_pert_feat
#dataloaders = []
#for i in range(num_features):
perturbed_cat = baseline_dataset.cat_all.clone()
target_dataset = perturbed_cat[:, slice_].view(
baseline_dataset.num_samples, *target_shape
)
target_dataset[:, i, :] = torch.FloatTensor(target_value)
perturbed_dataset = MOVEDataset(
perturbed_cat,
baseline_dataset.con_all,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)
perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)

return perturbed_dataloader


def perturb_continuous_data_extended_one( # We will keep the input almost the same, to make everything easier
# However, I have to introduce a variable that allows me to index the specific dataloader I want to create (index_pert_feat)
# And I eliminate the output directory, because I am not going to save any image
baseline_dataloader: DataLoader,
con_dataset_names: list[str],
target_dataset_name: str,
perturbation_type: ContinuousPerturbationType,
index_pert_feat: int,
) -> DataLoader: # But we change the output from list[DataLoader] to just one DataLoader
logger = get_logger(__name__)
"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
1,2) substituting this feature in all samples by the feature's minimum/maximum value.
3,4) Adding/Substracting one standard deviation to the sample's feature value.

Args:
baseline_dataloader: Baseline dataloader
con_dataset_names: List of continuous dataset names
target_dataset_name: Target continuous dataset to perturb
perturbation_type: 'minimum', 'maximum', 'plus_std' or 'minus_std'.
index_pert_feat: Index we want to perturb

Returns:
- Dataloader with the ith feature (index_pert_feat) perturbed.

Note:
This function was created so that it could generalize to non-normalized
datasets. Scaling is done per dataset, not per feature -> slightly different stds
feature to feature.
"""
logger.debug(f"Inside perturb_continuous_data_extended_one for feature {index_pert_feat}")

baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

# Use it only if we want to perturb all features in the target dataset
num_features = baseline_dataset.con_shapes[target_idx]
# Change below.
#num_features = 10

# Now, instead of the for loop that iterates over all the features we want to perturb, we do it only for one feature, the one
# indicated in index_pert_feat

#for i in range(num_features):
logger.debug(f"Setting up perturbed_con for feature {index_pert_feat}")

perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]

logger.debug(f"Changing to desired perturbation value for feature {index_pert_feat}")
# Change the desired feature value by:
min_feat_val_list, max_feat_val_list, std_feat_val_list = feature_stats(
target_dataset
)
if perturbation_type == "minimum":
target_dataset[:, index_pert_feat] = torch.FloatTensor([min_feat_val_list[index_pert_feat]])
elif perturbation_type == "maximum":
target_dataset[:, index_pert_feat] = torch.FloatTensor([max_feat_val_list[index_pert_feat]])
elif perturbation_type == "plus_std":
target_dataset[:, index_pert_feat] += torch.FloatTensor([std_feat_val_list[index_pert_feat]])
elif perturbation_type == "minus_std":
target_dataset[:, index_pert_feat] -= torch.FloatTensor([std_feat_val_list[index_pert_feat]])
logger.debug(f"Perturbation succesful for feature {index_pert_feat}")
# We used this for a plot I have removed, so no need to use it
# perturbations_list.append(target_dataset[:, i].numpy())

logger.debug(f"Creating perturbed dataset and dataloader for feature {index_pert_feat}")

perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
perturbed_con,
baseline_dataset.cat_shapes,
baseline_dataset.con_shapes,
)

perturbed_dataloader = DataLoader(
perturbed_dataset,
shuffle=False,
batch_size=baseline_dataloader.batch_size,
)
#dataloaders.append(perturbed_dataloader)

logger.debug(f"Finished perturb_continuous_data_extended_one for feature {index_pert_feat}")

return perturbed_dataloader



def perturb_categorical_data(
baseline_dataloader: DataLoader,
Expand Down Expand Up @@ -118,6 +319,7 @@ def perturb_continuous_data(

return dataloaders

from move.core.logging import get_logger

def perturb_continuous_data_extended(
baseline_dataloader: DataLoader,
Expand All @@ -126,6 +328,7 @@ def perturb_continuous_data_extended(
perturbation_type: ContinuousPerturbationType,
output_subpath: Optional[Path] = None,
) -> list[DataLoader]:
logger = get_logger(__name__)

"""Add perturbations to continuous data. For each feature in the target
dataset, change the feature's value in all samples (in rows):
Expand All @@ -149,22 +352,27 @@ def perturb_continuous_data_extended(
datasets. Scaling is done per dataset, not per feature -> slightly different stds
feature to feature.
"""

logger.debug("Inside perturb_extended, creating baseline dataset")
baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
assert baseline_dataset.con_shapes is not None
assert baseline_dataset.con_all is not None

logger.debug("Creating target_ics, splits, and slice")
target_idx = con_dataset_names.index(target_dataset_name) # dataset index
splits = np.cumsum([0] + baseline_dataset.con_shapes)
slice_ = slice(*splits[target_idx : target_idx + 2])

num_features = baseline_dataset.con_shapes[target_idx]
#num_features = baseline_dataset.con_shapes[target_idx]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Or does it still need to be changed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's correct now. I cannot check right now because I've been having problems connecting to Esrum all morning, but I'll check as soon as it works again (hopefully soon)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still can't connect, it's very annoying :( . I'll let you know as soon as I can again, but I think the code should be fine

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to connect finally today at noon :). The file was correct, because those functions are not used at all for multiprocessing, I had just changed that to test some things with the previous functions. But it is true that it led to confusion, so I reverted the changes so that the not used functions have their original code

# CHANGED THIS TO TRY IT. CHANGE LATER
num_features = 1
logger.debug(f"number of feature to perturb is {num_features}")
dataloaders = []
perturbations_list = []
# Change below.
#num_features = 10

for i in range(num_features):
logger.debug(f"Getting perturbed dataset for feature {i}")
perturbed_con = baseline_dataset.con_all.clone()
target_dataset = perturbed_con[:, slice_]
# Change the desired feature value by:
Expand All @@ -180,7 +388,7 @@ def perturb_continuous_data_extended(
elif perturbation_type == "minus_std":
target_dataset[:, i] -= torch.FloatTensor([std_feat_val_list[i]])

perturbations_list.append(target_dataset[:, i].numpy())
#perturbations_list.append(target_dataset[:, i].numpy())

perturbed_dataset = MOVEDataset(
baseline_dataset.cat_all,
Expand All @@ -195,13 +403,14 @@ def perturb_continuous_data_extended(
batch_size=baseline_dataloader.batch_size,
)
dataloaders.append(perturbed_dataloader)
logger.debug("Finished perturb_continuous_data_extended function")

# Plot the perturbations for all features, collapsed in one plot:
if output_subpath is not None:
fig = plot_value_distributions(np.array(perturbations_list).transpose())
fig_path = str(
output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
)
fig.savefig(fig_path)
#if output_subpath is not None:
# fig = plot_value_distributions(np.array(perturbations_list).transpose())
# fig_path = str(
# output_subpath / f"perturbation_distribution_{target_dataset_name}.png"
#)
#fig.savefig(fig_path)

return dataloaders