Add BlackForest Flux Support #815
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
…ron into add-flux-support
|
This PR is stale because it has been open 15 days with no activity. Remove stale label or comment or this will be closed in 5 days. |
…ron into add-flux-support
…ron into add-flux-support
There was a problem hiding this comment.
Massive pull-request: impressive !!!! I only have a few minor optional comments, and I think one test should be renamed.
Eventually, we should think about how some of the ModelBuilder related code can be reused for non-decoder models: it is unclear to me yet if there are some specifics that would prevent that though.
| ) | ||
| neuron_model = model_builder.trace(initialize_model_weights=False) | ||
|
|
||
| model_builder.shard_checkpoint(serialize_path=output.parent / "weights/") |
There was a problem hiding this comment.
Eventually this could be omitted: the weights can be sharded at loading time only, which makes export a lot faster. You only need to remember where the weights are (local dir or hub).
There was a problem hiding this comment.
Do you mean we shard the weights during the loading time instead of after the tracing? Indeed, it increases the whole export time, but somehow, I would rather spend more time exporting one shot and have faster loading + warmup during the deployment.
There was a problem hiding this comment.
Maybe in your case it is not that bad, because sharding does not mean loading the weights on the device, does it ?
It is just that you don't cache the sharded weights, so it is only useful when you want to push the exported model to the hub.
This is what I do for decoders:
- when using the optimum-cli: export (cache or fetch NEFFs)
- when using from_pretrained: export (cache or fetch NEFFs) + load_weights
For large models like llama 70B this makes a huge difference as loading weights takes several minutes.
There was a problem hiding this comment.
na na, shard_checkpoint just shard the weights, then we either serialize things to disk like I do here or load to Neuron.
There was a problem hiding this comment.
Yeah, I should definitly look into the neff cache part of ModelBuilder (not yet the case) and the work you have done that I could reuse. will do it next!
tengomucho
left a comment
There was a problem hiding this comment.
Few nits, otherwise LGTM!
| Flux is a series of text-to-image generation models based on diffusion transformers. | ||
|
|
||
| > [!TIP] | ||
| > We recommend using a `inf2.24xlarge` instance with tensor parallel size 8 for the model compilation and inference. |
There was a problem hiding this comment.
why 24x if we only do TP8?
# What does this PR do? This PR will allow Flux Kontext to be used for text2img by building upon the newly added Flux support here: #815 Note: This depends on `diffusers >0.34` and the following PR to be merged huggingface/diffusers#11985 @tengomucho @JingyaHuang --------- Co-authored-by: JingyaHuang <huang_jingya@outlook.com> Co-authored-by: Jingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

What does this PR do?
Fixes #763, #676
Compilation
TP=8for Flux transformer 2D) withneuronx_distributed.trace.model_builder.ModelBuilderExport via CLI
optimum-cli export neuron --model black-forest-labs/FLUX.1-dev --tensor_parallel_size 8 --batch_size 1 --height 768 --width 1360 --num_images_per_prompt 1 --torch_dtype bfloat16 flux_neuron/optimum-cli export neuron --model hf-internal-testing/tiny-flux-pipe-gated-silu --tensor_parallel_size 2 --batch_size 1 --height 8 --width 8 --num_images_per_prompt 1 --sequence_length 256 --torch_dtype bfloat16 tiny_flux_neuron/Export with
NeuronFluxPipelineAPIYou can find an example of compiled artifacts here (
Jingya/flux.1-dev_neuronx_tp8)Inference
For generating an image with
NeuronFluxPipeline:Other
cc. @yahavb