Skip to content
This repository was archived by the owner on Mar 8, 2025. It is now read-only.
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 39 additions & 22 deletions scripts/data_train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import json
from itertools import chain
from pathlib import Path
from typing import Optional, Sequence

Expand All @@ -9,17 +11,16 @@

logger = get_logger()


DEFAULT_FORMAT = ".csv.gz"
PARQUET_SUFFIX = ".parquet"


def dump_splits(
col: list[str],
orig_path: Path,
split_paths: list[Path],
cutoff_date: Optional[pd.Timestamp],
subject_id_split: Optional[tuple[Sequence, Sequence]],
col: list[str],
orig_path: Path,
split_paths: list[Path],
cutoff_date: Optional[pd.Timestamp],
subject_id_split: Optional[tuple[Sequence, Sequence]],
):
if all(p.exists() for p in split_paths):
logger.warning(f"All output files already exist, skipping: {orig_path}")
Expand Down Expand Up @@ -50,7 +51,7 @@ def dump_splits(
cond = _df[date_col] < cutoff_date
fold_name = split_path.parts[1]
if fold_name.endswith("prospective") or (
len(split_paths) == 2 and fold_name.startswith("test")
len(split_paths) == 2 and fold_name.startswith("test")
):
cond = ~cond
_df = _df.loc[cond]
Expand All @@ -72,14 +73,15 @@ def dump_splits(


def data_train_test_split(
dataset_dir: str,
col: str | list[str],
test_size: float,
id_data_path: str = None,
cutoff_date: str = None,
subset_format: str = DEFAULT_FORMAT,
seed: int = 42,
n_jobs: int = 1,
dataset_dir: str,
col: str | list[str],
test_size: float,
id_data_path: str = None,
cutoff_date: str = None,
subset_format: str = DEFAULT_FORMAT,
patient_ids: str = None,
seed: int = 42,
n_jobs: int = 1,
):
dataset_dir = Path(dataset_dir)
assert dataset_dir.is_dir(), f"Path is not a directory: {dataset_dir}"
Expand Down Expand Up @@ -126,10 +128,17 @@ def data_train_test_split(
if not df[id_col].is_unique:
raise ValueError(f"Column '{id_col}' is not unique in '{id_data_path}'")

test_df = df.sample(frac=test_size, random_state=seed)
train_df = df.drop(test_df.index)
if patient_ids is None:
test_df = df.sample(frac=test_size, random_state=seed)
train_df = df.drop(test_df.index)
subject_id_split = (train_df[id_col], test_df[id_col])
else:
with Path(patient_ids).open("r") as f:
splits = json.load(f)
get_ids_from_split = lambda fold: list(chain.from_iterable(
ids for split_name, ids in splits.items() if fold in split_name))
subject_id_split = (get_ids_from_split("train"), get_ids_from_split("test"))

subject_id_split = (train_df[id_col], test_df[id_col])
logger.info(
"Subject number (train/test): {:,}/{:,} (test_size={:.0%})".format(
len(subject_id_split[0]),
Expand Down Expand Up @@ -175,7 +184,7 @@ def data_train_test_split(
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Split the raw dataset into train and test files, defaults are for MIMIC. "
"The split dataset is created in the same directory as the original dataset.",
"The split dataset is created in the same directory as the original dataset.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("dateset", type=str, help="Path to the dataset dir.")
Expand All @@ -185,14 +194,14 @@ def data_train_test_split(
nargs="+",
default="subject_id",
help="The name of the column with patient ids and/or the name of the time column if"
" performing a time-based split.",
" performing a time-based split.",
)
parser.add_argument(
"--id_data_path",
type=str,
default="hosp/patients.csv.gz",
help="Path of the file with unique subject IDs to evaluate the train/test ids. "
"Relative path from the dataset directory.",
"Relative path from the dataset directory.",
)
parser.add_argument(
"--test_size",
Expand All @@ -210,7 +219,14 @@ def data_train_test_split(
type=str,
default=DEFAULT_FORMAT,
help="Format of the subsets in the dataset directory, supported are: "
"'parquet' or any type of CSV (comma).",
"'parquet' or any type of CSV (comma).",
)
parser.add_argument(
"--patient_ids",
type=str,
default=None,
help="File path to the JSON file with the patient ids split. Keys are the fold names, values"
" are the lists of patient ids.",
)
parser.add_argument(
"-s", "--seed", type=int, default=42, help="Random seed of the split by patients."
Expand All @@ -224,6 +240,7 @@ def data_train_test_split(
args.id_data_path,
args.cutoff_date,
args.data_format,
args.patient_ids,
args.seed,
args.n_jobs,
)