Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I got an strange result from the training with HR dataset #7

Open
yusuke-ai opened this issue Jan 19, 2024 · 14 comments
Open

I got an strange result from the training with HR dataset #7

yusuke-ai opened this issue Jan 19, 2024 · 14 comments

Comments

@yusuke-ai
Copy link

yusuke-ai commented Jan 19, 2024

Hi,

Thank you for the awesome work!
I trained the model with the HR dataset with the almost same configuration as your code and with the command below. The only difference is the learning rate and it is set to 4e-5 and got models both in iteration=10000 and iteration=around 60000.
(I didn't use the same training rate with the your code because it sometimes jumps the loss while training. I will comment more in #4)

python tools/train.py configs/segrefiner/segrefiner_hr.py --resume-from segrefiner_hr_latest.pth

and I got the different refinement results. The original segrefiner_hr_latest.pth model has a smooth segmentation around the line, but the retrained model has a jaggy segmentation around the line like below.
Expected result is the model shouldn't output differently.

Could you help me with finding the core issue?
Thank you!

Segmentation result with the segrefiner_hr_latest.pth

Screenshot from 2024-01-19 11-10-58

Segmentation result with the retrained model

Screenshot from 2024-01-19 11-13-11

@yusuke-ai
Copy link
Author

yusuke-ai commented Jan 19, 2024

@MengyuWang826
This is the configuration of the training just in case.

checkpoint_config = dict(
    interval=5000, by_epoch=False, save_last=True, max_keep_ckpts=20)
log_config = dict(
    interval=50, hooks=[dict(type='TextLoggerHook', by_epoch=False)])
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = 'segrefiner_hr_latest.pth'
workflow = [('train', 5000)]
opencv_num_threads = 0
mp_start_method = 'fork'
auto_scale_lr = dict(enable=False, base_batch_size=16)
object_size = 256
task = 'instance'
model = dict(
    type='SegRefiner',
    task='instance',
    step=6,
    denoise_model=dict(
        type='DenoiseUNet',
        in_channels=4,
        out_channels=1,
        model_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        attention_strides=(16, 32),
        learn_time_embd=True,
        channel_mult=(1, 1, 2, 2, 4, 4),
        dropout=0.0),
    diffusion_cfg=dict(
        betas=dict(type='linear', start=0.8, stop=0, num_timesteps=6),
        diff_iter=False),
    test_cfg=dict())
img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='LoadAnnotations',
        with_bbox=False,
        with_label=False,
        with_mask=True),
    dict(type='LoadPatchData', object_size=256, patch_size=256),
    dict(type='Resize', img_scale=(256, 256), keep_ratio=False),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='DefaultFormatBundle'),
    dict(
        type='Collect',
        keys=[
            'object_img', 'object_gt_masks', 'object_coarse_masks',
            'patch_img', 'patch_gt_masks', 'patch_coarse_masks'
        ])
]
dataset_type = 'HRCollectionDataset'
img_root = '/share/project/datasets/MSCOCO/coco2017/'
ann_root = '/share/project/datasets/LVIS/'
train_dataloader = dict(samples_per_gpu=4, workers_per_gpu=1)
data = dict(
    train=dict(
        type='HRCollectionDataset',
        pipeline=[
            dict(type='LoadImageFromFile'),
            dict(
                type='LoadAnnotations',
                with_bbox=False,
                with_label=False,
                with_mask=True),
            dict(type='LoadPatchData', object_size=256, patch_size=256),
            dict(type='Resize', img_scale=(256, 256), keep_ratio=False),
            dict(type='RandomFlip', flip_ratio=0.5),
            dict(
                type='Normalize',
                mean=[123.675, 116.28, 103.53],
                std=[58.395, 57.12, 57.375],
                to_rgb=True),
            dict(type='DefaultFormatBundle'),
            dict(
                type='Collect',
                keys=[
                    'object_img', 'object_gt_masks', 'object_coarse_masks',
                    'patch_img', 'patch_gt_masks', 'patch_coarse_masks'
                ])
        ],
        data_root='data/',
        collection_datasets=['thin', 'dis'],
        collection_json='data/collection_hr.json'),
    train_dataloader=dict(samples_per_gpu=4, workers_per_gpu=1),
    val=dict(),
    test=dict())
