Skip to content

[Feature Request] Will this support torch.compile to accelerate? #4853

@OutisLi

Description

@OutisLi

Summary

Problem Description

I'm trying to add torch.compile support to accelerate DeePMD-kit training, but I'm encountering a runtime error related to double backward operations that prevents training from working properly.

Current Modifications Made

I've implemented torch.compile support with the following changes:

1. CLI argument added (deepmd/main.py:287-290):

parser_train.add_argument(
    "--compile",
    action="store_true",
    help="(Supported backend: PyTorch) Use torch.compile for model optimization during training",
)

2. Training pipeline integration (deepmd/pt/entrypoints/main.py):

  • Added compile_model parameter to get_trainer() function (line 105)
  • Added compile_model parameter to train() function (line 253)
  • Passed compile_model=FLAGS.compile to main training call (line 546)

3. Model compilation implementation (deepmd/pt/train/training.py:417-429):

elif compile_model and hasattr(torch, 'compile'):
    # Import and set dynamo config to handle scalar outputs from tensor.item()
    import torch._dynamo.config as dynamo_config
    dynamo_config.capture_scalar_outputs = True
    torch._functorch.config.donated_buffer = False

    # For models requiring double backward (e.g., force calculation),
    # use mode='reduce-overhead' which avoids aot_autograd.
    # Set dynamic=True as batch/atom sizes can vary.
    self.model = torch.compile(
        self.model,
        backend='inductor',
        mode='reduce-overhead',
        dynamic=True
    )

4. Training loop adjustment (deepmd/pt/train/training.py:1015,1059):

  • Modified JIT compilation check to include compile_model: if JIT or (self.compile_model and hasattr(torch, 'compile'))

Error Details

When running training with the --compile flag:

dp --pt train input_static.json --skip-neighbor-stat --compile

The training starts successfully but fails during the backward pass with:

RuntimeError: torch.compile with aot_autograd does not currently support double backward

Full error traceback:

File "/home/outis/Software/deepmd-kit/deepmd/pt/train/training.py", line 748, in step
    loss.backward()
File "/home/outis/miniconda3/envs/dpmd_conda/lib/python3.13/site-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/outis/miniconda3/envs/dpmd_conda/lib/python3.13/site-packages/torch/autograd/__init__.py", line 353, in backward
    _engine_run_backward(tensors, retain_graph, create_graph, inputs, accumulate_grad=True)
File "/home/outis/miniconda3/envs/dpmd_conda/lib/python3.13/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
File "/home/outis/miniconda3/envs/dpmd_conda/lib/python3.13/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2128, in backward
    raise RuntimeError(
        "torch.compile with aot_autograd does not currently support double backward"
    )

Analysis

The issue occurs because:

  1. DeePMD-kit's loss function requires double backward computation for force calculations
  2. Despite using mode='reduce-overhead' (which should avoid aot_autograd), torch.compile still uses aot_autograd internally
  3. The error appears during the first loss.backward() call

Environment

  • DeePMD-kit version: 3.1.1.dev27+g46e95428.d20250729
  • PyTorch version: v2.7.1+cu128-ge2d141dbde5
  • CUDA version: 12.8
  • Python version: 3.13
  • OS: Linux 5.15.153.1-microsoft-standard-WSL2

Additional Warnings Observed

Before the failure, several warnings appear:

  • Multiple failed during evaluate_expr(Eq(u0 - u1, 0)) warnings from torch.fx experimental symbolic shapes
  • skipping cudagraphs due to cpu device warnings
  • torch._dynamo hit config.recompile_limit (8) warning with tensor rank mismatch issues
  • Warning about Dynamo not knowing how to trace list.append builtin

Detailed Description

.

Further Information, Files, and Links

No response

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions