Skip to content

Commit e464239

Browse files
Fix indexing in convert_tables_to_dicts() for col_groups != None
Also updated the unittest to pass after the fix and not before. Signed-off-by: Bartosz Grabowski <[email protected]>
1 parent 4cfd327 commit e464239

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

monai/data/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,7 @@ def convert_tables_to_dicts(
15001500
if col_groups is not None:
15011501
groups: dict[str, list] = {}
15021502
for name, cols in col_groups.items():
1503-
groups[name] = df.loc[rows, cols].values
1503+
groups[name] = df.iloc[rows][cols].values
15041504
# invert items of groups to every row of data
15051505
data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]
15061506

tests/data/test_csv_dataset.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -182,19 +182,12 @@ def prepare_csv_file(data, filepath):
182182
# test pre-loaded DataFrame subset
183183
df = pd.read_csv(filepath1)
184184
df_subset = df.iloc[[1, 3, 4]]
185-
dataset = CSVDataset(src=df_subset)
186-
self.assertDictEqual(
187-
{k: round(v, 4) if not isinstance(v, str) else v for k, v in dataset[2].items()},
188-
{
189-
"subject_id": "s000004",
190-
"label": 9,
191-
"image": "./imgs/s000004.png",
192-
"ehr_0": 6.4275,
193-
"ehr_1": 6.2549,
194-
"ehr_2": 5.9765,
195-
},
196-
)
185+
dataset = CSVDataset(src=df_subset, col_groups={"ehr": [f"ehr_{i}" for i in range(3)]})
197186
self.assertEqual(len(dataset), 3)
187+
np.testing.assert_allclose(
188+
[round(i, 4) for i in dataset[1]["ehr"]],
189+
[3.3333, 3.2353, 3.4000],
190+
)
198191

199192
# test pre-loaded multiple DataFrames, join tables with kwargs
200193
dfs = [pd.read_csv(i) for i in filepaths]

0 commit comments

Comments
 (0)