Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[INF] Enable torch compile for inference #5612

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
18 changes: 18 additions & 0 deletions deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from packaging import version as pkg_version
from deepspeed.runtime.checkpoint_engine.torch_checkpoint_engine import TorchCheckpointEngine
from deepspeed.utils.timer import SynchronizedWallClockTimer
from deepspeed.runtime.compiler import is_compile_supported

from ..runtime.state_dict_factory import SDLoaderFactory
from ..runtime.weight_quantizer import WeightQuantization
Expand Down Expand Up @@ -185,6 +186,7 @@ def __init__(self, model, config):

# Check if local CUDA graphs can be created in replacement modules
self.local_cuda_graph = self._local_cuda_graph_used(self.module)
self._is_compiled = False

def destroy(self):
# Have to import here because inference_module is a global, but python
Expand Down Expand Up @@ -634,3 +636,19 @@ def _generate(self, *inputs, **kwargs):
)

return self.module.generate(*inputs, **kwargs)

def compile(self, backend=get_accelerator().get_compile_backend(), compile_kwargs={}) -> None:
"""
Compile the module using the specified backend and kwargs.
"""
if not is_compile_supported():
raise RuntimeError("compile is not supported in your version of PyTorch.")

if self._is_compiled:
return
self.module.compile(backend=backend, **compile_kwargs)
self._is_compiled = True

@property
def is_compiled(self) -> bool:
return self._is_compiled
Loading