Skip to content

Commit 4fc7f71

Browse files
committed
Merge branch 'hugging-face' of https://github.com/binliunls/tutorials into hugging-face
2 parents f285e0b + fd939be commit 4fc7f71

File tree

1 file changed

+32
-37
lines changed

1 file changed

+32
-37
lines changed

hugging_face/hugging_face_pipeline_for_monai.ipynb

Lines changed: 32 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@
170170
" strides=(2, 2, 2, 2),\n",
171171
" num_res_units=2,\n",
172172
" norm=\"batch\",\n",
173-
" **kwargs\n",
173+
" **kwargs,\n",
174174
" ):\n",
175175
" super().__init__(**kwargs)\n",
176176
" self.spatial_dims = spatial_dims\n",
@@ -179,7 +179,7 @@
179179
" self.channels = channels\n",
180180
" self.strides = strides\n",
181181
" self.num_res_units = num_res_units\n",
182-
" self.norm=norm\n"
182+
" self.norm = norm"
183183
]
184184
},
185185
{
@@ -211,7 +211,7 @@
211211
" channels=config.channels,\n",
212212
" strides=config.strides,\n",
213213
" num_res_units=config.num_res_units,\n",
214-
" norm=config.norm\n",
214+
" norm=config.norm,\n",
215215
" )\n",
216216
"\n",
217217
" def forward(self, x):\n",
@@ -268,20 +268,15 @@
268268
"\n",
269269
" def _init_preprocessing_transforms(self, image_key=\"image\", load_image=True):\n",
270270
" transform_list = [LoadImaged(keys=image_key)] if load_image else []\n",
271-
" transform_list = transform_list.extend([\n",
272-
" EnsureChannelFirstd(keys=image_key),\n",
273-
" Orientationd(keys=image_key, axcodes=\"RAS\"),\n",
274-
" Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n",
275-
" ScaleIntensityRanged(\n",
276-
" keys=image_key,\n",
277-
" a_min=-57,\n",
278-
" a_max=164,\n",
279-
" b_min=0,\n",
280-
" b_max=1,\n",
281-
" clip=True\n",
282-
" ),\n",
283-
" EnsureTyped(keys=image_key)\n",
284-
" ])\n",
271+
" transform_list = transform_list.extend(\n",
272+
" [\n",
273+
" EnsureChannelFirstd(keys=image_key),\n",
274+
" Orientationd(keys=image_key, axcodes=\"RAS\"),\n",
275+
" Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n",
276+
" ScaleIntensityRanged(keys=image_key, a_min=-57, a_max=164, b_min=0, b_max=1, clip=True),\n",
277+
" EnsureTyped(keys=image_key),\n",
278+
" ]\n",
279+
" )\n",
285280
" preprocessing_transforms = Compose(transform_list)\n",
286281
" return preprocessing_transforms\n",
287282
"\n",
@@ -303,33 +298,35 @@
303298
" transform=copy.deepcopy(self.preprocessing_transforms),\n",
304299
" orig_keys=image_key,\n",
305300
" nearest_interp=False,\n",
306-
" to_tensor=True\n",
301+
" to_tensor=True,\n",
307302
" ),\n",
308303
" AsDiscreted(keys=pred_key, argmax=True),\n",
309304
" ]\n",
310-
" transform_list = transform_list.append(SaveImaged(\n",
311-
" keys=pred_key,\n",
312-
" output_dir=output_dir,\n",
313-
" output_ext=output_ext,\n",
314-
" output_dtype=output_dtype,\n",
315-
" output_postfix=output_postfix,\n",
316-
" separate_folder=separate_folder\n",
317-
" )) if save_output else transform_list\n",
305+
" transform_list = (\n",
306+
" transform_list.append(\n",
307+
" SaveImaged(\n",
308+
" keys=pred_key,\n",
309+
" output_dir=output_dir,\n",
310+
" output_ext=output_ext,\n",
311+
" output_dtype=output_dtype,\n",
312+
" output_postfix=output_postfix,\n",
313+
" separate_folder=separate_folder,\n",
314+
" )\n",
315+
" )\n",
316+
" if save_output\n",
317+
" else transform_list\n",
318+
" )\n",
318319
"\n",
319320
" postprocessing_transforms = Compose(transform_list)\n",
320321
" return postprocessing_transforms\n",
321-
" \n",
322+
"\n",
322323
" def _init_inferer(\n",
323324
" self,\n",
324325
" roi_size=(96, 96, 96),\n",
325326
" overlap=0.5,\n",
326327
" sw_batch_size=4,\n",
327328
" ):\n",
328-
" return SlidingWindowInferer(\n",
329-
" roi_size=roi_size,\n",
330-
" sw_batch_size=sw_batch_size,\n",
331-
" overlap=overlap\n",
332-
" )\n",
329+
" return SlidingWindowInferer(roi_size=roi_size, sw_batch_size=sw_batch_size, overlap=overlap)\n",
333330
"\n",
334331
" def _sanitize_parameters(self, **kwargs):\n",
335332
" preprocessing_kwargs = {}\n",
@@ -356,9 +353,7 @@
356353
" ):\n",
357354
" for key, value in kwargs.items():\n",
358355
" if key in self._preprocess_params and value != self._preprocess_params[key]:\n",
359-
" logging.warning(\n",
360-
" f\"Please set the parameter {key} during initialization.\"\n",
361-
" )\n",
356+
" logging.warning(f\"Please set the parameter {key} during initialization.\")\n",
362357
"\n",
363358
" if key not in self.PREPROCESSING_EXTRA_ARGS:\n",
364359
" logging.warning(f\"Cannot set parameter {key} for preprocessing.\")\n",
@@ -372,7 +367,7 @@
372367
" ):\n",
373368
" inputs.to(self.device)\n",
374369
" self.model.to(self.device)\n",
375-
" mode=eval_mode,\n",
370+
" mode = (eval_mode,)\n",
376371
" outputs = {Keys.IMAGE: inputs, Keys.LABEL: None}\n",
377372
" with mode(self.model):\n",
378373
" if amp:\n",
@@ -388,7 +383,7 @@
388383
" logging.warning(f\"Cannot set parameter {key} for postprocessing.\")\n",
389384
"\n",
390385
" outputs = self.postprocessing_transforms(decollate_batch(outputs))\n",
391-
" return outputs\n"
386+
" return outputs"
392387
]
393388
},
394389
{

0 commit comments

Comments
 (0)