Skip to content

Low validation Dice during CT-FM fine-tuning #38

@yws0322

Description

@yws0322

Hi, I am fine-tuning the CT-FM model for a 3D lesion segmentation task. I've encountered a situation where the model seems to converge well on the training data, but the validation performance for the 'lesion' class remains low (under 0.15) and unstable throughout the training process.
I suspect that my current configuration might be leading to overfitting, or perhaps I have misconfigured the validation pipeline. I would appreciate your insight on whether my settings align with the recommended practices for this foundation model.

WandB Logs

Image Image

Full Configuration

project: .
_requires_:
    - "$import monai"
    - "$import torch"

args:
  validate:
    ckpt_path: $f"{@vars#save_dir}/best.ckpt"
  test: "%#validate"
  predict: "%#validate"


vars:
    batch_size: 2
    pin_memory: True
    num_workers: 8
    init_LR: 0.0002
    patch_size: [96, 160, 160]
    val_max_patch_size: [192, 240, 240]
    axcodes: "SPL"
    in_channels: 1
    out_channels: 2  
    class_indices: [0, 1]
    class_labels: ["background", "lesion"]
    # Percentage of the dataset to use
    percentage: 100
    intensity_range: [-1024, 2048]
    # System specific variables
    dataset_dir: "/home/yws0322/scratch/CT_FM/Dataset011_abdomen_lesion_seg_ctfm"
    cache_dir: '$f"/home/yws0322/scratch/CT_FM/cache/{@vars#project}/{@vars#name}"'
    format: "lighter"
    save_dir: '$f"/home/yws0322/scratch/CT_FM/evaluations/{@vars#project}/checkpoints/{@vars#name}_{@vars#wandb_group}"'

trainer:
    _target_: pytorch_lightning.Trainer
    benchmark: True
    max_epochs: 300
    check_val_every_n_epoch: 5
    accelerator: gpu
    devices: 4
    strategy: ddp
    sync_batchnorm: True
    precision: 16-mixed
    log_every_n_steps: 10
    logger:
        _target_: pytorch_lightning.loggers.WandbLogger
        project: "@vars#project"
        name: "@vars#name"
        save_dir:  "$@vars#save_dir.replace('checkpoints', 'logs')"
        group: "@vars#wandb_group"

    callbacks:
        - _target_: lighter.callbacks.Freezer
          name_starts_with: ["trunk.encoder"]

        - _target_: pytorch_lightning.callbacks.ModelCheckpoint
          dirpath: "@vars#save_dir"
          save_last: False
          monitor: "val/metrics/Macro_Dice/epoch"
          mode: "max"
          filename: "best"
          auto_insert_metric_name: False
          verbose: True
          every_n_epochs: 5

        - _target_: project.callbacks.WandbImageLogger

        - _target_: project.callbacks.SaveInputTargetImages
          save_dir: '$f"/home/yws0322/scratch/CT_FM/visualizations/{@vars#name}_{@vars#wandb_group}"'
          save_every_n_epochs: 1
          save_every_n_batches: 50
          max_samples: 1

