Skip to content

Commit 662e122

Browse files
committed
Updates to run notebook on cpu
Signed-off-by: Eric Kerfoot <[email protected]>
1 parent a06376c commit 662e122

File tree

3 files changed

+14
-12
lines changed

3 files changed

+14
-12
lines changed

2d_classification/monai_101.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,11 +242,12 @@
242242
"outputs": [],
243243
"source": [
244244
"max_epochs = 5\n",
245-
"model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(\"cuda:0\")\n",
245+
"device = torch.device(\"cuda:0\" if torch.cuda.device_count() > 0 else \"cpu\")\n",
246+
"model = densenet121(spatial_dims=2, in_channels=1, out_channels=6).to(device)\n",
246247
"\n",
247248
"logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
248249
"trainer = SupervisedTrainer(\n",
249-
" device=torch.device(\"cuda:0\"),\n",
250+
" device=device,\n",
250251
" max_epochs=max_epochs,\n",
251252
" train_data_loader=DataLoader(dataset, batch_size=512, shuffle=True, num_workers=4),\n",
252253
" network=model,\n",
@@ -312,7 +313,7 @@
312313
"max_items_to_print = 10\n",
313314
"with eval_mode(model):\n",
314315
" for item in DataLoader(testdata, batch_size=1, num_workers=0):\n",
315-
" prob = np.array(model(item[\"image\"].to(\"cuda:0\")).detach().to(\"cpu\"))[0]\n",
316+
" prob = np.array(model(item[\"image\"].to(device)).detach().to(\"cpu\"))[0]\n",
316317
" pred = class_names[prob.argmax()]\n",
317318
" gt = item[\"class_name\"][0]\n",
318319
" print(f\"Class prediction is {pred}. Ground-truth: {gt}\")\n",

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ autopep8
44
jupytext<=1.16.3
55
autoflake
66
ipywidgets
7+
ipykernel
78
tensorboard>=2.4.0

runner.sh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,15 @@ trap finish EXIT
466466
# After setup, don't want to exit immediately after error
467467
set +e
468468

469+
# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354
470+
protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
471+
if [ "$protobuf_major_version" -ge "4" ]
472+
then
473+
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
474+
else
475+
unset PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION
476+
fi
477+
469478
########################################################################
470479
# #
471480
# loop over files #
@@ -559,15 +568,6 @@ for file in "${files[@]}"; do
559568

560569
python -c 'import monai; monai.config.print_config()'
561570

562-
# FIXME: https://github.com/Project-MONAI/MONAI/issues/4354
563-
protobuf_major_version=$(${PY_EXE} -m pip list | grep '^protobuf ' | tr -s ' ' | cut -d' ' -f2 | cut -d'.' -f1)
564-
if [ "$protobuf_major_version" -ge "4" ]
565-
then
566-
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
567-
else
568-
unset PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION
569-
fi
570-
571571
cmd=$(echo "papermill ${papermill_opt} --progress-bar --log-output -k ${kernelspec}")
572572
echo "$cmd"
573573
time out=$(echo "$notebook" | eval "$cmd")

0 commit comments

Comments
 (0)