Skip to content

Commit

Permalink
Add Pickscore/HPSv2 style reward training and evaluation (#249)
Browse files Browse the repository at this point in the history
* added basic training/testing scripts for megatron clip model

Signed-off-by: Rohit Jena <[email protected]>

* added pickscore dataset code + additional functions for reward model

Signed-off-by: Rohit Jena <[email protected]>

* update config

Signed-off-by: Rohit Jena <[email protected]>

* add forward to MegatronCLIPRewardModel

Signed-off-by: Rohit Jena <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* incorporate comments

Signed-off-by: Rohit Jena <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* changed path names

Signed-off-by: Rohit Jena <[email protected]>

---------

Signed-off-by: Rohit Jena <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and gshennvm committed Aug 30, 2024
1 parent 49ba69e commit 48c39db
Show file tree
Hide file tree
Showing 5 changed files with 708 additions and 7 deletions.
254 changes: 254 additions & 0 deletions examples/mm/clip/conf/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# An example model that works with this config is "https://huggingface.co/yuvalkirstain/PickScore_v1"
# this is the GAP (global average pooling) config
# nocrop config simply trains on the 224x size image
name: pickscorev2
restore_from_path: null # used when starting from a .nemo file
multicrop: False

trainer:
devices: 8
num_nodes: 1
accelerator: gpu
precision: bf16
logger: False # logger provided by exp_manager
enable_checkpointing: False
use_distributed_sampler: False
max_epochs: -1 # PTL default. In practice, max_steps will be reached first.
max_steps: 4000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 200
check_val_every_n_epoch: null
limit_val_batches: 50
limit_test_batches: 500
accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models
gradient_clip_val: 1.0
benchmark: False
enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually

exp_manager:
explicit_log_dir: null
exp_dir: /path/to/checkpoints/${name}
name: megatron_rm_baseline
create_wandb_logger: False
wandb_logger_kwargs:
project: Pickscore
name: ${name}
resume_if_exists: True
resume_ignore_no_checkpoint: True
resume_from_checkpoint: ${model.resume_from_checkpoint}
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits
filename: 'megatron_clip--{val_loss:.2f}--{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
ema:
enable: False
decay: 0.9999
validate_original_weights: False
every_n_steps: 1
cpu_offload: False

model:
precision: 32
# specify micro_batch_size, global_batch_size, and model parallelism
# gradient accumulation will be done automatically based on data_parallel_size
micro_batch_size: 16 # limited by GPU memory
global_batch_size: 16 # will use more micro batches to reach global batch size
tensor_model_parallel_size: 1 # intra-layer model parallelism
pipeline_model_parallel_size: 1 # inter-layer model parallelism
virtual_pipeline_model_parallel_size: null # interleaved pipeline

from_pretrained: /path/to/clip_checkpoint.ckpt # used in fine-tuning
# multimodal configs
output_dim: 1024
# As the number of devices used to train increases, so does the space complexity of
# the logit matrix. Using a naïve all-gather scheme, space complexity will be
# `O(n^2)`. Instead, complexity may become effectively linear if the flags
# `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one
# numerical results as the naïve method.
local_loss: False # calculate loss w/ local features @ global (instead of realizing full global @ global matrix)
gather_with_grad: True # enable full distributed gradient for feature gather, set this to False may cause convergence issue

vision:
precision: 32
# vision configs
patch_dim: 14
img_h: 224
img_w: 224
image_mean: null
image_std: null
num_channels: 3
drop_patch_rate: 0.0
drop_path_rate: 0.0
global_average_pool: False
output_dim: ${model.output_dim}
class_token_length: 1
preprocess_layernorm: True # apply layer norm to embedded tokens
freeze: False

# model architecture
encoder_seq_length: 196
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: learned_parameters
num_layers: 32
hidden_size: 1280
ffn_hidden_size: 5120 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 16
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
use_scaled_init_method: True # use scaled residuals initialization
hidden_dropout: 0. # Dropout probability for hidden state transformer.
attention_dropout: 0.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: layernorm # Type of normalization layers
layernorm_epsilon: 1e-5
do_layer_norm_weight_decay: False # True means weight decay on all params
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.

## Activation Checkpointing
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
activations_checkpoint_num_layers: null # not used with 'selective'
sequence_parallel: False

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
hysteresis: 2 # Gradient scale hysteresis
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# model fusions
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.

use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism.
openai_gelu: False
bias_activation_fusion: False
megatron_legacy: True
activation: gelu

text:
precision: 32
# text configs
output_dim: ${model.output_dim}
freeze: False

# model architecture
encoder_seq_length: 77
max_position_embeddings: ${.encoder_seq_length}
position_embedding_type: learned_parameters
num_layers: 24
hidden_size: 1024
ffn_hidden_size: 4096 # Transformer FFN hidden size. Usually 4 * hidden_size.
num_attention_heads: 16
init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.')
use_scaled_init_method: True # use scaled residuals initialization
hidden_dropout: 0. # Dropout probability for hidden state transformer.
attention_dropout: 0.
kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null
apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number.
normalization: layernorm # Type of normalization layers
layernorm_epsilon: 1e-5
do_layer_norm_weight_decay: False # True means weight decay on all params
pre_process: True # add embedding
post_process: True # add pooler
persist_layer_norm: True # Use of persistent fused layer norm kernel.

## Activation Checkpointing
activations_checkpoint_granularity: null # 'selective' or 'full'
activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective'
activations_checkpoint_num_layers: null # not used with 'selective'
num_micro_batches_with_partial_activation_checkpoints: null
activations_checkpoint_layers_per_pipeline: null
sequence_parallel: False

# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
hysteresis: 2 # Gradient scale hysteresis
fp32_residual_connection: False # Move residual connections to fp32
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# model fusions
masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask.
bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition.

use_cpu_initialization: False # Init weights on the CPU (slow for large models)
onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter.
gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism.
openai_gelu: False
bias_activation_fusion: False
megatron_legacy: True

transformer_engine: False
fp8: False # enables fp8 in TransformerLayer forward
fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3
fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID
fp8_margin: 0 # scaling margin
fp8_interval: 1 # scaling update interval
fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor
fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history
use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False.
activation: gelu

# Megatron O2-style half-precision
megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters
grad_allreduce_chunk_size_mb: 125
grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce

# miscellaneous
seed: 1234
resume_from_checkpoint: null # manually set the checkpoint file to load from
apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this
gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)

