|
173 | 173 | " strides=(2, 2, 2, 2),\n", |
174 | 174 | " num_res_units=2,\n", |
175 | 175 | " norm=\"batch\",\n", |
176 | | - " **kwargs\n", |
| 176 | + " **kwargs,\n", |
177 | 177 | " ):\n", |
178 | 178 | " super().__init__(**kwargs)\n", |
179 | 179 | " self.spatial_dims = spatial_dims\n", |
|
182 | 182 | " self.channels = channels\n", |
183 | 183 | " self.strides = strides\n", |
184 | 184 | " self.num_res_units = num_res_units\n", |
185 | | - " self.norm=norm\n" |
| 185 | + " self.norm = norm" |
186 | 186 | ] |
187 | 187 | }, |
188 | 188 | { |
|
214 | 214 | " channels=config.channels,\n", |
215 | 215 | " strides=config.strides,\n", |
216 | 216 | " num_res_units=config.num_res_units,\n", |
217 | | - " norm=config.norm\n", |
| 217 | + " norm=config.norm,\n", |
218 | 218 | " )\n", |
219 | 219 | "\n", |
220 | 220 | " def forward(self, x):\n", |
|
271 | 271 | "\n", |
272 | 272 | " def _init_preprocessing_transforms(self, image_key=\"image\", load_image=True):\n", |
273 | 273 | " transform_list = [LoadImaged(keys=image_key)] if load_image else []\n", |
274 | | - " transform_list = transform_list.extend([\n", |
275 | | - " EnsureChannelFirstd(keys=image_key),\n", |
276 | | - " Orientationd(keys=image_key, axcodes=\"RAS\"),\n", |
277 | | - " Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n", |
278 | | - " ScaleIntensityRanged(\n", |
279 | | - " keys=image_key,\n", |
280 | | - " a_min=-57,\n", |
281 | | - " a_max=164,\n", |
282 | | - " b_min=0,\n", |
283 | | - " b_max=1,\n", |
284 | | - " clip=True\n", |
285 | | - " ),\n", |
286 | | - " EnsureTyped(keys=image_key)\n", |
287 | | - " ])\n", |
| 274 | + " transform_list = transform_list.extend(\n", |
| 275 | + " [\n", |
| 276 | + " EnsureChannelFirstd(keys=image_key),\n", |
| 277 | + " Orientationd(keys=image_key, axcodes=\"RAS\"),\n", |
| 278 | + " Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n", |
| 279 | + " ScaleIntensityRanged(keys=image_key, a_min=-57, a_max=164, b_min=0, b_max=1, clip=True),\n", |
| 280 | + " EnsureTyped(keys=image_key),\n", |
| 281 | + " ]\n", |
| 282 | + " )\n", |
288 | 283 | " preprocessing_transforms = Compose(transform_list)\n", |
289 | 284 | " return preprocessing_transforms\n", |
290 | 285 | "\n", |
|
306 | 301 | " transform=copy.deepcopy(self.preprocessing_transforms),\n", |
307 | 302 | " orig_keys=image_key,\n", |
308 | 303 | " nearest_interp=False,\n", |
309 | | - " to_tensor=True\n", |
| 304 | + " to_tensor=True,\n", |
310 | 305 | " ),\n", |
311 | 306 | " AsDiscreted(keys=pred_key, argmax=True),\n", |
312 | 307 | " ]\n", |
313 | | - " transform_list = transform_list.append(SaveImaged(\n", |
314 | | - " keys=pred_key,\n", |
315 | | - " output_dir=output_dir,\n", |
316 | | - " output_ext=output_ext,\n", |
317 | | - " output_dtype=output_dtype,\n", |
318 | | - " output_postfix=output_postfix,\n", |
319 | | - " separate_folder=separate_folder\n", |
320 | | - " )) if save_output else transform_list\n", |
| 308 | + " transform_list = (\n", |
| 309 | + " transform_list.append(\n", |
| 310 | + " SaveImaged(\n", |
| 311 | + " keys=pred_key,\n", |
| 312 | + " output_dir=output_dir,\n", |
| 313 | + " output_ext=output_ext,\n", |
| 314 | + " output_dtype=output_dtype,\n", |
| 315 | + " output_postfix=output_postfix,\n", |
| 316 | + " separate_folder=separate_folder,\n", |
| 317 | + " )\n", |
| 318 | + " )\n", |
| 319 | + " if save_output\n", |
| 320 | + " else transform_list\n", |
| 321 | + " )\n", |
321 | 322 | "\n", |
322 | 323 | " postprocessing_transforms = Compose(transform_list)\n", |
323 | 324 | " return postprocessing_transforms\n", |
324 | | - " \n", |
| 325 | + "\n", |
325 | 326 | " def _init_inferer(\n", |
326 | 327 | " self,\n", |
327 | 328 | " roi_size=(96, 96, 96),\n", |
328 | 329 | " overlap=0.5,\n", |
329 | 330 | " sw_batch_size=4,\n", |
330 | 331 | " ):\n", |
331 | | - " return SlidingWindowInferer(\n", |
332 | | - " roi_size=roi_size,\n", |
333 | | - " sw_batch_size=sw_batch_size,\n", |
334 | | - " overlap=overlap\n", |
335 | | - " )\n", |
| 332 | + " return SlidingWindowInferer(roi_size=roi_size, sw_batch_size=sw_batch_size, overlap=overlap)\n", |
336 | 333 | "\n", |
337 | 334 | " def _sanitize_parameters(self, **kwargs):\n", |
338 | 335 | " preprocessing_kwargs = {}\n", |
|
359 | 356 | " ):\n", |
360 | 357 | " for key, value in kwargs.items():\n", |
361 | 358 | " if key in self._preprocess_params and value != self._preprocess_params[key]:\n", |
362 | | - " logging.warning(\n", |
363 | | - " f\"Please set the parameter {key} during initialization.\"\n", |
364 | | - " )\n", |
| 359 | + " logging.warning(f\"Please set the parameter {key} during initialization.\")\n", |
365 | 360 | "\n", |
366 | 361 | " if key not in self.PREPROCESSING_EXTRA_ARGS:\n", |
367 | 362 | " logging.warning(f\"Cannot set parameter {key} for preprocessing.\")\n", |
|
375 | 370 | " ):\n", |
376 | 371 | " inputs.to(self.device)\n", |
377 | 372 | " self.model.to(self.device)\n", |
378 | | - " mode=eval_mode,\n", |
| 373 | + " mode = (eval_mode,)\n", |
379 | 374 | " outputs = {Keys.IMAGE: inputs, Keys.LABEL: None}\n", |
380 | 375 | " with mode(self.model):\n", |
381 | 376 | " if amp:\n", |
|
391 | 386 | " logging.warning(f\"Cannot set parameter {key} for postprocessing.\")\n", |
392 | 387 | "\n", |
393 | 388 | " outputs = self.postprocessing_transforms(decollate_batch(outputs))\n", |
394 | | - " return outputs\n" |
| 389 | + " return outputs" |
395 | 390 | ] |
396 | 391 | }, |
397 | 392 | { |
|
0 commit comments