|
170 | 170 | " strides=(2, 2, 2, 2),\n", |
171 | 171 | " num_res_units=2,\n", |
172 | 172 | " norm=\"batch\",\n", |
173 | | - " **kwargs\n", |
| 173 | + " **kwargs,\n", |
174 | 174 | " ):\n", |
175 | 175 | " super().__init__(**kwargs)\n", |
176 | 176 | " self.spatial_dims = spatial_dims\n", |
|
179 | 179 | " self.channels = channels\n", |
180 | 180 | " self.strides = strides\n", |
181 | 181 | " self.num_res_units = num_res_units\n", |
182 | | - " self.norm=norm\n" |
| 182 | + " self.norm = norm" |
183 | 183 | ] |
184 | 184 | }, |
185 | 185 | { |
|
211 | 211 | " channels=config.channels,\n", |
212 | 212 | " strides=config.strides,\n", |
213 | 213 | " num_res_units=config.num_res_units,\n", |
214 | | - " norm=config.norm\n", |
| 214 | + " norm=config.norm,\n", |
215 | 215 | " )\n", |
216 | 216 | "\n", |
217 | 217 | " def forward(self, x):\n", |
|
268 | 268 | "\n", |
269 | 269 | " def _init_preprocessing_transforms(self, image_key=\"image\", load_image=True):\n", |
270 | 270 | " transform_list = [LoadImaged(keys=image_key)] if load_image else []\n", |
271 | | - " transform_list = transform_list.extend([\n", |
272 | | - " EnsureChannelFirstd(keys=image_key),\n", |
273 | | - " Orientationd(keys=image_key, axcodes=\"RAS\"),\n", |
274 | | - " Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n", |
275 | | - " ScaleIntensityRanged(\n", |
276 | | - " keys=image_key,\n", |
277 | | - " a_min=-57,\n", |
278 | | - " a_max=164,\n", |
279 | | - " b_min=0,\n", |
280 | | - " b_max=1,\n", |
281 | | - " clip=True\n", |
282 | | - " ),\n", |
283 | | - " EnsureTyped(keys=image_key)\n", |
284 | | - " ])\n", |
| 271 | + " transform_list = transform_list.extend(\n", |
| 272 | + " [\n", |
| 273 | + " EnsureChannelFirstd(keys=image_key),\n", |
| 274 | + " Orientationd(keys=image_key, axcodes=\"RAS\"),\n", |
| 275 | + " Spacingd(keys=image_key, pixdim=(1.5, 1.5, 2.0), mode=\"bilinear\"),\n", |
| 276 | + " ScaleIntensityRanged(keys=image_key, a_min=-57, a_max=164, b_min=0, b_max=1, clip=True),\n", |
| 277 | + " EnsureTyped(keys=image_key),\n", |
| 278 | + " ]\n", |
| 279 | + " )\n", |
285 | 280 | " preprocessing_transforms = Compose(transform_list)\n", |
286 | 281 | " return preprocessing_transforms\n", |
287 | 282 | "\n", |
|
303 | 298 | " transform=copy.deepcopy(self.preprocessing_transforms),\n", |
304 | 299 | " orig_keys=image_key,\n", |
305 | 300 | " nearest_interp=False,\n", |
306 | | - " to_tensor=True\n", |
| 301 | + " to_tensor=True,\n", |
307 | 302 | " ),\n", |
308 | 303 | " AsDiscreted(keys=pred_key, argmax=True),\n", |
309 | 304 | " ]\n", |
310 | | - " transform_list = transform_list.append(SaveImaged(\n", |
311 | | - " keys=pred_key,\n", |
312 | | - " output_dir=output_dir,\n", |
313 | | - " output_ext=output_ext,\n", |
314 | | - " output_dtype=output_dtype,\n", |
315 | | - " output_postfix=output_postfix,\n", |
316 | | - " separate_folder=separate_folder\n", |
317 | | - " )) if save_output else transform_list\n", |
| 305 | + " transform_list = (\n", |
| 306 | + " transform_list.append(\n", |
| 307 | + " SaveImaged(\n", |
| 308 | + " keys=pred_key,\n", |
| 309 | + " output_dir=output_dir,\n", |
| 310 | + " output_ext=output_ext,\n", |
| 311 | + " output_dtype=output_dtype,\n", |
| 312 | + " output_postfix=output_postfix,\n", |
| 313 | + " separate_folder=separate_folder,\n", |
| 314 | + " )\n", |
| 315 | + " )\n", |
| 316 | + " if save_output\n", |
| 317 | + " else transform_list\n", |
| 318 | + " )\n", |
318 | 319 | "\n", |
319 | 320 | " postprocessing_transforms = Compose(transform_list)\n", |
320 | 321 | " return postprocessing_transforms\n", |
321 | | - " \n", |
| 322 | + "\n", |
322 | 323 | " def _init_inferer(\n", |
323 | 324 | " self,\n", |
324 | 325 | " roi_size=(96, 96, 96),\n", |
325 | 326 | " overlap=0.5,\n", |
326 | 327 | " sw_batch_size=4,\n", |
327 | 328 | " ):\n", |
328 | | - " return SlidingWindowInferer(\n", |
329 | | - " roi_size=roi_size,\n", |
330 | | - " sw_batch_size=sw_batch_size,\n", |
331 | | - " overlap=overlap\n", |
332 | | - " )\n", |
| 329 | + " return SlidingWindowInferer(roi_size=roi_size, sw_batch_size=sw_batch_size, overlap=overlap)\n", |
333 | 330 | "\n", |
334 | 331 | " def _sanitize_parameters(self, **kwargs):\n", |
335 | 332 | " preprocessing_kwargs = {}\n", |
|
356 | 353 | " ):\n", |
357 | 354 | " for key, value in kwargs.items():\n", |
358 | 355 | " if key in self._preprocess_params and value != self._preprocess_params[key]:\n", |
359 | | - " logging.warning(\n", |
360 | | - " f\"Please set the parameter {key} during initialization.\"\n", |
361 | | - " )\n", |
| 356 | + " logging.warning(f\"Please set the parameter {key} during initialization.\")\n", |
362 | 357 | "\n", |
363 | 358 | " if key not in self.PREPROCESSING_EXTRA_ARGS:\n", |
364 | 359 | " logging.warning(f\"Cannot set parameter {key} for preprocessing.\")\n", |
|
372 | 367 | " ):\n", |
373 | 368 | " inputs.to(self.device)\n", |
374 | 369 | " self.model.to(self.device)\n", |
375 | | - " mode=eval_mode,\n", |
| 370 | + " mode = (eval_mode,)\n", |
376 | 371 | " outputs = {Keys.IMAGE: inputs, Keys.LABEL: None}\n", |
377 | 372 | " with mode(self.model):\n", |
378 | 373 | " if amp:\n", |
|
388 | 383 | " logging.warning(f\"Cannot set parameter {key} for postprocessing.\")\n", |
389 | 384 | "\n", |
390 | 385 | " outputs = self.postprocessing_transforms(decollate_batch(outputs))\n", |
391 | | - " return outputs\n" |
| 386 | + " return outputs" |
392 | 387 | ] |
393 | 388 | }, |
394 | 389 | { |
|
0 commit comments