-
Notifications
You must be signed in to change notification settings - Fork 740
Open
Labels
Description
Hi all!
I'd like to ask whether there are any plans to eventually support automatic mixed precision like PyTorch and TensorFlow.
In PyToch, all you gotta do is wrap your training loop with torch.autocast()
:
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
In TensorFlow, you simply define a policy:
mixed_precision.set_global_policy('mixed_float16')
As far as I can tell, Flax is missing a similar mechanism. If this statement is correct, how is one expected to train in mixed precision?
Thanks in advance.