diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3747972d..603a98f3 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1711,6 +1711,76 @@ def hook_train_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchD {'loss': loss.item()} ) + if self.train_config.validation_every is not None and self.step_num % self.train_config.validation_every == 0: + validation_loss = self.hook_validation_loop(self.data_loader_val) + loss_dict['validation_loss'] = validation_loss.item() + self.end_of_training_loop() return loss_dict + + def hook_validation_loop(self, batch: Union[DataLoaderBatchDTO, List[DataLoaderBatchDTO]]): + """ + Validation loop that evaluates model performance without updating weights. + Similar to training but with no gradient calculations or optimizer steps. + """ + if isinstance(batch, list): + batch_list = batch + else: + batch_list = [batch] + + total_loss = None + + # Ensure model components are in eval mode + if self.network is not None: + self.network.eval() + if self.adapter is not None: + self.adapter.eval() + if self.sd.unet is not None: + self.sd.unet.eval() + + with torch.no_grad(): + for batch in batch_list: + # Preprocess batch and get required tensors + batch = self.preprocess_batch(batch) + dtype = get_torch_dtype(self.train_config.dtype) + + # Process batch similar to training + noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) + + # Encode prompts + conditional_embeds = self.sd.encode_prompt( + conditioned_prompts, + long_prompts=self.do_long_prompts + ).to(self.device_torch, dtype=dtype) + + # Predict noise + noise_pred = self.predict_noise( + noisy_latents=noisy_latents, + timesteps=timesteps, + conditional_embeds=conditional_embeds + ) + + # Calculate loss + loss = self.calculate_loss( + noise_pred=noise_pred, + noise=noise, + noisy_latents=noisy_latents, + timesteps=timesteps, + batch=batch, + ) + + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + # Restore training mode + if self.network is not None: + self.network.train() + if self.adapter is not None: + self.adapter.train() + if self.sd.unet is not None: + self.sd.unet.train() + + return total_loss.detach() diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 3b210154..a0e8e504 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -214,7 +214,8 @@ def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=No self.named_lora = True self.snr_gos: Union[LearnableSNRGamma, None] = None self.ema: ExponentialMovingAverage = None - + + self.validation_every = self.train_config.validation_every validate_configs(self.train_config, self.model_config, self.save_config) def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): @@ -1056,15 +1057,15 @@ def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): raise ValueError(f"Unknown content_or_style {content_or_style}") # do flow matching - # if self.sd.is_flow_matching: - # u = compute_density_for_timestep_sampling( - # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] - # batch_size=batch_size, - # logit_mean=0.0, - # logit_std=1.0, - # mode_scale=1.29, - # ) - # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() + if content_or_style == 'style': + u = compute_density_for_timestep_sampling( + weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] + batch_size=batch_size, + logit_mean=-6, + logit_std=2.0, + mode_scale=1.29, + ) + timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) @@ -1701,10 +1702,10 @@ def run(self): self.before_dataset_load() # load datasets if passed in the root process if self.datasets is not None: - self.data_loader = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd) + self.data_loader, self.data_loader_val = get_dataloader_from_datasets(self.datasets, self.train_config.batch_size, self.sd, self.validation_every) if self.datasets_reg is not None: - self.data_loader_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, - self.sd) + self.data_loader_reg, self.data_loader_val_reg = get_dataloader_from_datasets(self.datasets_reg, self.train_config.batch_size, + self.sd, self.validation_every) flush() ### HOOK ### @@ -1744,6 +1745,13 @@ def run(self): dataloader_reg = None dataloader_iterator_reg = None + if self.data_loader_val is not None: + dataloader_val = self.data_loader_val + dataloader_iterator_val = iter(dataloader_val) + else: + dataloader_val = None + dataloader_iterator_val = None + # zero any gradients optimizer.zero_grad() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 15dca441..ebca5968 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -408,6 +408,9 @@ def __init__(self, **kwargs): # optimal noise pairing self.optimal_noise_pairing_samples = kwargs.get('optimal_noise_pairing_samples', 1) + # validation + self.validation_every = kwargs.get('validation_every', 1000) # Set to same as save_every by default + class ModelConfig: def __init__(self, **kwargs): @@ -666,6 +669,8 @@ def __init__(self, **kwargs): self.square_crop: bool = kwargs.get('square_crop', False) # apply same augmentations to control images. Usually want this true unless special case self.replay_transforms: bool = kwargs.get('replay_transforms', True) + # validation + self.validation_percent = kwargs.get('validation_percent', 0.1) # 10% holdout def preprocess_dataset_raw_config(raw_config: List[dict]) -> List[dict]: diff --git a/toolkit/data_loader.py b/toolkit/data_loader.py index 5285b371..87f52b45 100644 --- a/toolkit/data_loader.py +++ b/toolkit/data_loader.py @@ -12,7 +12,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torchvision import transforms -from torch.utils.data import Dataset, DataLoader, ConcatDataset +from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split from tqdm import tqdm import albumentations as A @@ -560,7 +560,8 @@ def get_dataloader_from_datasets( dataset_options, batch_size=1, sd: 'StableDiffusion' = None, -) -> DataLoader: + validation_every=None, +) -> List[DataLoader]: if dataset_options is None or len(dataset_options) == 0: return None @@ -593,6 +594,8 @@ def get_dataloader_from_datasets( concatenated_dataset = ConcatDataset(datasets) + train_dataset, validation_dataset = random_split(concatenated_dataset, [len(concatenated_dataset) * config.validation_percent, len(concatenated_dataset) - len(concatenated_dataset) * dataset_config_list[0].validation_percent],generator=torch.Generator().manual_seed(42)) + # todo build scheduler that can get buckets from all datasets that match # todo and evenly distribute reg images @@ -613,28 +616,50 @@ def dto_collation(batch: List['FileItemDTO']): dataloader_kwargs['num_workers'] = dataset_config_list[0].num_workers dataloader_kwargs['prefetch_factor'] = dataset_config_list[0].prefetch_factor + result_loaders = [] if has_buckets: # make sure they all have buckets for dataset in datasets: assert dataset.dataset_config.buckets, f"buckets not found on dataset {dataset.dataset_config.folder_path}, you either need all buckets or none" - data_loader = DataLoader( - concatenated_dataset, + train_loader = DataLoader( + train_dataset, batch_size=None, # we batch in the datasets for now drop_last=False, shuffle=True, collate_fn=dto_collation, # Use the custom collate function **dataloader_kwargs ) + result_loaders.append(train_loader) + if validation_every is not None: + validation_loader = DataLoader( + validation_dataset, + batch_size=1, # we batch in the datasets for now + drop_last=False, + shuffle=False, + collate_fn=dto_collation, # Use the custom collate function + **dataloader_kwargs + ) + result_loaders.append(validation_loader) else: data_loader = DataLoader( - concatenated_dataset, + train_dataset, batch_size=batch_size, shuffle=True, collate_fn=dto_collation, **dataloader_kwargs ) - return data_loader + result_loaders.append(data_loader) + if validation_every is not None: + validation_loader = DataLoader( + validation_dataset, + batch_size=1, + shuffle=False, + collate_fn=dto_collation, + **dataloader_kwargs + ) + result_loaders.append(validation_loader) + return result_loaders def trigger_dataloader_setup_epoch(dataloader: DataLoader):