From 2fb5ea5ca7ca8ea887af2851cce80ab2545d3f4f Mon Sep 17 00:00:00 2001 From: regisss <15324346+regisss@users.noreply.github.com> Date: Wed, 18 Sep 2024 22:48:34 +0200 Subject: [PATCH] Fix `is_torch_tpu_available` in ORT Trainer (#2028) --- optimum/onnxruntime/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/optimum/onnxruntime/trainer.py b/optimum/onnxruntime/trainer.py index 86c333adb3..66273cbcf9 100644 --- a/optimum/onnxruntime/trainer.py +++ b/optimum/onnxruntime/trainer.py @@ -103,14 +103,14 @@ from transformers.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_zero3_enabled if check_if_transformers_greater("4.39"): - from transformers.utils import is_torch_xla_available + from transformers.utils import is_torch_xla_available as is_torch_tpu_xla_available - if is_torch_xla_available(): + if is_torch_tpu_xla_available(): import torch_xla.core.xla_model as xm else: - from transformers.utils import is_torch_tpu_available + from transformers.utils import is_torch_tpu_available as is_torch_tpu_xla_available - if is_torch_tpu_available(check_device=False): + if is_torch_tpu_xla_available(check_device=False): import torch_xla.core.xla_model as xm if TYPE_CHECKING: @@ -735,7 +735,7 @@ def get_dataloader_sampler(dataloader): if ( args.logging_nan_inf_filter - and not is_torch_tpu_available() + and not is_torch_tpu_xla_available() and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step)) ): # if loss is nan or inf simply add the average of previous logged losses