-
Notifications
You must be signed in to change notification settings - Fork 3
Open
Description
Hi, I am fine-tuning the CT-FM model for a 3D lesion segmentation task. I've encountered a situation where the model seems to converge well on the training data, but the validation performance for the 'lesion' class remains low (under 0.15) and unstable throughout the training process.
I suspect that my current configuration might be leading to overfitting, or perhaps I have misconfigured the validation pipeline. I would appreciate your insight on whether my settings align with the recommended practices for this foundation model.
WandB Logs
Full Configuration
project: .
_requires_:
- "$import monai"
- "$import torch"
args:
validate:
ckpt_path: $f"{@vars#save_dir}/best.ckpt"
test: "%#validate"
predict: "%#validate"
vars:
batch_size: 2
pin_memory: True
num_workers: 8
init_LR: 0.0002
patch_size: [96, 160, 160]
val_max_patch_size: [192, 240, 240]
axcodes: "SPL"
in_channels: 1
out_channels: 2
class_indices: [0, 1]
class_labels: ["background", "lesion"]
# Percentage of the dataset to use
percentage: 100
intensity_range: [-1024, 2048]
# System specific variables
dataset_dir: "/home/yws0322/scratch/CT_FM/Dataset011_abdomen_lesion_seg_ctfm"
cache_dir: '$f"/home/yws0322/scratch/CT_FM/cache/{@vars#project}/{@vars#name}"'
format: "lighter"
save_dir: '$f"/home/yws0322/scratch/CT_FM/evaluations/{@vars#project}/checkpoints/{@vars#name}_{@vars#wandb_group}"'
trainer:
_target_: pytorch_lightning.Trainer
benchmark: True
max_epochs: 300
check_val_every_n_epoch: 5
accelerator: gpu
devices: 4
strategy: ddp
sync_batchnorm: True
precision: 16-mixed
log_every_n_steps: 10
logger:
_target_: pytorch_lightning.loggers.WandbLogger
project: "@vars#project"
name: "@vars#name"
save_dir: "$@vars#save_dir.replace('checkpoints', 'logs')"
group: "@vars#wandb_group"
callbacks:
- _target_: lighter.callbacks.Freezer
name_starts_with: ["trunk.encoder"]
- _target_: pytorch_lightning.callbacks.ModelCheckpoint
dirpath: "@vars#save_dir"
save_last: False
monitor: "val/metrics/Macro_Dice/epoch"
mode: "max"
filename: "best"
auto_insert_metric_name: False
verbose: True
every_n_epochs: 5
- _target_: project.callbacks.WandbImageLogger
- _target_: project.callbacks.SaveInputTargetImages
save_dir: '$f"/home/yws0322/scratch/CT_FM/visualizations/{@vars#name}_{@vars#wandb_group}"'
save_every_n_epochs: 1
save_every_n_batches: 50
max_samples: 1
system:
_target_: lighter.System
model:
_target_: project.models.wrapper.TrunkHeadWrapper
trunk:
head: null
criterion:
_target_: monai.losses.DeepSupervisionLoss
loss:
_target_: monai.losses.DiceCELoss
softmax: True
to_onehot_y: True
include_background: True
squared_pred: True
smooth_nr: 0
smooth_dr: 1.0e-05
optimizer:
_target_: torch.optim.AdamW
params: "$@system#model.parameters()"
lr: "@vars#init_LR"
weight_decay: 1.0e-05
scheduler:
_target_: monai.optimizers.WarmupCosineSchedule
optimizer: "@system#optimizer"
warmup_steps: "$@trainer#max_epochs/100"
end_lr: "$@system#optimizer#lr * 0.01"
warmup_multiplier: 0.1
t_total: "@trainer#max_epochs"
metrics:
train:
Macro_Dice:
_target_: project.metrics.monai.DiceScore
include_background: False
per_class: False
Classwise_Dice:
_target_: torchmetrics.wrappers.ClasswiseWrapper
metric:
_target_: project.metrics.monai.DiceScore
include_background: True
per_class: True
labels: "@vars#class_labels"
val: "%#train"
dataloaders:
train:
_target_: torch.utils.data.DataLoader
batch_size: "%vars#batch_size"
pin_memory: "%vars#pin_memory"
num_workers: "%vars#num_workers"
dataset:
_target_: monai.data.PersistentDataset
cache_dir: '$f"{@vars#cache_dir}/train"'
hash_transform: "$monai.data.utils.json_hashing"
data:
_target_: project.data.get_ts_datalist
data_dir: "@vars#dataset_dir"
filter_fn:
- "$lambda x: x[x['split'] == 'train']"
percentage: "@vars#percentage"
transform:
_target_: monai.transforms.Compose
map_items: False
transforms:
- _target_: monai.transforms.LoadImaged
reader: "ITKReader"
keys: ["image", "label"]
ensure_channel_first: True
- _target_: monai.transforms.EnsureTyped
keys: ["image", "label"]
- _target_: monai.transforms.Orientationd
keys: ["image", "label"]
axcodes: "@vars#axcodes"
- _target_: monai.transforms.LabelFilterd
keys: label
applied_labels: "@vars#class_indices"
- _target_: monai.transforms.MapLabelValued
keys: label
orig_labels: "@vars#class_indices"
target_labels: "$list(range(0, @vars#out_channels))"
- _target_: monai.transforms.ScaleIntensityRanged
keys: image
a_min: "$@vars#intensity_range[0]"
a_max: "$@vars#intensity_range[1]"
b_min: 0
b_max: 1
clip: True
- _target_: monai.transforms.CropForegroundd
keys: ["image", "label"]
source_key: image
margin: 10
- _target_: monai.transforms.SpatialPadd
keys: ["image", "label"]
spatial_size: "@vars#patch_size"
mode: constant
- _target_: monai.transforms.RandCropByLabelClassesd
keys: ["image", "label"]
label_key: "label"
image_key: "image"
ratios: "$[0] + [1]*(@vars#out_channels-1)"
num_classes: "@vars#out_channels"
num_samples: 1
spatial_size: "@vars#patch_size"
warn: False
- _target_: monai.transforms.Lambda
func: '$lambda x: x[0]'
track_meta: False
- _target_: monai.transforms.RandAffined
keys: ["image", "label"]
mode: ["bilinear", "nearest"]
prob: 0.2
rotate_range: [0.26, 0.26, 0.26]
scale_range: [0.2, 0.2, 0.2]
cache_grid: True
padding_mode: constant
- _target_: monai.transforms.RandGaussianSmoothd
keys: image
prob: 0.2
sigma_x: [0.5, 1.0]
sigma_y: [0.5, 1.0]
sigma_z: [0.5, 1.0]
- _target_: monai.transforms.RandScaleIntensityd
keys: image
factors: 0.3
prob: 0.5
- _target_: monai.transforms.RandShiftIntensityd
keys: image
offsets: 0.1
prob: 0.5
- _target_: monai.transforms.RandGaussianNoised
keys: image
std: 0.1
prob: 0.2
- _target_: monai.transforms.Lambda
func: '$lambda x: {"input": x["image"].as_tensor(), "target": x["label"].as_tensor(), "id": x["id"]}'
track_meta: False
val:
_target_: torch.utils.data.DataLoader
batch_size: 1
pin_memory: "%vars#pin_memory"
num_workers: "%vars#num_workers"
dataset:
_target_: monai.data.PersistentDataset
cache_dir: '$f"{@vars#cache_dir}/val"'
hash_transform: "$monai.data.utils.json_hashing"
data:
_target_: project.data.get_ts_datalist
data_dir: "@vars#dataset_dir"
filter_fn:
- "$lambda x: x[x['split'] == 'val']"
percentage: "@vars#percentage"
transform:
_target_: monai.transforms.Compose
map_items: False
transforms:
- _target_: monai.transforms.LoadImaged
reader: "ITKReader"
keys: ["image", "label"]
ensure_channel_first: True
- _target_: monai.transforms.EnsureTyped
keys: ["image", "label"]
- _target_: monai.transforms.Orientationd
keys: ["image", "label"]
axcodes: "@vars#axcodes"
- _target_: monai.transforms.LabelFilterd
keys: label
applied_labels: "@vars#class_indices"
- _target_: monai.transforms.MapLabelValued
keys: label
orig_labels: "@vars#class_indices"
target_labels: "$list(range(0, @vars#out_channels))"
- _target_: monai.transforms.ScaleIntensityRanged
keys: image
a_min: "$@vars#intensity_range[0]"
a_max: "$@vars#intensity_range[1]"
b_min: 0
b_max: 1
clip: True
- _target_: monai.transforms.CropForegroundd
keys: ["image", "label"]
source_key: image
margin: 10
- _target_: monai.transforms.SpatialPadd
keys: ["image", "label"]
spatial_size: "@vars#patch_size"
mode: constant
- _target_: monai.transforms.RandCropByLabelClassesd
keys: ["image", "label"]
label_key: "label"
image_key: "image"
ratios: "$[0] + [1]*(@vars#out_channels-1)"
num_classes: "@vars#out_channels"
num_samples: 1
spatial_size: "@vars#patch_size"
warn: False
- _target_: monai.transforms.Lambda
func: '$lambda x: x[0]'
track_meta: False
- _target_: monai.transforms.Lambda
func: '$lambda x: {"input": x["image"].as_tensor(), "target": x["label"].as_tensor(), "id": x["id"]}'
track_meta: False
predict:
_target_: torch.utils.data.DataLoader
batch_size: 1
pin_memory: "%vars#pin_memory"
num_workers: "%vars#num_workers"
dataset:
_target_: monai.data.Dataset
data:
_target_: project.data.get_ts_datalist
data_dir: "@vars#dataset_dir"
filter_fn:
- "$lambda x: x[x['split'] == 'val']"
percentage: 100
transform:
_target_: monai.transforms.Compose
transforms:
- _target_: monai.transforms.CopyItemsd
keys: ["image", "label"]
names: ["input", "target",]
- _target_: monai.transforms.LoadImaged
keys: ["input", "target"]
reader: "ITKReader"
ensure_channel_first: True
- _target_: monai.transforms.EnsureTyped
keys: ["input", "target"]
- _target_: monai.transforms.Orientationd
keys: ["input", "target"]
axcodes: "@vars#axcodes"
- _target_: monai.transforms.LabelFilterd
keys: target
applied_labels: "@vars#class_indices"
- _target_: monai.transforms.MapLabelValued
keys: target
orig_labels: "@vars#class_indices"
target_labels: "$list(range(0, @vars#out_channels))"
- _target_: monai.transforms.ScaleIntensityRanged
keys: input
a_min: "$@vars#intensity_range[0]"
a_max: "$@vars#intensity_range[1]"
b_min: 0
b_max: 1
clip: True
- _target_: monai.transforms.CropForegroundd
keys: ["input", "target"]
source_key: input
margin: 10
- _target_: monai.transforms.SelectItemsd
keys: ["input", "target", "id"]
adapters:
train:
batch:
_target_: lighter.adapters.BatchAdapter
input_accessor: "input"
target_accessor: "target"
identifier_accessor: "id"
metrics:
_target_: lighter.adapters.MetricsAdapter
pred_argument: 0
target_argument: 1
pred_transforms:
- "$lambda x: x[0] if isinstance(x, list) else x"
- "$lambda x: torch.softmax(x, 1)"
target_transforms:
- "$lambda x: x.squeeze(1).long()"
logging:
_target_: lighter.adapters.LoggingAdapter
pred_transforms:
- "$lambda x: x[0] if isinstance(x, list) else x"
- "$lambda x: torch.softmax(x, 1)"
- "$lambda x: x.argmax(dim=1, keepdim=True)"
val: "%#train"
inferer:
_target_: monai.inferers.SlidingWindowInfererAdapt
roi_size: "@vars#patch_size"
sw_batch_size: "%vars#batch_size"
overlap: 0.625
mode: gaussianMetadata
Metadata
Assignees
Labels
No labels