tokenizer:
library: 'huggingface'
type: 'laion/CLIP-ViT-H-14-laion2B-s32B-b79K'
# type: 'openai/clip-vit-large-patch14'
model: null
vocab_file: null
merge_file: null
delimiter: null # only used for tabular tokenizer
sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers.
make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency.

data:
num_workers: 8
data_path: /path/to/dataset
no_crop_images: True
train:
drop_last: True
val:
drop_last: True

# Nsys profiling options
nsys_profile:
enabled: False
start_step: 10 # Global batch to start profiling
end_step: 10 # Global batch to end profiling
ranks: [ 0 ] # Global rank IDs to profile
gen_shape: False # Generate model and kernel details including input shapes

optim:
name: fused_adam
lr: 3e-6
weight_decay: 0.01
betas:
- 0.9
- 0.999
sched:
name: PolynomialDecayAnnealing
warmup_steps: 500
power: 1.0
min_lr: 0

120 changes: 120 additions & 0 deletions examples/mm/clip/test_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import pdb

import numpy as np
import torch
import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor

from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from nemo_aligner.data.mm.pickscore_dataset import build_train_valid_datasets
from nemo_aligner.models.mm.stable_diffusion.image_text_rms import get_reward_model
from nemo_aligner.utils.distributed import Timer