optimizer = dict(
    type='AdamW', lr=0.00004, weight_decay=0, eps=1e-08, betas=(0.9, 0.999))
optimizer_config = dict(grad_clip=None)
max_iters = 120000
runner = dict(type='IterBasedRunner', max_iters=120000)
lr_config = dict(
    policy='step',
    gamma=0.5,
    by_epoch=False,
    step=[80000, 100000],
    warmup='linear',
    warmup_by_epoch=False,
    warmup_ratio=1.0,
    warmup_iters=10)
interval = 5000
data_root = 'data/'
work_dir = './work_dirs/segrefiner_hr'
auto_resume = False
gpu_ids = [0]

@MengyuWang826
Copy link
Owner

@yusuke-ai
It appears that in this sample, jaggy segmentation occurs because some regions undergo only global refinement without local refinement. Taking the relevant parameters for collecting local patches in the segrefiner_big.py config as an example:
model = dict( type='SegRefinerSemantic', task=task, test_cfg=dict( model_size=object_size, fine_prob_thr=0.9, iou_thr=0.3, batch_max=32))
You can try reducing the fine_prob_thr to ensure that more local patches are collected.

@MengyuWang826
Copy link
Owner

@yusuke-ai It appears that in this sample, jaggy segmentation occurs because some regions undergo only global refinement without local refinement. Taking the relevant parameters for collecting local patches in the segrefiner_big.py config as an example: model = dict( type='SegRefinerSemantic', task=task, test_cfg=dict( model_size=object_size, fine_prob_thr=0.9, iou_thr=0.3, batch_max=32)) You can try reducing the fine_prob_thr to ensure that more local patches are collected.

Alternatively, you can also try implementing a different method for collecting local patches, such as along the edges of the mask. The specific implementation of this step only affects which local patches will be refined and does not impact the functioning of the model.

@yusuke-ai
Copy link
Author

@MengyuWang826
Thank you for the answer!
I tried to change the fine_prob_thr to from 0.1 to 1.0,
but I got similar jaggy results.
Do you come up with other reasons of jaggy result?

@MengyuWang826
Copy link
Owner

@yusuke-ai
image

Since not all positions exhibit jaggy segmentation, I speculate that the areas within the red box represent the normal output, and the appearance of jaggy segmentation is due to not being subjected to local refinement. You can start by visualizing to determine which local patches have undergone local refinement.

@yusuke-ai
Copy link
Author

@MengyuWang826
Thank you! I played with the local refinement code and I concluded that the model I trained has a lower capability than the one you provided. I will try more with the training, but if you can retrain the model with the code in this repository again, it will be really helpful.

@wzx0720
Copy link

wzx0720 commented Jan 26, 2024

@yusuke-ai
Hey! I also try to train the HR-SegRefiner with the pretrained model segrefiner_hr_latest.pth. But I met this problem:
self._epoch = checkpoint['meta']['epoch']
KeyError: 'meta'`
I wonder whether you have met problem like this and how you solve it? Thanks a lot!

@yusuke-ai
Copy link
Author

@wzx0720
Sorry for late.
I just edited the lines around that line and just let them start from epoch 0.

@wzx0720
Copy link

wzx0720 commented Feb 1, 2024

@wzx0720 Sorry for late. I just edited the lines around that line and just let them start from epoch 0.

Sorry for bothering you again. I can't understand how to edit the lines. Do you mean to edit the segerefiner_hr_latest.pth or to edit the .py file? May I see how you "edit the lines"? Thanks a lot again!!!

@wang21jun
Copy link

same problem, have you solved it?

@wzx0720
Copy link

wzx0720 commented May 6, 2024

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

@wang21jun
Copy link

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it.

@wang21jun
Copy link

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it.
load from segrefiner_lr_latest.pth and train on DIS5K, there are still jaggy.
image

@wzx0720
Copy link

wzx0720 commented May 7, 2024

same problem, have you solved it?

@wang21jun I haven't try but maybe you could modify the load_from in configs/_base_/default_runtime.py to fine-tuning the model

Thinks, I will try it.
load from segrefiner_lr_latest.pth and train on DIS5K, there are still jaggy.
image

Yes, I had the same problem in my work so I quit this method now. Maybe you can open an issue and ask the author, I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants