-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onwon't fixThis will not be worked onThis will not be worked on
Description
🚀 Feature
Fine-grained memory profiling in pytorch-lightning, that explains:
- what specifically causes GPU utilization.memory
- what specifically causes GPU memory.used
Motivation
- If data is currently being moved back and forth unnecessarily between the GPU, this slows things down. However, it is hard to pinpoint where this comes from.
- Certain architectures and batch sizes cause OOM. However, it would be useful to pinpoint exactly where the memory consumption comes from, to make more compact networks and use larger batch sizes.
Pitch
pytorch-lightning is designed to make it easy to train networks using pytorch. However, debugging utilization.memory and memory.used is very ad-hoc and tricky. Best-practices don't work all the time, and a very simple fine-grained profiler would be very useful, even for experts if they are writing complicated nets.
Alternatives
- Instrumenting every single line of code with GPUStatsMonitor
- pytorch_memlab, but it doesn't have good support for pl. Better native integration of this tool in pytorch-lightning would be very beneficial.
Additional context
Attached is a graph from wandb.ai dashboard of my utilization.memory:
I am tearing my hair out figuring why this is the case. As far as I can tell everything is on the GPU, and I don't know where the memory accesses are coming from. I'd love a one-liner tool that explained this, rather than poking around blind in a haphazard way.
Molaire and yllgl
Metadata
Metadata
Assignees
Labels
featureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked onwon't fixThis will not be worked onThis will not be worked on