Skip to content

Commit

Permalink
Set bf16 to true when needed (#635)
Browse files Browse the repository at this point in the history
Co-authored-by: David Corvoysier <[email protected]>
  • Loading branch information
michaelbenayoun and dacorvo authored Jun 17, 2024
1 parent f467bcb commit 047c65e
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions optimum/neuron/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .accelerate import NeuronAcceleratorState, NeuronPartialState
from .accelerate.utils import ModelParallelismPlugin, patch_accelerate_is_torch_xla_available
from .utils import is_main_worker
from .utils.misc import is_precompilation
from .utils.patching import Patcher, patch_within_function
from .utils.torch_xla_and_neuronx_initialization import set_neuron_cc_optlevel

Expand Down Expand Up @@ -177,6 +178,25 @@ def __post_init__(self):
async_save=self.async_save,
)

# If the user did not specify bf16=True but the flags are set, we set bf16=True.
# Without this we can fall in the case where XLA will compile the graph in bf16 with torch.finfo unpatched,
# leading to NaNs.
if not self.bf16 and (
os.environ.get("XLA_USE_BF16", "0") == "1" or os.environ.get("XLA_DOWNCAST_BF16", "0") == "1"
):
self.bf16 = True

if (
is_precompilation()
and self.bf16
and os.environ.get("XLA_USE_BF16", "0") == "0"
and os.environ.get("XLA_DOWNCAST_BF16", "0") == "0"
):
raise ValueError(
"bf16=True but both of the environment variables XLA_USE_BF16 and XLA_DOWNCAST_BF16 are not set. You "
"must set them manually when using `neuron_parallel_compile`."
)

if self.bf16 and self.half_precision_backend == "amp":
os.environ["ACCELERATE_USE_AMP"] = "true"
else:
Expand Down

0 comments on commit 047c65e

Please sign in to comment.