From aafeecc9eb56e733728157ebab9f90ff238fa3a8 Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Fri, 9 Feb 2024 18:56:14 +0100 Subject: [PATCH] Cleanup --- optimum/neuron/trainers.py | 12 +++++++++--- optimum/neuron/utils/hub_neuronx_cache.py | 14 +++++++++++--- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/optimum/neuron/trainers.py b/optimum/neuron/trainers.py index e3420b493..e99fbe309 100755 --- a/optimum/neuron/trainers.py +++ b/optimum/neuron/trainers.py @@ -461,10 +461,16 @@ def _save_xla(self, output_dir: Optional[str] = None): def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False): if not os.environ.get("NEURON_PARALLEL_COMPILE"): # Avoid unnecessary model saving during precompilation - if output_dir is None: - output_dir = self.args.output_dir + with patch_neuron_cc_wrapper(): + with hub_neuronx_cache(cache_repo_id=get_hf_hub_cache_repos()[0]): + if output_dir is None: + output_dir = self.args.output_dir - self._save_xla(output_dir) + self._save_xla(output_dir) + + if xm.get_ordinal() == 0: + synchronize_hub_cache(get_hf_hub_cache_repos()[0]) + xm.rendezvous("Hub cache synchronization done") # Push to the Hub when `save_model` is called by the user. if self.args.push_to_hub and not _internal_call: diff --git a/optimum/neuron/utils/hub_neuronx_cache.py b/optimum/neuron/utils/hub_neuronx_cache.py index a88283cec..6e3911bf1 100644 --- a/optimum/neuron/utils/hub_neuronx_cache.py +++ b/optimum/neuron/utils/hub_neuronx_cache.py @@ -28,7 +28,7 @@ from ..version import __version__ from .import_utils import is_neuronx_available from .patching import patch_everywhere -from .require_utils import requires_torch_neuronx +from .require_utils import requires_torch_neuronx, requires_torch_xla if is_neuronx_available(): @@ -277,6 +277,8 @@ def hf_create_compile_cache(cache_url): patch_everywhere("create_compile_cache", create_compile_cache, "libneuronxla") +@requires_torch_neuronx +@requires_torch_xla @contextmanager def patch_neuron_cc_wrapper(): """ @@ -284,6 +286,8 @@ def patch_neuron_cc_wrapper(): uses our caching system. """ + import torch_xla.core.xla_model as xm + def patch(restore: bool = False): path = os.environ["PATH"] main_dir = Path(path.split(":")[0]) @@ -301,10 +305,14 @@ def patch(restore: bool = False): shutil.copy(src, dst) try: - patch() + if xm.get_ordinal() == 0: + patch() + xm.rendezvous("Patch neuron_cc_wrapper") yield finally: - patch(restore=True) + if xm.get_ordinal() == 0: + patch(restore=True) + xm.rendezvous("Restore neuron_cc_wrapper") @requires_torch_neuronx