-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Pytorch Metrics #158
Comments
I don't know anything about it but it sounds like a positive. |
It's just a suite of metrics that torch sets up to manage multi-gpus under the hood. You just pass it your tensors during the training loop and it will store submetrics. Then when you need the actual metric you call and it does the calculation and memory collection under the hood. It saves you from desync issues if you have more than GPU for training. Also PTL supports it so it helps reduce boilerplate for other metrics. |
I think I tried to do this last year, and had some issues getting the features I actually wanted from it to track, so I gave up, but I think it was in beta or something then. I'd be happy if you got it working :) |
I just got it running for BLEU scores on another project, so CER/WER should be stable by now. |
I piloted this a bit. We can do our form of accuracy using |
Yeah, they have a bias towards strings as final output, it's a bit annoying since you need to do metric calculations on CPU at terminus. It's somewhat painless to implement new metrics so long as the distributed training is managed properly. I support implementing it on our side and then once it's robust enough we can push to library if they ever allow more robust metric balancing. |
Closes CUNY-CL#158. * Loss is computed as before, but streamlined somewhat. * `torchmetrics`' implementation of exact match accuracy is lightly adapted. This does everything in tensor-land and should keep things on devices. My tests confirm that accuracy is EXACTLY what it was before. * A `torchmetrics`-compatible implementation of symbol error rate (here defined as the edit distance divided by sum of target lengths) is inserted here. This is heavily documented and it is compatible with our existing implementation. The hot inner loop is still on CPU, but as mentioned in the documentation, this is probably the best option and I don't observe any obvious performance penalty when enabling this. * We do away with the `evaluation` module altogether. Rather we treat the metrics objects as nullables living in the base class, a design adapted from UDTube. The CLI interface is unimpacted, and my side-by-side shows the metrics are exactly the same as before this change.
TorchMetrics support is pretty reliable nowadays and makes distributed training less annoying (no more World sizes, yay!). It also syncs well with Wandb logging and allows monitoring of training batch performance. Any complaints about me migrating validation logging to this?
The text was updated successfully, but these errors were encountered: