-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dice function with batch_size>1 #16
Comments
it seems it is solved by re-writing top[0].data[0]=np.sum(dice)/float(bottom[0].data.shape[0]) I still have some problems... I will investigate and let you know |
Have you tried changing learning rate when you change batch size? This conversation says that this paper indicates that learning rate should be changed. If you are trying to increase batch size it indicates that learning rate should be decreased. Their relationship indicates if you are doubling the batch size that the learning rate should decrease to ~0.7 of the original learning rate. This likely wont fix everything you are indicating but it might help. I also thought there was another comment in the Vnet issues that indicates that vnet worked with >1 batch size out of the box. That issue indicates that the dice would be reported as the sum of the 2 volumes in the batch, if Im not mistaken this could result in dice scores upto 2.0 but it shouldnt make a difference, essentially you can divide by 2 to get the average dice of the 2 volumes. |
@getta, thanks a lot for the suggestion and the links. I tried the 0.7 decrease. it actually helps, so using this trick, the results with batch_size:1 and batch_size 10 are now Comparable! nice! however, I was hoping that having a greater batch size to work as a data balancing trick since I have a very unbalanced data (90% class A, 10% class B). but I still don't get better results using batch_size >1 ... => regarding the Vnet dice out-of-the-box, it is true that it can work but when batch_size >1 then this dice wont be normalized, so it can not be used along with other normalized loss functions together. this is the main reason it is a good idea to normalize it by the bacth_size, so the dice-loss stays <1 using the following: |
Glad it helped! |
Hi,
I tried to correct the dice function to work for batch_size >1, but didn't much succeed... in particular in the backward function...
any chance you could consider to update you implementation for batch_size>1 ?
That can really help :)
Thanks
The text was updated successfully, but these errors were encountered: