Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions ppdiffusers/examples/flow_grpo/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2025 Jie Liu

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
127 changes: 127 additions & 0 deletions ppdiffusers/examples/flow_grpo/config/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ml_collections


def get_config():
config = ml_collections.ConfigDict()

# General #
# run name for wandb logging and checkpoint saving -- if not provided, will be auto-generated based on the datetime.
config.run_name = ""
# random seed for reproducibility.
config.seed = 42
# top-level logging directory for checkpoint saving.
config.logdir = "logs"
# number of epochs between saving model checkpoints.
config.save_freq = 20
# number of epochs between evaluating the model.
config.eval_freq = 20
# number of checkpoints to keep before overwriting old ones.
config.num_checkpoint_limit = 5
# mixed precision training. options are "fp16", "bf16", and "no". half-precision speeds up training significantly.
config.mixed_precision = "bf16"
# allow tf32 on Ampere GPUs, which can speed up training.
# config.allow_tf32 = True
# The selected gpu id -1 means use cpu
config.device_id = 0
# whether or not to use LoRA.
config.use_lora = True
config.dataset = "/root/paddlejob/workspace/env_run/gxl/flow_grpo/dataset/drawbench"
config.resolution = 768

# Pretrained Model #
config.pretrained = pretrained = ml_collections.ConfigDict()
# base model to load. either a path to a local directory, or a model name from the HuggingFace model hub.
pretrained.model = "runwayml/stable-diffusion-v1-5"
# revision of the model to load.
pretrained.revision = "main"

# Sampling #
config.sample = sample = ml_collections.ConfigDict()
# number of sampler inference steps for collecting dataset.
sample.num_steps = 40
# number of sampler inference steps for evaluation.
sample.eval_num_steps = 40
# classifier-free guidance weight. 1.0 is no guidance.
sample.guidance_scale = 4.5
# batch size (per GPU!) to use for sampling.
sample.train_batch_size = 1
sample.num_image_per_prompt = 1
sample.test_batch_size = 1
# number of batches to sample per epoch. the total number of samples per epoch is `num_batches_per_epoch *
# batch_size * num_gpus`.
sample.num_batches_per_epoch = 2
# Whether use all samples in a batch to compute std
sample.global_std = True
# noise level
sample.noise_level = 0.7
# Whether to use the same noise for the same prompt
sample.same_latent = False

# Training #
config.train = train = ml_collections.ConfigDict()
# batch size (per GPU!) to use for training.
train.batch_size = 1
# whether to use the 8bit Adam optimizer from bitsandbytes.
train.use_8bit_adam = False
# learning rate.
train.learning_rate = 3e-4
# Adam beta1.
train.adam_beta1 = 0.9
# Adam beta2.
train.adam_beta2 = 0.999
# Adam weight decay.
train.adam_weight_decay = 1e-4
# Adam epsilon.
train.adam_epsilon = 1e-8
# number of gradient accumulation steps. the effective batch size is `batch_size * num_gpus *
# gradient_accumulation_steps`.
train.gradient_accumulation_steps = 1
# maximum gradient norm for gradient clipping.
train.max_grad_norm = 1.0
# number of inner epochs per outer epoch. each inner epoch is one iteration through the data collected during one
# outer epoch's round of sampling.
train.num_inner_epochs = 1
# whether or not to use classifier-free guidance during training. if enabled, the same guidance scale used during
# sampling will be used during training.
train.cfg = True
# clip advantages to the range [-adv_clip_max, adv_clip_max].
train.adv_clip_max = 5
# the PPO clip range.
train.clip_range = 1e-4
# the fraction of timesteps to train on. if set to less than 1.0, the model will be trained on a subset of the
# timesteps for each sample. this will speed up training but reduce the accuracy of policy gradient estimates.
train.timestep_fraction = 1.0
# kl ratio
train.beta = 0.0
# pretrained lora path
train.lora_path = None
# save ema model
train.ema = False

