Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more streamlined local and single_gpu support, allow use of 8 bit adam optimizer #57

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
38 changes: 38 additions & 0 deletions configs/training/lora_c_1b_bfloat16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# GLOBAL STUFF
experiment_id: stage_c_1b_lora
checkpoint_path: /tmp/cascade/chk
output_path: /tmp/cascade/lora_sample
model_version: 1B

# TRAINING PARAMS
lr: 1.0e-4
batch_size: 40
image_size: 768
multi_aspect_ratio: [1/1, 1/2, 1/3, 2/3, 3/4, 1/5, 2/5, 3/5, 4/5, 1/6, 5/6, 9/16]
grad_accum_steps: 4
updates: 10000
backup_every: 1000
save_every: 100
warmup_updates: 1
# use_fsdp: True -> FSDP doesn't work at the moment for LoRA
use_fsdp: False

# GDF
# adaptive_loss_weight: True

# LoRA specific. 'No Defect Train Railcar Wheel'
module_filters: ['.attn']
rank: 4
train_tokens:
# - ['^snail', null] # token starts with "snail" -> "snail" & "snails", don't need to be reinitialized
- ['[fernando]', '^dog</w>'] # custom token [snail], initialize as avg of snail & snails


# ema_start_iters: 5000
# ema_iters: 100
# ema_beta: 0.9

webdataset_path: file:/home/asutermo/cascade/data/dataset.tar
effnet_checkpoint_path: models/effnet_encoder.safetensors
previewer_checkpoint_path: models/previewer.safetensors
generator_checkpoint_path: models/stage_c_lite_bf16.safetensors
5 changes: 4 additions & 1 deletion core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class Config(Base):
wandb_project: str = None
wandb_entity: str = None

single_gpu: bool = False

@dataclass() # not frozen, means that fields are mutable
class Info(): # not inheriting from Base, because we don't want to enforce the default fields
wandb_run_id: str = None
Expand Down Expand Up @@ -141,6 +143,7 @@ def setup_config(self, config_file_path=None, config_dict=None, training=True) -
return self.Config(training=training)

def setup_ddp(self, experiment_id, single_gpu=False):
self.single_gpu = single_gpu
if not single_gpu:
local_rank = int(os.environ.get("SLURM_LOCALID"))
process_id = int(os.environ.get("SLURM_PROCID"))
Expand Down Expand Up @@ -297,7 +300,7 @@ def __call__(self, single_gpu=False):

if self.is_main_node:
print()
print("**STARTIG JOB WITH CONFIG:**")
print("**STARTING JOB WITH CONFIG:**")
print(yaml.dump(self.config.to_dict(), default_flow_style=False))
print("------------------------------------")
print()
Expand Down
3 changes: 2 additions & 1 deletion core/templates/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def models_to_save(self):
return ['generator', 'generator_ema']

def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
barrier()
if not self.single_gpu:
barrier()
suffix = '' if suffix is None else suffix
self.save_info(self.info, suffix=suffix)
models_dict = models.to_dict()
Expand Down
3 changes: 2 additions & 1 deletion train/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .train_b import WurstCore as WurstCoreB
from .train_c import WurstCore as WurstCoreC
from .train_c_controlnet import WurstCore as ControlNetCore
from .train_c_lora import WurstCore as LoraCore
from .train_c_lora import WurstCore as LoraCore

6 changes: 5 additions & 1 deletion train/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ class Config(DataCore.Config, WarpCore.Config):

use_fsdp: bool = None

# Optimizer Params
use_8bit_adam: bool = None

@dataclass() # not frozen, means that fields are mutable. Doesn't support EXPECTED
class Info(WarpCore.Info):
ema_loss: float = None
Expand Down Expand Up @@ -310,7 +313,8 @@ def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, op
self.sample(models, data, extras)

def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None):
barrier()
if not self.single_gpu:
barrier()
suffix = '' if suffix is None else suffix
self.save_info(self.info, suffix=suffix)
models_dict = models.to_dict()
Expand Down
18 changes: 14 additions & 4 deletions train/train_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def dummy_context():
generator_ema = self.load_model(generator_ema, 'generator_ema')
generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)

if self.config.use_fsdp:
if not self.single_gpu and self.config.use_fsdp:
fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
if generator_ema is not None:
Expand All @@ -209,7 +209,15 @@ def dummy_context():
)

def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
if self.config.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
else:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = self.load_optimizer(optimizer, 'generator_optim',
fsdp_model=models.generator if self.config.use_fsdp else None)
return self.Optimizers(generator=optimizer)
Expand Down Expand Up @@ -294,11 +302,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
# core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)
20 changes: 16 additions & 4 deletions train/train_c.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def dummy_context():
generator_ema = self.load_model(generator_ema, 'generator_ema')
generator_ema.to(dtype).to(self.device).eval().requires_grad_(False)

