Skip to content

Native integration of pytorch_memlab or something like it #5189

@turian

Description

@turian

🚀 Feature

Fine-grained memory profiling in pytorch-lightning, that explains:

  1. what specifically causes GPU utilization.memory
  2. what specifically causes GPU memory.used

Motivation

  1. If data is currently being moved back and forth unnecessarily between the GPU, this slows things down. However, it is hard to pinpoint where this comes from.
  2. Certain architectures and batch sizes cause OOM. However, it would be useful to pinpoint exactly where the memory consumption comes from, to make more compact networks and use larger batch sizes.

Pitch

pytorch-lightning is designed to make it easy to train networks using pytorch. However, debugging utilization.memory and memory.used is very ad-hoc and tricky. Best-practices don't work all the time, and a very simple fine-grained profiler would be very useful, even for experts if they are writing complicated nets.

Alternatives

  • Instrumenting every single line of code with GPUStatsMonitor
  • pytorch_memlab, but it doesn't have good support for pl. Better native integration of this tool in pytorch-lightning would be very beneficial.

Additional context

Attached is a graph from wandb.ai dashboard of my utilization.memory:

image

I am tearing my hair out figuring why this is the case. As far as I can tell everything is on the GPU, and I don't know where the memory accesses are coming from. I'd love a one-liner tool that explained this, rather than poking around blind in a haphazard way.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked onwon't fixThis will not be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions