Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SparseDice Loss #968

Open
wants to merge 64 commits into
base: master
Choose a base branch
from
Open

Conversation

DavidLandup0
Copy link
Contributor

@DavidLandup0 DavidLandup0 commented Oct 28, 2022

What does this PR do?

As introduced in #296 discussed in #371, the modified PR includes:

  • Fixed typos/extended documentation
  • Changed the input shapes to conform to shapes seen in the training scripts and other Keras losses
  • Test cases for output shapes and numerical tests of CategoricalDice, SparseDice and BinaryDice
  • Check for axis which adjusts the channel format (whether the input is channels_first or channels_last)
  • Check for axis values
  • Added support for calculating dice loss per image in a batch and computing the overall score as the mean of those scores, as discussed in DICE loss #296 (comment)

Verified on Tan's training script for DeepLabV3.
It's worth noting that since Dice Score is bound between 0 and 1, it'll generally be significantly lower than a crossentropy metric, and will converge slower than with a crossentropy loss. This was somewhat mitigated by increasing the base LR in the script. I've personally had more luck with SparseCrossentropy in almost all cases over various Dice Loss implementations on my own DeepLabV3+ implementations, so this checks out? Thoughts?

Tagging the people from the first PR and @innat whose original PR was modified.
/cc @LukeWood @bhack @qlzh727

Note: I haven't been able to test the per_image argument. @innat mentioned that it provides a performance boost in some cases, but I'm not sure which tasks these are. Could you share some that we can benchmark the argument?

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue? Please add a link
    to it if that's the case. DICE loss #296 Add Dice Loss #371
  • Did you write any new necessary tests?
  • If this adds a new model, can you run a few training steps on TPU in Colab to ensure that no XLA incompatible OP are used?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@DavidLandup0 DavidLandup0 marked this pull request as ready for review November 22, 2022 00:28
@DavidLandup0 DavidLandup0 mentioned this pull request Nov 22, 2022
Copy link
Contributor

@tanzhenyu tanzhenyu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help us report the mean iou after 30 epochs? (w/ or w/o mixed precision is fine)

@DavidLandup0
Copy link
Contributor Author

Sure! Just to check - there are 9963 images in the set, right?

@tanzhenyu
Copy link
Contributor

Sure! Just to check - there are 9963 images in the set, right?

Yes it is around that. I have yet to submit a "Step 4" PR to augment it to around ~10000 images, but it's on the right order

@DavidLandup0
Copy link
Contributor Author

And the sbd_train is train+val splits, or? I can't seem to find the exact info 🤔

@tanzhenyu
Copy link
Contributor

And the sbd_train is train+val splits, or? I can't seem to find the exact info 🤔

you can use whatever the training script is using today

@DavidLandup0
Copy link
Contributor Author

Running the script. It might take a while, so I'll run it locally on my workstation to avoid interruptions in training. I'll report the val mean iou after 30 epochs when they run out :)

@tanzhenyu
Copy link
Contributor

Running the script. It might take a while, so I'll run it locally on my workstation to avoid interruptions in training. I'll report the val mean iou after 30 epochs when they run out :)

on a 2-gpu setup, it takes ~4 hours I think? Hopefully this can be done simply by changing model.compile in the training script.

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Nov 22, 2022

Yup, the only difference is the change from SparseCategoricalCrossentropy to SparseDice! (and the higher learning rate)