system:
    _target_: lighter.System

    model:
        _target_: project.models.wrapper.TrunkHeadWrapper
        trunk:
        head: null

    criterion:
        _target_: monai.losses.DeepSupervisionLoss
        loss:
          _target_: monai.losses.DiceCELoss
          softmax: True
          to_onehot_y: True
          include_background: True
          squared_pred: True
          smooth_nr: 0
          smooth_dr: 1.0e-05

    optimizer:
        _target_: torch.optim.AdamW
        params: "$@system#model.parameters()"
        lr: "@vars#init_LR"
        weight_decay: 1.0e-05 

    scheduler:
        _target_: monai.optimizers.WarmupCosineSchedule
        optimizer: "@system#optimizer"
        warmup_steps: "$@trainer#max_epochs/100"
        end_lr:  "$@system#optimizer#lr * 0.01"
        warmup_multiplier: 0.1
        t_total: "@trainer#max_epochs"

    metrics:
        train:
          Macro_Dice:
              _target_: project.metrics.monai.DiceScore
              include_background: False
              per_class: False
          Classwise_Dice:
              _target_: torchmetrics.wrappers.ClasswiseWrapper
              metric:
                _target_: project.metrics.monai.DiceScore
                include_background: True
                per_class: True
              labels: "@vars#class_labels"
              
        val: "%#train"

    dataloaders:
        train:
          _target_: torch.utils.data.DataLoader
          batch_size: "%vars#batch_size"
          pin_memory: "%vars#pin_memory"
          num_workers: "%vars#num_workers"
          dataset:
            _target_: monai.data.PersistentDataset
            cache_dir: '$f"{@vars#cache_dir}/train"'
            hash_transform: "$monai.data.utils.json_hashing"
            data:
                _target_: project.data.get_ts_datalist
                data_dir: "@vars#dataset_dir"
                filter_fn: 
                    - "$lambda x: x[x['split'] == 'train']"
                percentage: "@vars#percentage"
            transform:
                _target_: monai.transforms.Compose
                map_items: False
                transforms: 
                    - _target_: monai.transforms.LoadImaged
                      reader: "ITKReader"
                      keys: ["image", "label"]
                      ensure_channel_first: True
                    - _target_: monai.transforms.EnsureTyped
                      keys: ["image", "label"]
                    - _target_: monai.transforms.Orientationd
                      keys: ["image", "label"]
                      axcodes: "@vars#axcodes"
                    - _target_: monai.transforms.LabelFilterd
                      keys: label
                      applied_labels: "@vars#class_indices"
                    - _target_: monai.transforms.MapLabelValued
                      keys: label
                      orig_labels: "@vars#class_indices"
                      target_labels: "$list(range(0, @vars#out_channels))"
                    - _target_: monai.transforms.ScaleIntensityRanged
                      keys: image
                      a_min: "$@vars#intensity_range[0]"
                      a_max: "$@vars#intensity_range[1]"
                      b_min: 0
                      b_max: 1
                      clip: True
                    - _target_: monai.transforms.CropForegroundd
                      keys: ["image", "label"]
                      source_key: image
                      margin: 10
                    - _target_: monai.transforms.SpatialPadd
                      keys: ["image", "label"]
                      spatial_size: "@vars#patch_size"
                      mode: constant
                    - _target_: monai.transforms.RandCropByLabelClassesd
                      keys: ["image", "label"]
                      label_key: "label"
                      image_key: "image"
                      ratios: "$[0] + [1]*(@vars#out_channels-1)"
                      num_classes: "@vars#out_channels"
                      num_samples: 1
                      spatial_size: "@vars#patch_size"   
                      warn: False
                    - _target_: monai.transforms.Lambda
                      func: '$lambda x: x[0]'
                      track_meta: False
                    - _target_: monai.transforms.RandAffined
                      keys: ["image", "label"]
                      mode: ["bilinear", "nearest"]
                      prob: 0.2
                      rotate_range: [0.26, 0.26, 0.26]
                      scale_range: [0.2, 0.2, 0.2]
                      cache_grid: True
                      padding_mode: constant
                    - _target_: monai.transforms.RandGaussianSmoothd
                      keys: image
                      prob: 0.2
                      sigma_x: [0.5, 1.0]
                      sigma_y: [0.5, 1.0]
                      sigma_z: [0.5, 1.0]
                    - _target_: monai.transforms.RandScaleIntensityd
                      keys: image
                      factors: 0.3
                      prob: 0.5
                    - _target_: monai.transforms.RandShiftIntensityd
                      keys: image
                      offsets: 0.1
                      prob: 0.5
                    - _target_: monai.transforms.RandGaussianNoised
                      keys: image
                      std: 0.1
                      prob: 0.2
                    - _target_: monai.transforms.Lambda
                      func: '$lambda x: {"input": x["image"].as_tensor(), "target": x["label"].as_tensor(), "id": x["id"]}'
                      track_meta: False
    
        val:
          _target_: torch.utils.data.DataLoader
          batch_size: 1
          pin_memory: "%vars#pin_memory"
          num_workers: "%vars#num_workers"
          dataset:
            _target_: monai.data.PersistentDataset
            cache_dir: '$f"{@vars#cache_dir}/val"'
            hash_transform: "$monai.data.utils.json_hashing"
            data:
                _target_: project.data.get_ts_datalist
                data_dir: "@vars#dataset_dir"
                filter_fn: 
                    - "$lambda x: x[x['split'] == 'val']"
                percentage: "@vars#percentage" 
            transform:
                _target_: monai.transforms.Compose
                map_items: False
                transforms: 
                    - _target_: monai.transforms.LoadImaged
                      reader: "ITKReader"
                      keys: ["image", "label"]
                      ensure_channel_first: True
                    - _target_: monai.transforms.EnsureTyped
                      keys: ["image", "label"]
                    - _target_: monai.transforms.Orientationd
                      keys: ["image", "label"]
                      axcodes: "@vars#axcodes"
                    - _target_: monai.transforms.LabelFilterd
                      keys: label
                      applied_labels: "@vars#class_indices"
                    - _target_: monai.transforms.MapLabelValued
                      keys: label
                      orig_labels: "@vars#class_indices"
                      target_labels: "$list(range(0, @vars#out_channels))"
                    - _target_: monai.transforms.ScaleIntensityRanged
                      keys: image
                      a_min: "$@vars#intensity_range[0]"
                      a_max: "$@vars#intensity_range[1]"
                      b_min: 0
                      b_max: 1
                      clip: True
                    - _target_: monai.transforms.CropForegroundd
                      keys: ["image", "label"]
                      source_key: image
                      margin: 10
                    - _target_: monai.transforms.SpatialPadd
                      keys: ["image", "label"]
                      spatial_size: "@vars#patch_size"
                      mode: constant
                    - _target_: monai.transforms.RandCropByLabelClassesd 
                      keys: ["image", "label"]
                      label_key: "label"
                      image_key: "image"
                      ratios: "$[0] + [1]*(@vars#out_channels-1)"
                      num_classes: "@vars#out_channels"
                      num_samples: 1
                      spatial_size: "@vars#patch_size"   
                      warn: False
                    - _target_: monai.transforms.Lambda
                      func: '$lambda x: x[0]'
                      track_meta: False
                    - _target_: monai.transforms.Lambda
                      func: '$lambda x: {"input": x["image"].as_tensor(), "target": x["label"].as_tensor(), "id": x["id"]}'
                      track_meta: False
        predict:
          _target_: torch.utils.data.DataLoader
          batch_size: 1
          pin_memory: "%vars#pin_memory"
          num_workers: "%vars#num_workers"
          dataset:
            _target_: monai.data.Dataset
            data:
                _target_: project.data.get_ts_datalist
                data_dir: "@vars#dataset_dir"
                filter_fn: 
                    - "$lambda x: x[x['split'] == 'val']"
                percentage: 100
            transform:
                _target_: monai.transforms.Compose
                transforms:
                    - _target_: monai.transforms.CopyItemsd
                      keys: ["image", "label"]
                      names: ["input", "target",]
                    - _target_: monai.transforms.LoadImaged
                      keys: ["input", "target"]
                      reader: "ITKReader"
                      ensure_channel_first: True
                    - _target_: monai.transforms.EnsureTyped
                      keys: ["input", "target"]
                    - _target_: monai.transforms.Orientationd
                      keys: ["input", "target"]
                      axcodes: "@vars#axcodes"
                    - _target_: monai.transforms.LabelFilterd
                      keys: target
                      applied_labels: "@vars#class_indices"
                    - _target_: monai.transforms.MapLabelValued
                      keys: target
                      orig_labels: "@vars#class_indices"
                      target_labels: "$list(range(0, @vars#out_channels))"
                    - _target_: monai.transforms.ScaleIntensityRanged
                      keys: input
                      a_min: "$@vars#intensity_range[0]"
                      a_max: "$@vars#intensity_range[1]"
                      b_min: 0
                      b_max: 1
                      clip: True
                    - _target_: monai.transforms.CropForegroundd
                      keys: ["input", "target"]
                      source_key: input
                      margin: 10
                    - _target_: monai.transforms.SelectItemsd
                      keys: ["input", "target", "id"]                      

    adapters:
      train:
        batch:
            _target_: lighter.adapters.BatchAdapter
            input_accessor: "input"
            target_accessor: "target"
            identifier_accessor: "id"
        metrics:
            _target_: lighter.adapters.MetricsAdapter
            pred_argument: 0
            target_argument: 1
            pred_transforms:
                - "$lambda x: x[0] if isinstance(x, list) else x" 
                - "$lambda x: torch.softmax(x, 1)"
            target_transforms:
                - "$lambda x: x.squeeze(1).long()"

        logging:        
            _target_: lighter.adapters.LoggingAdapter      
            pred_transforms:
                - "$lambda x: x[0] if isinstance(x, list) else x" 
                - "$lambda x: torch.softmax(x, 1)"
                - "$lambda x: x.argmax(dim=1, keepdim=True)"

      val: "%#train"

    inferer:
        _target_: monai.inferers.SlidingWindowInfererAdapt
        roi_size: "@vars#patch_size"
        sw_batch_size: "%vars#batch_size"
        overlap: 0.625
        mode: gaussian

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions