Skip to content

Add BlackForest Flux Support #815

Merged
JingyaHuang merged 73 commits intomainfrom
add-flux-support
Jul 17, 2025
Merged

Add BlackForest Flux Support #815
JingyaHuang merged 73 commits intomainfrom
add-flux-support

Conversation

@JingyaHuang
Copy link
Copy Markdown
Collaborator

@JingyaHuang JingyaHuang commented Mar 24, 2025

What does this PR do?

Fixes #763, #676

Compilation

  • Export of Flux pipeline (TP=8 for Flux transformer 2D) with neuronx_distributed.trace.model_builder.ModelBuilder

Export via CLI

  • Regular
optimum-cli export neuron --model black-forest-labs/FLUX.1-dev --tensor_parallel_size 8 --batch_size 1 --height 768 --width 1360 --num_images_per_prompt 1 --torch_dtype bfloat16 flux_neuron/
  • Tiny test
optimum-cli export neuron --model hf-internal-testing/tiny-flux-pipe-gated-silu --tensor_parallel_size 2 --batch_size 1 --height 8 --width 8 --num_images_per_prompt 1 --sequence_length 256 --torch_dtype bfloat16 tiny_flux_neuron/

Export with NeuronFluxPipeline API

from optimum.neuron import NeuronFluxPipeline

if __name__ == "__main__":
    compiler_args = {"auto_cast": "none"}
    input_shapes = {"batch_size": 1, "height": 1024, "width": 1024}

    pipe = NeuronFluxPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        torch_dtype=torch.bfloat16,
        export=True,
        tensor_parallel_size=8,
        # disable_neuron_cache=True,
        **compiler_args,
        **input_shapes
    )

    # Save locally
    pipe.save_pretrained("flux_dev_neuron_1024_tp8/")

    # Upload to the HuggingFace Hub
    pipe.push_to_hub(
        "flux_dev_neuron_1024_tp8/", repository_id="Jingya/FLUX.1-dev-neuronx-1024x1024-tp8"  # Replace with your HF Hub repo id
    )

You can find an example of compiled artifacts here (Jingya/flux.1-dev_neuronx_tp8 )

Inference

  • Flux Inference

For generating an image with NeuronFluxPipeline:

from optimum.neuron import NeuronFluxPipeline

pipe = NeuronFluxPipeline.from_pretrained("flux_neuron")
prompt = "A cat holding a sign that says hello world"
out = pipe(
    prompt,
    guidance_scale=3.5,
    num_inference_steps=50,
    generator=torch.Generator("cpu").manual_seed(0)
).images[0]
out.save("flux_optimum.png")

Other

  • Tests
  • Doc

cc. @yahavb

@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@github-actions
Copy link
Copy Markdown

This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions Bot added the Stale label Apr 10, 2025
@github-actions github-actions Bot removed the Stale label Apr 11, 2025
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 9, 2025

This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days.

@github-actions github-actions Bot added the Stale label May 9, 2025
@github-actions github-actions Bot removed the Stale label May 14, 2025
@JingyaHuang JingyaHuang marked this pull request as ready for review July 16, 2025 22:48
Copy link
Copy Markdown
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Massive pull-request: impressive !!!! I only have a few minor optional comments, and I think one test should be renamed.
Eventually, we should think about how some of the ModelBuilder related code can be reused for non-decoder models: it is unclear to me yet if there are some specifics that would prevent that though.

)
neuron_model = model_builder.trace(initialize_model_weights=False)

model_builder.shard_checkpoint(serialize_path=output.parent / "weights/")
Copy link
Copy Markdown
Collaborator

@dacorvo dacorvo Jul 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eventually this could be omitted: the weights can be sharded at loading time only, which makes export a lot faster. You only need to remember where the weights are (local dir or hub).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean we shard the weights during the loading time instead of after the tracing? Indeed, it increases the whole export time, but somehow, I would rather spend more time exporting one shot and have faster loading + warmup during the deployment.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe in your case it is not that bad, because sharding does not mean loading the weights on the device, does it ?
It is just that you don't cache the sharded weights, so it is only useful when you want to push the exported model to the hub.
This is what I do for decoders:

  • when using the optimum-cli: export (cache or fetch NEFFs)
  • when using from_pretrained: export (cache or fetch NEFFs) + load_weights
    For large models like llama 70B this makes a huge difference as loading weights takes several minutes.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

na na, shard_checkpoint just shard the weights, then we either serialize things to disk like I do here or load to Neuron.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I should definitly look into the neff cache part of ModelBuilder (not yet the case) and the work you have done that I could reuse. will do it next!

Comment thread optimum/neuron/models/inference/flux/modeling_flux.py
Comment thread tests/inference/test_nxd.py Outdated
Comment thread tests/inference/test_nxd.py Outdated
Comment thread tests/inference/test_nxd.py Outdated
Copy link
Copy Markdown
Collaborator

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few nits, otherwise LGTM!

Flux is a series of text-to-image generation models based on diffusion transformers.

> [!TIP]
> We recommend using a `inf2.24xlarge` instance with tensor parallel size 8 for the model compilation and inference.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why 24x if we only do TP8?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because inf2 jump from either you want to have just 1 neuron device, or you want to have 6
image

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right

Comment thread optimum/exporters/neuron/base.py Outdated
Comment thread optimum/exporters/neuron/convert.py
Comment thread optimum/exporters/neuron/utils.py
Copy link
Copy Markdown
Collaborator

@dacorvo dacorvo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks !

@JingyaHuang JingyaHuang merged commit b898cab into main Jul 17, 2025
8 checks passed
@JingyaHuang JingyaHuang deleted the add-flux-support branch July 17, 2025 15:38
JingyaHuang added a commit that referenced this pull request Aug 25, 2025
# What does this PR do?

This PR will allow Flux Kontext to be used for text2img by building upon
the newly added Flux support here:
#815

Note: This depends on `diffusers >0.34` and the following PR to be
merged huggingface/diffusers#11985

@tengomucho @JingyaHuang

---------

Co-authored-by: JingyaHuang <huang_jingya@outlook.com>
Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support for Flux model in diffusers

5 participants