|
40 | 40 | }, |
41 | 41 | { |
42 | 42 | "cell_type": "code", |
43 | | - "execution_count": null, |
| 43 | + "execution_count": 2, |
44 | 44 | "metadata": {}, |
45 | | - "outputs": [], |
| 45 | + "outputs": [ |
| 46 | + { |
| 47 | + "name": "stderr", |
| 48 | + "output_type": "stream", |
| 49 | + "text": [ |
| 50 | + "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", |
| 51 | + " from .autonotebook import tqdm as notebook_tqdm\n" |
| 52 | + ] |
| 53 | + }, |
| 54 | + { |
| 55 | + "name": "stdout", |
| 56 | + "output_type": "stream", |
| 57 | + "text": [ |
| 58 | + "MONAI version: 1.4.0\n", |
| 59 | + "Numpy version: 1.24.4\n", |
| 60 | + "Pytorch version: 2.5.0a0+872d972e41.nv24.08\n", |
| 61 | + "MONAI flags: HAS_EXT = True, USE_COMPILED = False, USE_META_DICT = False\n", |
| 62 | + "MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008\n", |
| 63 | + "MONAI __file__: /opt/monai/monai/__init__.py\n", |
| 64 | + "\n", |
| 65 | + "Optional dependencies:\n", |
| 66 | + "Pytorch Ignite version: 0.4.11\n", |
| 67 | + "ITK version: 5.4.0\n", |
| 68 | + "Nibabel version: 5.3.1\n", |
| 69 | + "scikit-image version: 0.24.0\n", |
| 70 | + "scipy version: 1.14.0\n", |
| 71 | + "Pillow version: 10.4.0\n", |
| 72 | + "Tensorboard version: 2.16.2\n", |
| 73 | + "gdown version: 5.2.0\n", |
| 74 | + "TorchVision version: 0.20.0a0\n", |
| 75 | + "tqdm version: 4.66.5\n", |
| 76 | + "lmdb version: 1.5.1\n", |
| 77 | + "psutil version: 6.0.0\n", |
| 78 | + "pandas version: 2.2.2\n", |
| 79 | + "einops version: 0.8.0\n", |
| 80 | + "transformers version: 4.40.2\n", |
| 81 | + "mlflow version: 2.17.0\n", |
| 82 | + "pynrrd version: 1.0.0\n", |
| 83 | + "clearml version: 1.16.5rc2\n", |
| 84 | + "\n", |
| 85 | + "For details about installing the optional dependencies, please visit:\n", |
| 86 | + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies\n", |
| 87 | + "\n" |
| 88 | + ] |
| 89 | + } |
| 90 | + ], |
46 | 91 | "source": [ |
47 | | - "import copy\n", |
48 | 92 | "import logging\n", |
49 | 93 | "\n", |
50 | 94 | "from monai.transforms import (\n", |
|
65 | 109 | "from monai.networks.nets import UNet\n", |
66 | 110 | "from monai.networks.utils import eval_mode\n", |
67 | 111 | "from monai.inferers import SlidingWindowInferer\n", |
68 | | - "from monai.data import decollate_batch\n", |
| 112 | + "from monai.data import decollate_batch, list_data_collate\n", |
69 | 113 | "from monai.config import print_config\n", |
70 | 114 | "from monai.apps import download_and_extract\n", |
71 | 115 | "import torch\n", |
|
111 | 155 | }, |
112 | 156 | { |
113 | 157 | "cell_type": "code", |
114 | | - "execution_count": null, |
| 158 | + "execution_count": 4, |
115 | 159 | "metadata": {}, |
116 | 160 | "outputs": [], |
117 | 161 | "source": [ |
|
133 | 177 | }, |
134 | 178 | { |
135 | 179 | "cell_type": "code", |
136 | | - "execution_count": null, |
| 180 | + "execution_count": 5, |
137 | 181 | "metadata": {}, |
138 | 182 | "outputs": [], |
139 | 183 | "source": [ |
|
154 | 198 | }, |
155 | 199 | { |
156 | 200 | "cell_type": "code", |
157 | | - "execution_count": null, |
| 201 | + "execution_count": 6, |
158 | 202 | "metadata": {}, |
159 | 203 | "outputs": [], |
160 | 204 | "source": [ |
|
195 | 239 | }, |
196 | 240 | { |
197 | 241 | "cell_type": "code", |
198 | | - "execution_count": null, |
| 242 | + "execution_count": 7, |
199 | 243 | "metadata": {}, |
200 | 244 | "outputs": [], |
201 | 245 | "source": [ |
|
231 | 275 | }, |
232 | 276 | { |
233 | 277 | "cell_type": "code", |
234 | | - "execution_count": null, |
| 278 | + "execution_count": 8, |
235 | 279 | "metadata": {}, |
236 | 280 | "outputs": [], |
237 | 281 | "source": [ |
|
259 | 303 | " \"save_output\",\n", |
260 | 304 | " ]\n", |
261 | 305 | "\n", |
262 | | - " def __init__(self, model, config):\n", |
263 | | - " super().__init__(model=model, config=config)\n", |
| 306 | + " def __init__(self, model, **kwargs):\n", |
| 307 | + " super().__init__(model=model, **kwargs)\n", |
264 | 308 | " self.model = model\n", |
265 | 309 | " self.preprocessing_transforms = self._init_preprocessing_transforms(**self._preprocess_params)\n", |
266 | 310 | " self.inferer = self._init_inferer(**self._forward_params)\n", |
267 | 311 | " self.postprocessing_transforms = self._init_postprocessing_transforms(**self._postprocess_params)\n", |
268 | 312 | "\n", |
269 | | - " def _init_preprocessing_transforms(self, image_key=\"image\", load_image=True):\n", |
| 313 | + " def _init_preprocessing_transforms(self, image_key=Keys.IMAGE, load_image=True):\n", |
270 | 314 | " transform_list = [LoadImaged(keys=image_key)] if load_image else []\n", |
271 | | - " transform_list = transform_list.extend(\n", |
| 315 | + " transform_list.extend(\n", |
272 | 316 | " [\n", |
273 | 317 | " EnsureChannelFirstd(keys=image_key),\n", |
274 | 318 | " Orientationd(keys=image_key, axcodes=\"RAS\"),\n", |
|
282 | 326 | "\n", |
283 | 327 | " def _init_postprocessing_transforms(\n", |
284 | 328 | " self,\n", |
285 | | - " pred_key: str = \"pred\",\n", |
286 | | - " image_key: str = \"image\",\n", |
| 329 | + " pred_key: str = Keys.PRED,\n", |
| 330 | + " image_key: str = Keys.IMAGE,\n", |
287 | 331 | " output_dir: str = \"output_directory\",\n", |
288 | 332 | " output_ext: str = \".nii.gz\",\n", |
289 | 333 | " output_dtype: torch.dtype = torch.float32,\n", |
|
295 | 339 | " Activationsd(keys=pred_key, softmax=True),\n", |
296 | 340 | " Invertd(\n", |
297 | 341 | " keys=pred_key,\n", |
298 | | - " transform=copy.deepcopy(self.preprocessing_transforms),\n", |
| 342 | + " transform=self.preprocessing_transforms,\n", |
299 | 343 | " orig_keys=image_key,\n", |
300 | 344 | " nearest_interp=False,\n", |
301 | 345 | " to_tensor=True,\n", |
302 | 346 | " ),\n", |
303 | 347 | " AsDiscreted(keys=pred_key, argmax=True),\n", |
304 | 348 | " ]\n", |
305 | | - " transform_list = (\n", |
306 | | - " transform_list.append(\n", |
307 | | - " SaveImaged(\n", |
308 | | - " keys=pred_key,\n", |
309 | | - " output_dir=output_dir,\n", |
310 | | - " output_ext=output_ext,\n", |
311 | | - " output_dtype=output_dtype,\n", |
312 | | - " output_postfix=output_postfix,\n", |
313 | | - " separate_folder=separate_folder,\n", |
314 | | - " )\n", |
| 349 | + " \n", |
| 350 | + " transform_list.append(\n", |
| 351 | + " SaveImaged(\n", |
| 352 | + " keys=pred_key,\n", |
| 353 | + " output_dir=output_dir,\n", |
| 354 | + " output_ext=output_ext,\n", |
| 355 | + " output_dtype=output_dtype,\n", |
| 356 | + " output_postfix=output_postfix,\n", |
| 357 | + " separate_folder=separate_folder,\n", |
315 | 358 | " )\n", |
316 | | - " if save_output\n", |
317 | | - " else transform_list\n", |
318 | | - " )\n", |
319 | | - "\n", |
| 359 | + " )if save_output else transform_list\n", |
320 | 360 | " postprocessing_transforms = Compose(transform_list)\n", |
321 | 361 | " return postprocessing_transforms\n", |
322 | 362 | "\n", |
|
358 | 398 | " if key not in self.PREPROCESSING_EXTRA_ARGS:\n", |
359 | 399 | " logging.warning(f\"Cannot set parameter {key} for preprocessing.\")\n", |
360 | 400 | " inputs = self.preprocessing_transforms(inputs)\n", |
| 401 | + " inputs = list_data_collate([inputs])\n", |
361 | 402 | " return inputs\n", |
362 | 403 | "\n", |
363 | 404 | " def _forward(\n", |
364 | 405 | " self,\n", |
365 | 406 | " inputs,\n", |
366 | 407 | " amp: bool = True,\n", |
367 | 408 | " ):\n", |
368 | | - " inputs.to(self.device)\n", |
369 | | - " self.model.to(self.device)\n", |
370 | | - " mode = (eval_mode,)\n", |
371 | | - " outputs = {Keys.IMAGE: inputs, Keys.LABEL: None}\n", |
| 409 | + " inputs[Keys.IMAGE].to(self.device)\n", |
| 410 | + " self.model.unet.to(self.device)\n", |
| 411 | + " mode = eval_mode\n", |
372 | 412 | " with mode(self.model):\n", |
373 | 413 | " if amp:\n", |
374 | 414 | " with torch.autocast(\"cuda\"):\n", |
375 | | - " outputs[Keys.LABEL] = self.inferer(inputs, self.model)\n", |
| 415 | + " inputs[Keys.PRED] = self.inferer(inputs[Keys.IMAGE], self.model)\n", |
376 | 416 | " else:\n", |
377 | | - " outputs[Keys.LABEL] = self.inferer(inputs, self.model)\n", |
378 | | - " return outputs\n", |
| 417 | + " inputs[Keys.PRED] = self.inferer(inputs[Keys.IMAGE], self.model)\n", |
| 418 | + " return inputs\n", |
379 | 419 | "\n", |
380 | 420 | " def postprocess(self, outputs, **kwargs):\n", |
381 | | - " for key, _ in kwargs.items():\n", |
| 421 | + " for key, value in kwargs.items():\n", |
382 | 422 | " if key not in self.POSTPROCESSING_EXTRA_ARGS:\n", |
383 | 423 | " logging.warning(f\"Cannot set parameter {key} for postprocessing.\")\n", |
384 | | - "\n", |
385 | | - " outputs = self.postprocessing_transforms(decollate_batch(outputs))\n", |
| 424 | + " if (\n", |
| 425 | + " key in self._postprocess_params\n", |
| 426 | + " and value != self._postprocess_params[key]\n", |
| 427 | + " ) or (key not in self._postprocess_params):\n", |
| 428 | + " self._postprocess_params.update(kwargs)\n", |
| 429 | + " self.postprocessing_transforms = self._init_postprocessing_transforms(\n", |
| 430 | + " **self._postprocess_params\n", |
| 431 | + " )\n", |
| 432 | + " outputs = decollate_batch(outputs)\n", |
| 433 | + " outputs = self.postprocessing_transforms(outputs)\n", |
386 | 434 | " return outputs" |
387 | 435 | ] |
388 | 436 | }, |
|
397 | 445 | }, |
398 | 446 | { |
399 | 447 | "cell_type": "code", |
400 | | - "execution_count": null, |
| 448 | + "execution_count": 9, |
401 | 449 | "metadata": {}, |
402 | | - "outputs": [], |
| 450 | + "outputs": [ |
| 451 | + { |
| 452 | + "name": "stdout", |
| 453 | + "output_type": "stream", |
| 454 | + "text": [ |
| 455 | + "2025-02-28 16:34:41,094 INFO image_writer.py:197 - writing: output_directory/spleen_1/spleen_1_seg.nii.gz\n", |
| 456 | + "2025-02-28 16:34:47,260 INFO image_writer.py:197 - writing: output_directory/spleen_11/spleen_11_seg.nii.gz\n" |
| 457 | + ] |
| 458 | + } |
| 459 | + ], |
403 | 460 | "source": [ |
404 | 461 | "config = MONAIUNetConfig()\n", |
405 | 462 | "monai_unet = MONAIUNet(config)\n", |
406 | | - "pipeline = SpleenCTSegmentationPipeline(model=monai_unet)\n", |
407 | | - "output = pipeline(data_dicts[0])" |
| 463 | + "pipeline = SpleenCTSegmentationPipeline(model=monai_unet, device=torch.device(\"cuda:0\"))\n", |
| 464 | + "output = pipeline(data_dicts[:2])" |
408 | 465 | ] |
409 | 466 | } |
410 | 467 | ], |
411 | 468 | "metadata": { |
412 | 469 | "kernelspec": { |
413 | | - "display_name": "nim-test", |
| 470 | + "display_name": "Python 3", |
414 | 471 | "language": "python", |
415 | 472 | "name": "python3" |
416 | 473 | }, |
417 | 474 | "language_info": { |
| 475 | + "codemirror_mode": { |
| 476 | + "name": "ipython", |
| 477 | + "version": 3 |
| 478 | + }, |
| 479 | + "file_extension": ".py", |
| 480 | + "mimetype": "text/x-python", |
418 | 481 | "name": "python", |
419 | | - "version": "3.10.15" |
| 482 | + "nbconvert_exporter": "python", |
| 483 | + "pygments_lexer": "ipython3", |
| 484 | + "version": "3.10.12" |
420 | 485 | } |
421 | 486 | }, |
422 | 487 | "nbformat": 4, |
|
0 commit comments