diff --git a/monai/data/utils.py b/monai/data/utils.py index 988b813272..14217e9103 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -1473,7 +1473,7 @@ def convert_tables_to_dicts( # parse row indices rows: list[int | str] = [] if row_indices is None: - rows = slice(df.shape[0]) # type: ignore + rows = df.index.tolist() else: for i in row_indices: if isinstance(i, (tuple, list)): diff --git a/tests/data/test_csv_dataset.py b/tests/data/test_csv_dataset.py index 71be4fdd22..56bac5cead 100644 --- a/tests/data/test_csv_dataset.py +++ b/tests/data/test_csv_dataset.py @@ -179,6 +179,20 @@ def prepare_csv_file(data, filepath): }, ) + # test pre-loaded DataFrame subset + df = pd.read_csv(filepath1) + df_subset = df.iloc[[1, 3, 4]] + dataset = CSVDataset(src=df_subset, col_groups={"ehr": [f"ehr_{i}" for i in range(3)]}) + self.assertEqual(len(dataset), 3) + np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000]) + + # test pre-loaded DataFrame subset with row_indices != None + df = pd.read_csv(filepath1) + df_subset = df.iloc[[1, 3, 4]] + dataset = CSVDataset(src=df_subset, row_indices=[1, 3], col_groups={"ehr": [f"ehr_{i}" for i in range(3)]}) + self.assertEqual(len(dataset), 2) + np.testing.assert_allclose([round(i, 4) for i in dataset[1]["ehr"]], [3.3333, 3.2353, 3.4000]) + # test pre-loaded multiple DataFrames, join tables with kwargs dfs = [pd.read_csv(i) for i in filepaths] dataset = CSVDataset(src=dfs, on="subject_id")