Skip to content

Commit

Permalink
nasa
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolashuynh committed Jan 15, 2024
1 parent a1bf300 commit 70b80ea
Show file tree
Hide file tree
Showing 4 changed files with 584 additions and 1 deletion.
1 change: 1 addition & 0 deletions cmd/conf/datamodule/nasa.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ random_seed: ${random_seed}
fourier_transform: ${fourier_transform}
standardize: ${standardize}
subdataset: charge
remove_outlier_feature: True
batch_size: 16
4 changes: 4 additions & 0 deletions cmd/run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/usr/bin/env bash

python sample.py model_id=20d9c1kc
python sample.py model_id=tip2g8eh
569 changes: 568 additions & 1 deletion notebooks/nasa_exploration.ipynb

Large diffs are not rendered by default.

11 changes: 11 additions & 0 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,11 @@ def __init__(
fourier_transform: bool = False,
standardize: bool = False,
subdataset: str = "charge",
remove_outlier_feature: bool = True,
) -> None:
self.subdataset = subdataset
self.remove_outlier_feature = remove_outlier_feature

super().__init__(
data_dir=data_dir,
random_seed=random_seed,
Expand Down Expand Up @@ -432,6 +435,14 @@ def setup(self, stage: str = "fit") -> None:
self.X_train = torch.load(self.data_dir / self.subdataset / "X_train.pt")
self.X_test = torch.load(self.data_dir / self.subdataset / "X_test.pt")

if self.remove_outlier_feature and self.subdataset == "charge":
# Remove the third feature which has a bad range
self.X_train = self.X_train[:, ::2, [0, 1, 3, 4]]
self.X_test = self.X_test[:, ::2, [0, 1, 3, 4]]

assert self.X_train.shape[2] == self.X_test.shape[2] == 4
assert self.X_train.shape[1] == 251
assert self.X_test.shape[1] == 251
assert isinstance(self.X_train, torch.Tensor)
assert isinstance(self.X_test, torch.Tensor)

Expand Down

0 comments on commit 70b80ea

Please sign in to comment.