@hydra_runner(config_path="conf", config_name="baseline")
@torch.no_grad()
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f"\n{OmegaConf.to_yaml(cfg)}")

cfg.model.global_batch_size = cfg.trainer.devices * cfg.trainer.num_nodes * cfg.model.micro_batch_size

model = get_reward_model(cfg, cfg.model.micro_batch_size, cfg.model.global_batch_size).cuda()
model.eval()
batch_size = cfg.model.micro_batch_size
_, val_ds, test_ds = build_train_valid_datasets(cfg.model, 0, return_test_data=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, drop_last=False, shuffle=False, collate_fn=model.dl_collate_fn,)
test_dl = DataLoader(
test_ds, batch_size=batch_size, drop_last=False, shuffle=False, collate_fn=model.dl_collate_fn
)

# collect all labels here
all_val_probs = []
all_val_labels = []

# run through the val and test datasets
thresholds = [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]
for batch in tqdm(val_dl, total=len(val_dl)):
img_0, img_1 = batch["img_0"], batch["img_1"]
label = batch["label"]
prompt = batch["prompt"]
# move to device
img_0, img_1 = [x.cuda() for x in img_0], [x.cuda() for x in img_1]
r0 = model.get_reward(img_0, prompt)[:, None]
r1 = model.get_reward(img_1, prompt)[:, None]
prob = F.softmax(torch.cat([r0, r1], dim=1), dim=1) # [b, 2]
# append
all_val_probs.append(prob.detach().cpu())
all_val_labels.append(label)

all_val_probs = torch.cat(all_val_probs, 0)
all_val_labels = torch.cat(all_val_labels, 0)
logging.info(all_val_labels.shape, all_val_probs.shape)
best_thres, accuracies = calc_thres(all_val_probs, all_val_labels, thresholds)
logging.info(f"Best computed threshold from validation set is {best_thres}.")
logging.info(f"All val accuracies: {accuracies}")

# run on test set
all_test_probs, all_test_labels = [], []
for batch in tqdm(test_dl, total=len(test_dl)):
img_0, img_1 = batch["img_0"], batch["img_1"]
label = batch["label"]
prompt = batch["prompt"]
# move to device
img_0, img_1 = [x.cuda() for x in img_0], [x.cuda() for x in img_1]
r0 = model.get_reward(img_0, prompt)[:, None]
r1 = model.get_reward(img_1, prompt)[:, None]
prob = F.softmax(torch.cat([r0, r1], dim=1), dim=1) # [b, 2]
# append
all_test_probs.append(prob.detach().cpu())
all_test_labels.append(label)
# concat and pass
all_test_labels = torch.cat(all_test_labels, 0)
all_test_probs = torch.cat(all_test_probs, 0)
_, acc = calc_thres(all_test_probs, all_test_labels, [best_thres])
logging.info(f"Test acc: {acc}.")


def calc_thres(probs, labels, thresholds):
# both are of size [B, 2] and thresholds is a list
scores = []
arange = torch.arange(probs.shape[0])
argmax = torch.argmax(probs, dim=1)
batch_size = probs.shape[0]
# compute ties
for t in thresholds:
ties = 1.0 * (torch.abs(probs[:, 0] - probs[:, 1]) <= t) # [B, ]
label_ties = 1.0 * (torch.abs(labels[:, 0] - labels[:, 1]) <= 0.01)
# first term gives you a point, 0.5 or 0 points for all non-ambiguous predictions,
# for predicted ties, if label is a tie, then give full point, else give half a point
# if label is tie, but pred isnt, 0.5 is added from the first term
score = (labels[arange, argmax] * (1 - ties)).sum() + (ties * (label_ties + 0.5 * (1 - label_ties))).sum()
score /= batch_size
scores.append(score.item())
idx = int(np.argmax(scores))
return thresholds[idx], scores


if __name__ == "__main__":
main()
Loading

0 comments on commit 48c39db

Please sign in to comment.