# Prompt Function #
# prompt function to use. see `prompts.py` for available prompt functions.
config.prompt_fn = "imagenet_animals"
# kwargs to pass to the prompt function.
config.prompt_fn_kwargs = {}

# Reward Function #
# reward function to use. see `rewards.py` for available reward functions.
config.reward_fn = ml_collections.ConfigDict()
config.save_dir = ""

# Per-Prompt Stat Tracking #
config.per_prompt_stat_tracking = True

return config
124 changes: 124 additions & 0 deletions ppdiffusers/examples/flow_grpo/config/dpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import imp
import os

base = imp.load_source("base", os.path.join(os.path.dirname(__file__), "base.py"))


def compressibility():
config = base.get_config()

config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")

config.use_lora = True

config.sample.batch_size = 8
config.sample.num_batches_per_epoch = 4

config.train.batch_size = 4
config.train.gradient_accumulation_steps = 2

# prompting
config.prompt_fn = "general_ocr"

# rewards
config.reward_fn = {"jpeg_compressibility": 1}
config.per_prompt_stat_tracking = True
return config


def geneval_sd3():
config = compressibility()
config.dataset = os.path.join(os.getcwd(), "dataset/geneval")

# sd3.5 medium
config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
config.sample.num_steps = 40
config.sample.eval_num_steps = 40
config.sample.guidance_scale = 4.5

config.resolution = 512
config.sample.train_batch_size = 24
config.sample.num_image_per_prompt = 24
config.sample.num_batches_per_epoch = 1
config.sample.test_batch_size = 14 # This bs is a special design, the test set has a total of 2212, to make gpu_num*bs*n as close as possible to 2212, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.

config.train.algorithm = "dpo"
# Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO.
config.train.ref_update_step = 10000000
config.train.batch_size = config.sample.train_batch_size
config.train.gradient_accumulation_steps = 1
config.train.num_inner_epochs = 1
config.train.timestep_fraction = 0.99
config.train.beta = 100
config.sample.global_std = True
config.train.ema = True
config.save_freq = 40 # epoch
config.eval_freq = 40
config.save_dir = "logs/geneval/sd3.5-M-dpo"
config.reward_fn = {
"geneval": 1.0,
}

config.prompt_fn = "geneval"

config.per_prompt_stat_tracking = True
return config


def pickscore_sd3():
config = compressibility()
config.dataset = os.path.join(os.getcwd(), "dataset/pickscore")

# sd3.5 medium
config.pretrained.model = "stabilityai/stable-diffusion-3.5-medium"
config.sample.num_steps = 40
config.sample.eval_num_steps = 40
config.sample.guidance_scale = 4.5

config.resolution = 512
config.sample.train_batch_size = 24
config.sample.num_image_per_prompt = 24
config.sample.num_batches_per_epoch = 1
config.sample.test_batch_size = 16 # # This bs is a special design, the test set has a total of 2048, to make gpu_num*bs*n as close as possible to 2048, because when the number of samples cannot be divided evenly by the number of cards, multi-card will fill the last batch to ensure each card has the same number of samples, affecting gradient synchronization.

config.train.algorithm = "dpo"
# Change ref_update_step to a small number, e.g., 40, to switch to OnlineDPO.
config.train.ref_update_step = 10000000

config.train.batch_size = config.sample.train_batch_size
config.train.gradient_accumulation_steps = 1
config.train.num_inner_epochs = 1
config.train.timestep_fraction = 0.99
config.train.beta = 100
config.sample.global_std = True
config.train.ema = True
config.save_freq = 60 # epoch
config.eval_freq = 60
config.save_dir = "logs/pickscore/sd3.5-M-dpo"
config.reward_fn = {
"pickscore": 1.0,
}

config.prompt_fn = "general_ocr"

config.per_prompt_stat_tracking = True
return config


def get_config(name):
return globals()[name]()
Loading