if self.config.use_fsdp:
if not self.single_gpu and self.config.use_fsdp:
fsdp_auto_wrap_policy = ModuleWrapPolicy([ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock])
generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
if generator_ema is not None:
Expand All @@ -192,7 +192,15 @@ def dummy_context():
)

def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
if self.config.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
else:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = self.load_optimizer(optimizer, 'generator_optim',
fsdp_model=models.generator if self.config.use_fsdp else None)
return self.Optimizers(generator=optimizer)
Expand Down Expand Up @@ -256,11 +264,15 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")

device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
# core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)

23 changes: 16 additions & 7 deletions train/train_c_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@

from modules import EfficientNetEncoder
from modules import StageC
from modules import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
from modules import Previewer
from modules import ControlNet, ControlNetDeliverer
from modules import ControlNet
from modules import controlnet_filters

from train.base import DataCore, TrainingCore
Expand All @@ -26,7 +25,6 @@
from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
from accelerate import init_empty_weights
Expand Down Expand Up @@ -223,7 +221,7 @@ def dummy_context():
controlnet = self.load_model(controlnet, 'controlnet')
controlnet.backbone.eval().requires_grad_(True)

if self.config.use_fsdp:
if not self.single_gpu and self.config.use_fsdp:
fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
controlnet = FSDP(controlnet, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)

Expand All @@ -235,7 +233,15 @@ def dummy_context():
)

def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
optimizer = optim.AdamW(models.controlnet.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
if self.config.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
else:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = self.load_optimizer(optimizer, 'controlnet_optim',
fsdp_model=models.controlnet if self.config.use_fsdp else None)
return self.Optimizers(generator=None, controlnet=optimizer)
Expand Down Expand Up @@ -372,11 +378,14 @@ def sample(self, models: Models, data: WarpCore.Data, extras: Extras):

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)

24 changes: 15 additions & 9 deletions train/train_c_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from modules.effnet import EfficientNetEncoder
from modules.stage_c import StageC
from modules.stage_c import ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
from modules.previewer import Previewer
from modules.lora import apply_lora, apply_retoken, LoRA, ReToken

Expand All @@ -26,8 +25,6 @@

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
import functools
from accelerate import init_empty_weights
from accelerate.utils import set_module_tensor_to_device
from contextlib import contextmanager
Expand Down Expand Up @@ -166,7 +163,6 @@ def dummy_context():
yield None

loading_context = dummy_context if self.config.training else init_empty_weights

with loading_context():
# Diffusion models
if self.config.model_version == '3.6B':
Expand All @@ -185,7 +181,7 @@ def dummy_context():
generator = generator.to(dtype).to(self.device)
generator = self.load_model(generator, 'generator')

# if self.config.use_fsdp:
# if not self.single_gpu and self.config.use_fsdp:
# fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
# generator = FSDP(generator, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)

Expand Down Expand Up @@ -239,7 +235,7 @@ def dummy_context():

lora = self.load_model(lora, 'lora')
lora.to(self.device).train().requires_grad_(True)
if self.config.use_fsdp:
if not self.single_gpu and self.config.use_fsdp:
# fsdp_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=3000)
fsdp_auto_wrap_policy = ModuleWrapPolicy([LoRA, ReToken])
lora = FSDP(lora, **self.fsdp_defaults, auto_wrap_policy=fsdp_auto_wrap_policy, device_id=self.device)
Expand All @@ -252,7 +248,15 @@ def dummy_context():
)

def setup_optimizers(self, extras: Extras, models: Models) -> Optimizers:
optimizer = optim.AdamW(models.lora.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
if self.config.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer = bnb.optim.AdamW8bit(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
else:
optimizer = optim.AdamW(models.generator.parameters(), lr=self.config.lr) # , eps=1e-7, betas=(0.9, 0.95))
optimizer = self.load_optimizer(optimizer, 'lora_optim',
fsdp_model=models.lora if self.config.use_fsdp else None)
return self.Optimizers(generator=None, lora=optimizer)
Expand Down Expand Up @@ -320,11 +324,13 @@ def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, ext

if __name__ == '__main__':
print("Launching Script")
device = torch.device(int(os.environ.get('SLURM_LOCALID')) if 'SLURM_LOCALID' in os.environ else "cuda" if torch.cuda.is_available() else "cpu")
warpcore = WurstCore(
config_file_path=sys.argv[1] if len(sys.argv) > 1 else None,
device=torch.device(int(os.environ.get("SLURM_LOCALID")))
device=device
)
warpcore.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD

# RUN TRAINING
warpcore()
use_single_gpu = torch.cuda.device_count() == 1
warpcore(single_gpu=use_single_gpu)