Skip to content

Commit

Permalink
Add top feature selection for MIMIC-III
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanCrabbe committed Jan 1, 2024
1 parent 3e6546c commit c810532
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
1 change: 1 addition & 0 deletions cmd/conf/datamodule/mimiciii.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ random_seed: ${random_seed}
fourier_transform: ${fourier_transform}
standardize: ${standardize}
batch_size: 64
n_feats: 20
13 changes: 13 additions & 0 deletions src/fdiff/dataloaders/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __init__(
batch_size: int = 32,
fourier_transform: bool = False,
standardize: bool = False,
n_feats: int = 104,
) -> None:
super().__init__(
data_dir=data_dir,
Expand All @@ -285,6 +286,7 @@ def __init__(
fourier_transform=fourier_transform,
standardize=standardize,
)
self.n_feats = n_feats

def setup(self, stage: str = "fit") -> None:
if (
Expand All @@ -304,6 +306,17 @@ def setup(self, stage: str = "fit") -> None:
self.X_train = torch.load(self.data_dir / "X_train.pt")
self.X_test = torch.load(self.data_dir / "X_test.pt")

assert isinstance(self.X_train, torch.Tensor)
assert isinstance(self.X_test, torch.Tensor)

# Filter the tensors to keep the features with highest variance accross the population
# The variance for each feature is averaged accrossed all time steps
top_feats = torch.argsort(self.X_train.std(1).mean(0), descending=True)[
: self.n_feats
]
self.X_train = self.X_train[:, :, top_feats]
self.X_test = self.X_test[:, :, top_feats]

def download_data(self) -> None:
dataset_path = self.data_dir / "all_hourly_data.h5"
assert dataset_path.exists(), (
Expand Down

0 comments on commit c810532

Please sign in to comment.