Skip to content

Commit 1a245ad

Browse files
committed
update the hugging face pipeline for monai tutorial
Signed-off-by: binliu <[email protected]>
1 parent 6f1119b commit 1a245ad

File tree

1 file changed

+111
-46
lines changed

1 file changed

+111
-46
lines changed

hugging_face/hugging_face_pipeline_for_monai.ipynb

Lines changed: 111 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,55 @@
4040
},
4141
{
4242
"cell_type": "code",
43-
"execution_count": null,
43+
"execution_count": 2,
4444
"metadata": {},
45-
"outputs": [],
45+
"outputs": [
46+
{
47+
"name": "stderr",
48+
"output_type": "stream",
49+
"text": [
50+
"/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
51+
" from .autonotebook import tqdm as notebook_tqdm\n"
52+
]
53+
},
54+
{
55+
"name": "stdout",
56+
"output_type": "stream",
57+
"text": [
58+
"MONAI version: 1.4.0\n",
59+
"Numpy version: 1.24.4\n",
60+
"Pytorch version: 2.5.0a0+872d972e41.nv24.08\n",
61+
"MONAI flags: HAS_EXT = True, USE_COMPILED = False, USE_META_DICT = False\n",
62+
"MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008\n",
63+
"MONAI __file__: /opt/monai/monai/__init__.py\n",
64+
"\n",
65+
"Optional dependencies:\n",
66+
"Pytorch Ignite version: 0.4.11\n",
67+
"ITK version: 5.4.0\n",
68+
"Nibabel version: 5.3.1\n",
69+
"scikit-image version: 0.24.0\n",
70+
"scipy version: 1.14.0\n",
71+
"Pillow version: 10.4.0\n",
72+
"Tensorboard version: 2.16.2\n",
73+
"gdown version: 5.2.0\n",
74+
"TorchVision version: 0.20.0a0\n",
75+
"tqdm version: 4.66.5\n",
76+
"lmdb version: 1.5.1\n",
77+
"psutil version: 6.0.0\n",
78+
"pandas version: 2.2.2\n",
79+
"einops version: 0.8.0\n",
80+
"transformers version: 4.40.2\n",
81+
"mlflow version: 2.17.0\n",
82+
"pynrrd version: 1.0.0\n",
83+
"clearml version: 1.16.5rc2\n",
84+
"\n",
85+
"For details about installing the optional dependencies, please visit:\n",
86+
" https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n",
87+
"\n"
88+
]
89+
}
90+
],
4691
"source": [
47-
"import copy\n",
4892
"import logging\n",
4993
"\n",
5094
"from monai.transforms import (\n",
@@ -65,7 +109,7 @@
65109
"from monai.networks.nets import UNet\n",
66110
"from monai.networks.utils import eval_mode\n",
67111
"from monai.inferers import SlidingWindowInferer\n",
68-
"from monai.data import decollate_batch\n",
112+
"from monai.data import decollate_batch, list_data_collate\n",
69113
"from monai.config import print_config\n",
70114
"from monai.apps import download_and_extract\n",
71115
"import torch\n",
@@ -111,7 +155,7 @@
111155
},
112156
{
113157
"cell_type": "code",
114-
"execution_count": null,
158+
"execution_count": 4,
115159
"metadata": {},
116160
"outputs": [],
117161
"source": [
@@ -133,7 +177,7 @@
133177
},
134178
{
135179
"cell_type": "code",
136-
"execution_count": null,
180+
"execution_count": 5,
137181
"metadata": {},
138182
"outputs": [],
139183
"source": [
@@ -154,7 +198,7 @@
154198
},
155199
{
156200
"cell_type": "code",
157-
"execution_count": null,
201+
"execution_count": 6,
158202
"metadata": {},
159203
"outputs": [],
160204
"source": [
@@ -195,7 +239,7 @@
195239
},
196240
{
197241
"cell_type": "code",
198-
"execution_count": null,
242+
"execution_count": 7,
199243
"metadata": {},
200244
"outputs": [],
201245
"source": [
@@ -231,7 +275,7 @@
231275
},
232276
{
233277
"cell_type": "code",
234-
"execution_count": null,
278+
"execution_count": 8,
235279
"metadata": {},
236280
"outputs": [],
237281
"source": [
@@ -259,16 +303,16 @@
259303
" \"save_output\",\n",
260304
" ]\n",
261305
"\n",
262-
" def __init__(self, model, config):\n",
263-
" super().__init__(model=model, config=config)\n",
306+
" def __init__(self, model, **kwargs):\n",
307+
" super().__init__(model=model, **kwargs)\n",
264308
" self.model = model\n",
265309
" self.preprocessing_transforms = self._init_preprocessing_transforms(**self._preprocess_params)\n",
266310
" self.inferer = self._init_inferer(**self._forward_params)\n",
267311
" self.postprocessing_transforms = self._init_postprocessing_transforms(**self._postprocess_params)\n",
268312
"\n",
269-
" def _init_preprocessing_transforms(self, image_key=\"image\", load_image=True):\n",
313+
" def _init_preprocessing_transforms(self, image_key=Keys.IMAGE, load_image=True):\n",
270314
" transform_list = [LoadImaged(keys=image_key)] if load_image else []\n",
271-
" transform_list = transform_list.extend(\n",
315+
" transform_list.extend(\n",
272316
" [\n",
273317
" EnsureChannelFirstd(keys=image_key),\n",
274318
" Orientationd(keys=image_key, axcodes=\"RAS\"),\n",
@@ -282,8 +326,8 @@
282326
"\n",
283327
" def _init_postprocessing_transforms(\n",
284328
" self,\n",
285-
" pred_key: str = \"pred\",\n",
286-
" image_key: str = \"image\",\n",
329+
" pred_key: str = Keys.PRED,\n",
330+
" image_key: str = Keys.IMAGE,\n",
287331
" output_dir: str = \"output_directory\",\n",
288332
" output_ext: str = \".nii.gz\",\n",
289333
" output_dtype: torch.dtype = torch.float32,\n",
@@ -295,28 +339,24 @@
295339
" Activationsd(keys=pred_key, softmax=True),\n",
296340
" Invertd(\n",
297341
" keys=pred_key,\n",
298-
" transform=copy.deepcopy(self.preprocessing_transforms),\n",
342+
" transform=self.preprocessing_transforms,\n",
299343
" orig_keys=image_key,\n",
300344
" nearest_interp=False,\n",
301345
" to_tensor=True,\n",
302346
" ),\n",
303347
" AsDiscreted(keys=pred_key, argmax=True),\n",
304348
" ]\n",
305-
" transform_list = (\n",
306-
" transform_list.append(\n",
307-
" SaveImaged(\n",
308-
" keys=pred_key,\n",
309-
" output_dir=output_dir,\n",
310-
" output_ext=output_ext,\n",
311-
" output_dtype=output_dtype,\n",
312-
" output_postfix=output_postfix,\n",
313-
" separate_folder=separate_folder,\n",
314-
" )\n",
349+
" \n",
350+
" transform_list.append(\n",
351+
" SaveImaged(\n",
352+
" keys=pred_key,\n",
353+
" output_dir=output_dir,\n",
354+
" output_ext=output_ext,\n",
355+
" output_dtype=output_dtype,\n",
356+
" output_postfix=output_postfix,\n",
357+
" separate_folder=separate_folder,\n",
315358
" )\n",
316-
" if save_output\n",
317-
" else transform_list\n",
318-
" )\n",
319-
"\n",
359+
" )if save_output else transform_list\n",
320360
" postprocessing_transforms = Compose(transform_list)\n",
321361
" return postprocessing_transforms\n",
322362
"\n",
@@ -358,31 +398,39 @@
358398
" if key not in self.PREPROCESSING_EXTRA_ARGS:\n",
359399
" logging.warning(f\"Cannot set parameter {key} for preprocessing.\")\n",
360400
" inputs = self.preprocessing_transforms(inputs)\n",
401+
" inputs = list_data_collate([inputs])\n",
361402
" return inputs\n",
362403
"\n",
363404
" def _forward(\n",
364405
" self,\n",
365406
" inputs,\n",
366407
" amp: bool = True,\n",
367408
" ):\n",
368-
" inputs.to(self.device)\n",
369-
" self.model.to(self.device)\n",
370-
" mode = (eval_mode,)\n",
371-
" outputs = {Keys.IMAGE: inputs, Keys.LABEL: None}\n",
409+
" inputs[Keys.IMAGE].to(self.device)\n",
410+
" self.model.unet.to(self.device)\n",
411+
" mode = eval_mode\n",
372412
" with mode(self.model):\n",
373413
" if amp:\n",
374414
" with torch.autocast(\"cuda\"):\n",
375-
" outputs[Keys.LABEL] = self.inferer(inputs, self.model)\n",
415+
" inputs[Keys.PRED] = self.inferer(inputs[Keys.IMAGE], self.model)\n",
376416
" else:\n",
377-
" outputs[Keys.LABEL] = self.inferer(inputs, self.model)\n",
378-
" return outputs\n",
417+
" inputs[Keys.PRED] = self.inferer(inputs[Keys.IMAGE], self.model)\n",
418+
" return inputs\n",
379419
"\n",
380420
" def postprocess(self, outputs, **kwargs):\n",
381-
" for key, _ in kwargs.items():\n",
421+
" for key, value in kwargs.items():\n",
382422
" if key not in self.POSTPROCESSING_EXTRA_ARGS:\n",
383423
" logging.warning(f\"Cannot set parameter {key} for postprocessing.\")\n",
384-
"\n",
385-
" outputs = self.postprocessing_transforms(decollate_batch(outputs))\n",
424+
" if (\n",
425+
" key in self._postprocess_params\n",
426+
" and value != self._postprocess_params[key]\n",
427+
" ) or (key not in self._postprocess_params):\n",
428+
" self._postprocess_params.update(kwargs)\n",
429+
" self.postprocessing_transforms = self._init_postprocessing_transforms(\n",
430+
" **self._postprocess_params\n",
431+
" )\n",
432+
" outputs = decollate_batch(outputs)\n",
433+
" outputs = self.postprocessing_transforms(outputs)\n",
386434
" return outputs"
387435
]
388436
},
@@ -397,26 +445,43 @@
397445
},
398446
{
399447
"cell_type": "code",
400-
"execution_count": null,
448+
"execution_count": 9,
401449
"metadata": {},
402-
"outputs": [],
450+
"outputs": [
451+
{
452+
"name": "stdout",
453+
"output_type": "stream",
454+
"text": [
455+
"2025-02-28 16:34:41,094 INFO image_writer.py:197 - writing: output_directory/spleen_1/spleen_1_seg.nii.gz\n",
456+
"2025-02-28 16:34:47,260 INFO image_writer.py:197 - writing: output_directory/spleen_11/spleen_11_seg.nii.gz\n"
457+
]
458+
}
459+
],
403460
"source": [
404461
"config = MONAIUNetConfig()\n",
405462
"monai_unet = MONAIUNet(config)\n",
406-
"pipeline = SpleenCTSegmentationPipeline(model=monai_unet)\n",
407-
"output = pipeline(data_dicts[0])"
463+
"pipeline = SpleenCTSegmentationPipeline(model=monai_unet, device=torch.device(\"cuda:0\"))\n",
464+
"output = pipeline(data_dicts[:2])"
408465
]
409466
}
410467
],
411468
"metadata": {
412469
"kernelspec": {
413-
"display_name": "nim-test",
470+
"display_name": "Python 3",
414471
"language": "python",
415472
"name": "python3"
416473
},
417474
"language_info": {
475+
"codemirror_mode": {
476+
"name": "ipython",
477+
"version": 3
478+
},
479+
"file_extension": ".py",
480+
"mimetype": "text/x-python",
418481
"name": "python",
419-
"version": "3.10.15"
482+
"nbconvert_exporter": "python",
483+
"pygments_lexer": "ipython3",
484+
"version": "3.10.12"
420485
}
421486
},
422487
"nbformat": 4,

0 commit comments

Comments
 (0)