From 047c65e9197c97224c6d24a461c4e23c9d09765b Mon Sep 17 00:00:00 2001 From: Michael Benayoun Date: Mon, 17 Jun 2024 16:36:18 +0200 Subject: [PATCH] Set bf16 to true when needed (#635) Co-authored-by: David Corvoysier --- optimum/neuron/training_args.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/optimum/neuron/training_args.py b/optimum/neuron/training_args.py index 9d2253a66..ce6e34a0b 100644 --- a/optimum/neuron/training_args.py +++ b/optimum/neuron/training_args.py @@ -32,6 +32,7 @@ from .accelerate import NeuronAcceleratorState, NeuronPartialState from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_torch_xla_available from .utils import is_main_worker +from .utils.misc import is_precompilation from .utils.patching import Patcher, patch_within_function from .utils.torch_xla_and_neuronx_initialization import set_neuron_cc_optlevel @@ -177,6 +178,25 @@ def __post_init__(self): async_save=self.async_save, ) + # If the user did not specify bf16=True but the flags are set, we set bf16=True. + # Without this we can fall in the case where XLA will compile the graph in bf16 with torch.finfo unpatched, + # leading to NaNs. + if not self.bf16 and ( + os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1" + ): + self.bf16 = True + + if ( + is_precompilation() + and self.bf16 + and os.environ.get("XLA_USE_BF16", "0") == "0" + and os.environ.get("XLA_DOWNCAST_BF16", "0") == "0" + ): + raise ValueError( + "bf16=True but both of the environment variables XLA_USE_BF16 and XLA_DOWNCAST_BF16 are not set. You " + "must set them manually when using `neuron_parallel_compile`." + ) + if self.bf16 and self.half_precision_backend == "amp": os.environ["ACCELERATE_USE_AMP"] = "true" else: