Skip to content

Feature Request: Automatic Mixed Precision #4921

@Artoriuz

Description

@Artoriuz

Hi all!

I'd like to ask whether there are any plans to eventually support automatic mixed precision like PyTorch and TensorFlow.

In PyToch, all you gotta do is wrap your training loop with torch.autocast():

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass (model + loss)
    with torch.autocast(device_type="cuda"):
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()
    loss.backward()
    optimizer.step()

In TensorFlow, you simply define a policy:

mixed_precision.set_global_policy('mixed_float16')

As far as I can tell, Flax is missing a similar mechanism. If this statement is correct, how is one expected to train in mixed precision?

Thanks in advance.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions