Replies: 1 comment 2 replies
-
Just double checking, you're not changing the batch size within the same training loop right? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
When working on a current project I found that the compiling time of my
train_step
function increased drastically when I increased the batch size I was using. At first it was taking about 20s when using a batch size of 2 but increased up to 4000s when I bumped the batch size up to 64. Here is thetrain_step
function I'm using (it may be a little confusing because I'm working on an IQA task with a custom model that has an state):As a note, I have a different function to calculate the metrics during validation, and this function isn't showing the same behavior so I thought that it may have been related to the calculation of the gradient, but I don't really know if it makes sense.
I was under the assumption that changing the batch size shouldn't have this big of an influence in compilation and, as I couldn't narrow down the problem, I tried to replicate it in a very simple MNIST classifier example in Colab (here).
What I found was basically the same, as the compilation time goes up with the batch size as you can see in this quick wandb dashboard I set up for the experiment: https://wandb.ai/jorgvt/JaX_Compile?workspace=user-jorgvt
I'd be more than willing to share more information with anyone that can shed some light!
Beta Was this translation helpful? Give feedback.
All reactions