I'm running it on a single GPU system, since Kaggle's 2-GPU environment throws unsupported Conv2D operations (a recent example by @bhack worked, but I haven't been able to get it running for KCV models yet). Due to the limited VRAM, I can only fit the training with a batch_size of 1, so hopefully the training doesn't go awry because of the super small batch size.

Edit: got it running on a 2-GPU setup:

Epoch 24/30
932/932 [==============================] - 1138s 1s/step - loss: 0.9034 - sparse_categorical_crossentropy: 0.4876 - mean_io_u: 0.5997 - sparse_categorical_accuracy: 0.8985 - val_loss: 0.9086 - val_sparse_categorical_crossentropy: 0.5602 - val_mean_io_u: 0.5854 - val_sparse_categorical_accuracy: 0.8885

What's the reference val_mean_io_u to aim for? The 63% from #975?
It's still running, but the current run's max so far was 59.37% in epoch 22. We could probably tweak the script a bit more to get it up, but I wanted to keep it the same for a fair benchmark.

I think it's worth noting that the loss is still 0.9 out of a max 1.0, while the sparse_categorical_crossentropy is 0.48. Dice seems to be much more aware of the "wrong parts" than CE 🤔

@DavidLandup0
Copy link
Contributor Author

Scope question: Do we want to add a DiceScore metric as well? The loss is 1-score so we've got everything needed to do the score as a metric as well if the loss looks okay :)
@tanzhenyu @LukeWood @ianstenbit

@DavidLandup0
Copy link
Contributor Author

The 30 epoch run is done, best epoch was:

Epoch 22/30
932/932 [==============================] - 1137s 1s/step - loss: 0.9042 - sparse_categorical_crossentropy: 0.5024 - mean_io_u: 0.5802 - sparse_categorical_accuracy: 0.8944 - val_loss: 0.9084 - val_sparse_categorical_crossentropy: 0.5301 - val_mean_io_u: 0.5937 - val_sparse_categorical_accuracy: 0.8920

And it tapered off in the end, with somewhat unstable metrics. The LR might've been a bit too high?
Does this look okay? @tanzhenyu

@tanzhenyu
Copy link
Contributor

The 30 epoch run is done, best epoch was:

Epoch 22/30
932/932 [==============================] - 1137s 1s/step - loss: 0.9042 - sparse_categorical_crossentropy: 0.5024 - mean_io_u: 0.5802 - sparse_categorical_accuracy: 0.8944 - val_loss: 0.9084 - val_sparse_categorical_crossentropy: 0.5301 - val_mean_io_u: 0.5937 - val_sparse_categorical_accuracy: 0.8920

And it tapered off in the end, with somewhat unstable metrics. The LR might've been a bit too high? Does this look okay? @tanzhenyu

shouldn't it report something like "sparse_dice" instead of "sparse_categorical_crossentropy"?

@DavidLandup0
Copy link
Contributor Author

DavidLandup0 commented Nov 22, 2022

shouldn't it report something like "sparse_dice" instead of "sparse_categorical_crossentropy"?

It's using a Dice Loss, so the loss and val_loss are the main metric tied to the dice scores. I've included the sparse categorical crossentropy as a metric (keras.metrics.SparseCategoricalCrossentropy) from the original script, because I thought it would help compare the dice loss and CE loss runs with a common metric :)

loss_fn = keras_cv.losses.SparseDice()
metrics = [
        tf.keras.metrics.SparseCategoricalCrossentropy(ignore_class=255),
        tf.keras.metrics.MeanIoU(num_classes=21, sparse_y_pred=False),
        tf.keras.metrics.SparseCategoricalAccuracy(),
]

We don't currently have a dice score as a metric.

@tanzhenyu
Copy link
Contributor

shouldn't it report something like "sparse_dice" instead of "sparse_categorical_crossentropy"?

It's using a Dice Loss, so the loss and val_loss are the main metric tied to the dice scores. I've included the sparse categorical crossentropy as a metric (keras.metrics.SparseCategoricalCrossentropy) from the original script, because I thought it would help compare the dice loss and CE loss runs with a common metric :)

loss_fn = keras_cv.losses.SparseDice()
metrics = [
        tf.keras.metrics.SparseCategoricalCrossentropy(ignore_class=255),
        tf.keras.metrics.MeanIoU(num_classes=21, sparse_y_pred=False),
        tf.keras.metrics.SparseCategoricalAccuracy(),
]

We don't currently have a dice score as a metric.

Oh I see. yeah the metrics and loss looks solid to me

@DavidLandup0
Copy link
Contributor Author

Awesome! Wanna review the code and tests?
I can add dice variants as metrics if you think they'd be useful to have 🤔



@tf.keras.utils.register_keras_serializable(package="keras_cv")
class CategoricalDice(tf.keras.losses.Loss):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A high-level recommendation -- can you split this up into separate PRs? For example, if you only use SparseDice, maybe start with that first. This greatly help us to expedite the review process (given we want to review for correctness as well)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope that we could enforce this policy also with internal sourcing PRs.

Recently I saw very large PRs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! They were in one originally, so I kept them that way. Separating into three PRs.
Do I also do Dice as a metric too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, one file for all three PRs or three files? It might be a bit verbose for the FS if there are three files given the shared method(s) between all three losses

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated into three PRs, one for each loss type. Since they share two methods, I've put them all in the same file. We could have a separate file for util methods, but no other losses have this format.
Is this okay? @tanzhenyu

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chaining onto @bhack's comment - what's a large PR?
I'll make mine smaller going forward, but some wrapper would help in trying to make them more manageable for reviews :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we revisit this since it should be near the finish line? @tanzhenyu

@DavidLandup0 DavidLandup0 changed the title Add Dice Loss Add SparseDice Loss Nov 23, 2022
loss_type=None,
label_smoothing=0.0,
epsilon=1e-07,
per_image=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I haven't been able to test the per_image argument. @innat mentioned that it provides a performance boost in some cases, but I'm not sure which tasks these are. Could you share some that we can benchmark the argument?

#296 (comment)

@innat
Copy link
Contributor

innat commented May 28, 2023

@jbischof
Any thoughts on this PRs. #1050 #1049
cc. @DavidLandup0

@jbischof
Copy link
Contributor

jbischof commented Jun 8, 2023

Seems interesting @innat if @DavidLandup0 or someone else wants to continue the work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants