Skip to content

Commit 524e048

Browse files
authored
Fix additional dataloader creation (#220)
1 parent 09b5b5a commit 524e048

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

DeepCrazyhouse/src/training/train_cnn.ipynb

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -399,18 +399,7 @@
399399
" for phase in [str(phase) for phase in to.phase_weights.keys()] + [\"None\"]:\n",
400400
" pgn_dataset_arrays_dict = load_pgn_dataset(dataset_type='test', part_id=0,\n",
401401
" verbose=True, normalize=tc.normalize, phase=phase)\n",
402-
" s_idcs_val_tmp = pgn_dataset_arrays_dict[\"start_indices\"]\n",
403-
" x_val_tmp = pgn_dataset_arrays_dict[\"x\"]\n",
404-
" yv_val_tmp = pgn_dataset_arrays_dict[\"y_value\"]\n",
405-
" yp_val_tmp = pgn_dataset_arrays_dict[\"y_policy\"]\n",
406-
" plys_to_end_tmp = pgn_dataset_arrays_dict[\"plys_to_end\"]\n",
407-
" pgn_datasets_val_tmp = pgn_dataset_arrays_dict[\"pgn_dataset\"]\n",
408-
" phase_vector_tmp = pgn_dataset_arrays_dict[\"phase_vector\"]\n",
409-
"\n",
410-
" if tc.discount != 1:\n",
411-
" yv_val_tmp *= tc.discount**plys_to_end_tmp\n",
412-
"\n",
413-
" data_loader = get_data_loader(x_val_tmp, yv_val_tmp, yp_val_tmp, plys_to_end_tmp, phase_vector_tmp, tc, shuffle=False)\n",
402+
" data_loader = get_data_loader(pgn_dataset_arrays_dict, tc, shuffle=False)\n",
414403
" additional_data_loaders[f\"Phase{phase}Test\"] = data_loader"
415404
]
416405
},
@@ -1726,4 +1715,4 @@
17261715
},
17271716
"nbformat": 4,
17281717
"nbformat_minor": 4
1729-
}
1718+
}

0 commit comments

Comments
 (0)