-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add SparseDice Loss #968
base: master
Are you sure you want to change the base?
Add SparseDice Loss #968
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you help us report the mean iou after 30 epochs? (w/ or w/o mixed precision is fine)
Sure! Just to check - there are 9963 images in the set, right? |
Yes it is around that. I have yet to submit a "Step 4" PR to augment it to around ~10000 images, but it's on the right order |
And the |
you can use whatever the training script is using today |
Running the script. It might take a while, so I'll run it locally on my workstation to avoid interruptions in training. I'll report the val mean iou after 30 epochs when they run out :) |
on a 2-gpu setup, it takes ~4 hours I think? Hopefully this can be done simply by changing |
Yup, the only difference is the change from
Edit: got it running on a 2-GPU setup:
What's the reference I think it's worth noting that the loss is still 0.9 out of a max 1.0, while the |
Scope question: Do we want to add a |
The 30 epoch run is done, best epoch was:
And it tapered off in the end, with somewhat unstable metrics. The LR might've been a bit too high? |
shouldn't it report something like "sparse_dice" instead of "sparse_categorical_crossentropy"? |
It's using a Dice Loss, so the
We don't currently have a dice score as a metric. |
Oh I see. yeah the metrics and loss looks solid to me |
Awesome! Wanna review the code and tests? |
keras_cv/losses/dice.py
Outdated
|
||
|
||
@tf.keras.utils.register_keras_serializable(package="keras_cv") | ||
class CategoricalDice(tf.keras.losses.Loss): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A high-level recommendation -- can you split this up into separate PRs? For example, if you only use SparseDice, maybe start with that first. This greatly help us to expedite the review process (given we want to review for correctness as well)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope that we could enforce this policy also with internal sourcing PRs.
Recently I saw very large PRs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! They were in one originally, so I kept them that way. Separating into three PRs.
Do I also do Dice as a metric too?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, one file for all three PRs or three files? It might be a bit verbose for the FS if there are three files given the shared method(s) between all three losses
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated into three PRs, one for each loss type. Since they share two methods, I've put them all in the same file. We could have a separate file for util methods, but no other losses have this format.
Is this okay? @tanzhenyu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Chaining onto @bhack's comment - what's a large PR?
I'll make mine smaller going forward, but some wrapper would help in trying to make them more manageable for reviews :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we revisit this since it should be near the finish line? @tanzhenyu
loss_type=None, | ||
label_smoothing=0.0, | ||
epsilon=1e-07, | ||
per_image=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: I haven't been able to test the per_image argument. @innat mentioned that it provides a performance boost in some cases, but I'm not sure which tasks these are. Could you share some that we can benchmark the argument?
@jbischof |
Seems interesting @innat if @DavidLandup0 or someone else wants to continue the work! |
What does this PR do?
As introduced in #296 discussed in #371, the modified PR includes:
axis
which adjusts the channel format (whether the input is channels_first or channels_last)axis
valuesVerified on Tan's training script for DeepLabV3.
It's worth noting that since Dice Score is bound between
0
and1
, it'll generally be significantly lower than a crossentropy metric, and will converge slower than with a crossentropy loss. This was somewhat mitigated by increasing the base LR in the script. I've personally had more luck with SparseCrossentropy in almost all cases over various Dice Loss implementations on my own DeepLabV3+ implementations, so this checks out? Thoughts?Tagging the people from the first PR and @innat whose original PR was modified.
/cc @LukeWood @bhack @qlzh727
Note: I haven't been able to test the
per_image
argument. @innat mentioned that it provides a performance boost in some cases, but I'm not sure which tasks these are. Could you share some that we can benchmark the argument?Before submitting
Pull Request section?
to it if that's the case. DICE loss #296 Add Dice Loss #371
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.