Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: possible way of adding validation loop to training script #245

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions extensions_built_in/sd_trainer/SDTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,76 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD
{'loss': loss.item()}
)

if self.train_config.validation_every is not None and self.step_num % self.train_config.validation_every == 0:
validation_loss = self.hook_validation_loop(self.data_loader_val)
loss_dict['validation_loss'] = validation_loss.item()

self.end_of_training_loop()

return loss_dict

def hook_validation_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]):
"""
Validation loop that evaluates model performance without updating weights.
Similar to training but with no gradient calculations or optimizer steps.
"""
if isinstance(batch, list):
batch_list = batch
else:
batch_list = [batch]

total_loss = None

# Ensure model components are in eval mode
if self.network is not None:
self.network.eval()
if self.adapter is not None:
self.adapter.eval()
if self.sd.unet is not None:
self.sd.unet.eval()

with torch.no_grad():
for batch in batch_list:
# Preprocess batch and get required tensors
batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype)

# Process batch similar to training
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)

# Encode prompts
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts,
long_prompts=self.do_long_prompts
).to(self.device_torch, dtype=dtype)

# Predict noise
noise_pred = self.predict_noise(
noisy_latents=noisy_latents,
timesteps=timesteps,
conditional_embeds=conditional_embeds
)

# Calculate loss
loss = self.calculate_loss(
noise_pred=noise_pred,
noise=noise,
noisy_latents=noisy_latents,
timesteps=timesteps,
batch=batch,
)

if total_loss is None:
total_loss = loss
else:
total_loss += loss

# Restore training mode
if self.network is not None:
self.network.train()
if self.adapter is not None:
self.adapter.train()
if self.sd.unet is not None:
self.sd.unet.train()

return total_loss.detach()
34 changes: 21 additions & 13 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,8 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No
self.named_lora = True
self.snr_gos: Union[LearnableSNRGamma, None] = None
self.ema: ExponentialMovingAverage = None


self.validation_every = self.train_config.validation_every
validate_configs(self.train_config, self.model_config, self.save_config)

def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
Expand Down Expand Up @@ -1056,15 +1057,15 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'):
raise ValueError(f"Unknown content_or_style {content_or_style}")

# do flow matching
# if self.sd.is_flow_matching:
# u = compute_density_for_timestep_sampling(
# weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
# batch_size=batch_size,
# logit_mean=0.0,
# logit_std=1.0,
# mode_scale=1.29,
# )
# timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
if content_or_style == 'style':
u = compute_density_for_timestep_sampling(
weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"]
batch_size=batch_size,
logit_mean=-6,
logit_std=2.0,
mode_scale=1.29,
)
timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long()
# convert the timestep_indices to a timestep
timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices]
timesteps = torch.stack(timesteps, dim=0)
Expand Down Expand Up @@ -1701,10 +1702,10 @@ def run(self):
self.before_dataset_load()
# load datasets if passed in the root process
if self.datasets is not None:
self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd)
self.data_loader, self.data_loader_val = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd, self.validation_every)
if self.datasets_reg is not None:
self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd)
self.data_loader_reg, self.data_loader_val_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size,
self.sd, self.validation_every)

flush()
### HOOK ###
Expand Down Expand Up @@ -1744,6 +1745,13 @@ def run(self):
dataloader_reg = None
dataloader_iterator_reg = None

if self.data_loader_val is not None:
dataloader_val = self.data_loader_val
dataloader_iterator_val = iter(dataloader_val)
else:
dataloader_val = None
dataloader_iterator_val = None

# zero any gradients
optimizer.zero_grad()

Expand Down
5 changes: 5 additions & 0 deletions toolkit/config_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,9 @@ def __init__(self, **kwargs):
# optimal noise pairing
self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1)

# validation
self.validation_every = kwargs.get('validation_every', 1000) # Set to same as save_every by default


class ModelConfig:
def __init__(self, **kwargs):
Expand Down Expand Up @@ -666,6 +669,8 @@ def __init__(self, **kwargs):
self.square_crop: bool = kwargs.get('square_crop', False)
# apply same augmentations to control images. Usually want this true unless special case
self.replay_transforms: bool = kwargs.get('replay_transforms', True)
# validation
self.validation_percent = kwargs.get('validation_percent', 0.1) # 10% holdout


def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]:
Expand Down
37 changes: 31 additions & 6 deletions toolkit/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from PIL import Image
from PIL.ImageOps import exif_transpose
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split
from tqdm import tqdm
import albumentations as A

Expand Down Expand Up @@ -560,7 +560,8 @@ def get_dataloader_from_datasets(
dataset_options,
batch_size=1,
sd: 'StableDiffusion' = None,
) -> DataLoader:
validation_every=None,
) -> List[DataLoader]:
if dataset_options is None or len(dataset_options) == 0:
return None

Expand Down Expand Up @@ -593,6 +594,8 @@ def get_dataloader_from_datasets(

concatenated_dataset = ConcatDataset(datasets)

train_dataset, validation_dataset = random_split(concatenated_dataset, [len(concatenated_dataset) * config.validation_percent, len(concatenated_dataset) - len(concatenated_dataset) * dataset_config_list[0].validation_percent],generator=torch.Generator().manual_seed(42))

# todo build scheduler that can get buckets from all datasets that match
# todo and evenly distribute reg images

Expand All @@ -613,28 +616,50 @@ def dto_collation(batch: List['FileItemDTO']):
dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers
dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor

result_loaders = []
if has_buckets:
# make sure they all have buckets
for dataset in datasets:
assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none"

data_loader = DataLoader(
concatenated_dataset,
train_loader = DataLoader(
train_dataset,
batch_size=None, # we batch in the datasets for now
drop_last=False,
shuffle=True,
collate_fn=dto_collation, # Use the custom collate function
**dataloader_kwargs
)
result_loaders.append(train_loader)
if validation_every is not None:
validation_loader = DataLoader(
validation_dataset,
batch_size=1, # we batch in the datasets for now
drop_last=False,
shuffle=False,
collate_fn=dto_collation, # Use the custom collate function
**dataloader_kwargs
)
result_loaders.append(validation_loader)
else:
data_loader = DataLoader(
concatenated_dataset,
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=dto_collation,
**dataloader_kwargs
)
return data_loader
result_loaders.append(data_loader)
if validation_every is not None:
validation_loader = DataLoader(
validation_dataset,
batch_size=1,
shuffle=False,
collate_fn=dto_collation,
**dataloader_kwargs
)
result_loaders.append(validation_loader)
return result_loaders


def trigger_dataloader_setup_epoch(dataloader: DataLoader):
Expand Down