diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e30ddae0..3bf517b4 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1,72 +1,107 @@ import copy +import gc import glob +import hashlib import inspect import json +import os import random +import re import shutil from collections import OrderedDict -import os -import re -from typing import Union, List, Optional +from hashlib import md5 +from typing import List, Optional, Union +import diffusers import numpy as np -import yaml -from diffusers import T2IAdapter, ControlNetModel -from diffusers.training_utils import compute_density_for_timestep_sampling -from safetensors.torch import save_file, load_file -# from lycoris.config import PRESET -from torch.utils.data import DataLoader import torch import torch.backends.cuda +import torchvision.transforms as transforms +import transformers +import yaml +from accelerate import Accelerator +from diffusers import ControlNetModel, FluxTransformer2DModel, T2IAdapter +from diffusers.training_utils import compute_density_for_timestep_sampling from huggingface_hub import HfApi, Repository, interpreter_login from huggingface_hub.utils import HfFolder - +from jobs.process import BaseTrainProcess +from PIL import Image +from safetensors.torch import load_file, save_file +from toolkit.accelerator import get_accelerator from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter +from toolkit.config_modules import ( + AdapterConfig, + DatasetConfig, + DecoratorConfig, + EmbeddingConfig, + GenerateImageConfig, + GuidanceConfig, + LoggingConfig, + ModelConfig, + NetworkConfig, + SampleConfig, + SaveConfig, + TrainConfig, + preprocess_dataset_raw_config, + validate_configs, +) from toolkit.custom_adapter import CustomAdapter -from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch -from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO +from toolkit.data_loader import ( + get_dataloader_from_datasets, + trigger_dataloader_setup_epoch, +) +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO from toolkit.ema import ExponentialMovingAverage from toolkit.embedding import Embedding -from toolkit.image_utils import show_tensors, show_latents, reduce_contrast +from toolkit.image_utils import reduce_contrast, show_latents, show_tensors from toolkit.ip_adapter import IPAdapter +from toolkit.logging import create_logger from toolkit.lora_special import LoRASpecialNetwork -from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \ - lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE +from toolkit.lorm import ( + LORM_TARGET_REPLACE_MODULE, + convert_diffusers_unet_to_lorm, + count_parameters, + lorm_ignore_if_contains, + lorm_parameter_threshold, + print_lorm_extract_details, +) from toolkit.lycoris_special import LycorisSpecialNetwork +from toolkit.metadata import ( + add_base_model_info_to_meta, + get_meta_for_safetensors, + load_metadata_from_safetensors, + parse_metadata_from_safetensors, +) from toolkit.models.decorator import Decorator from toolkit.network_mixins import Network from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT +from toolkit.print import print_acc from toolkit.progress_bar import ToolkitProgressBar from toolkit.reference_adapter import ReferenceAdapter from toolkit.sampler import get_sampler -from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ - load_ip_adapter_model, load_custom_adapter_model - +from toolkit.saving import ( + load_custom_adapter_model, + load_ip_adapter_model, + load_t2i_model, + save_ip_adapter_from_diffusers, + save_t2i_from_diffusers, +) from toolkit.scheduler import get_lr_scheduler from toolkit.sd_device_states_presets import get_train_sd_device_state_preset from toolkit.stable_diffusion_model import StableDiffusion +from toolkit.train_tools import ( + LearnableSNRGamma, + apply_learnable_snr_gos, + apply_snr_weight, + get_torch_dtype, +) -from jobs.process import BaseTrainProcess -from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \ - parse_metadata_from_safetensors -from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight -import gc - +# from lycoris.config import PRESET +from torch.utils.data import DataLoader from tqdm import tqdm -from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \ - GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \ - DecoratorConfig -from toolkit.logging import create_logger -from diffusers import FluxTransformer2DModel -from toolkit.accelerator import get_accelerator -from toolkit.print import print_acc -from accelerate import Accelerator -import transformers -import diffusers -import hashlib def flush(): torch.cuda.empty_cache() @@ -260,13 +295,8 @@ def sample(self, step=None, is_first=False): if sample_config.walk_seed: current_seed = start_seed + i - step_num = '' - if step is not None: - # zero-pad 9 digits - step_num = f"_{str(step).zfill(9)}" - + step_num = f"_{str(step).zfill(9)}" if step is not None else '' filename = f"[time]_{step_num}_[count].{self.sample_config.ext}" - output_path = os.path.join(sample_folder, filename) prompt = sample_config.prompts[i] @@ -322,6 +352,30 @@ def sample(self, step=None, is_first=False): if self.ema is not None: self.ema.train() + # Log generated images to TensorBoard + + if step is not None and hasattr(self, 'writer'): + step_num = f"_{step:09d}" + pattern = os.path.join(sample_folder, f"*{step_num}_*.{self.sample_config.ext}") + image_paths = glob.glob(pattern) + image_paths.sort() # Maintain order using filename sorting + + for idx, (img_path, config) in enumerate(zip(image_paths, gen_img_config_list)): + img = Image.open(img_path) + img_tensor = transforms.ToTensor()(img) + + # Generate a truncated, safe tag for TensorBoard + prompt_text = config.prompt + sanitized_prompt = re.sub(r'[^a-zA-Z0-9]+', '_', prompt_text)[:50] # Truncate and sanitize + hash_prompt = md5(prompt_text.encode()).hexdigest()[:8] # Short hash for uniqueness + + tag = f"samples/{idx}_{sanitized_prompt}_{hash_prompt}" + + self.writer.add_image(tag, img_tensor, global_step=step) + + # Optionally, log the full prompt text separately + for idx, config in enumerate(gen_img_config_list): + self.writer.add_text(f"prompt/{idx}", config.prompt, global_step=step) def update_training_metadata(self): o_dict = OrderedDict({ diff --git a/jobs/process/TrainESRGANProcess.py b/jobs/process/TrainESRGANProcess.py index 4ff3a69d..cadc52a5 100644 --- a/jobs/process/TrainESRGANProcess.py +++ b/jobs/process/TrainESRGANProcess.py @@ -5,29 +5,31 @@ from collections import OrderedDict from typing import List, Optional +import numpy as np +import torch +from diffusers import AutoencoderKL +from jobs.process import BaseTrainProcess from PIL import Image from PIL.ImageOps import exif_transpose - +from safetensors.torch import load_file, save_file from toolkit.basic import flush -from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys -from safetensors.torch import save_file, load_file -from torch.utils.data import DataLoader, ConcatDataset -import torch -from torch import nn -from torchvision.transforms import transforms - -from jobs.process import BaseTrainProcess from toolkit.data_loader import AugmentedImageDataset -from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.esrgan_utils import ( + convert_basicsr_state_dict_to_save_format, + convert_state_dict_to_basicsr, +) +from toolkit.losses import ComparativeTotalVariation, PatternLoss, get_gradient_penalty from toolkit.metadata import get_meta_for_safetensors +from toolkit.models.RRDB import RRDBNet as ESRGAN +from toolkit.models.RRDB import esrgan_safetensors_keys from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype -from diffusers import AutoencoderKL +from torch import nn +from torch.utils.data import ConcatDataset, DataLoader +from torchvision.transforms import transforms from tqdm import tqdm -import time -import numpy as np + from .models.vgg19_critic import Critic IMAGE_TRANSFORMS = transforms.Compose( @@ -289,7 +291,6 @@ def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None): def process_and_save(img, target_img, save_path): img = img.to(self.device, dtype=self.esrgan_dtype) output = self.model(img) - # output = (output / 2 + 0.5).clamp(0, 1) output = output.clamp(0, 1) img = img.clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 @@ -330,6 +331,9 @@ def process_and_save(img, target_img, save_path): output_img.save(save_path) + # Return the stacked image for TensorBoard logging + return output_img + with torch.no_grad(): for i, img_url in enumerate(self.sample_sources): img = exif_transpose(Image.open(img_url)) @@ -345,8 +349,6 @@ def process_and_save(img, target_img, save_path): # downscale the image input img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC) - # downscale the image input - img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype) img = img @@ -358,7 +360,15 @@ def process_and_save(img, target_img, save_path): # zero-pad 2 digits i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" - process_and_save(img, target_image, os.path.join(sample_folder, filename)) + output_img = process_and_save(img, target_image, os.path.join(sample_folder, filename)) + + # Log to TensorBoard + if step is not None and hasattr(self, 'writer'): + import torchvision.transforms as transforms + # Convert PIL image to tensor + img_tensor = transforms.ToTensor()(output_img) + # Log the stacked image (input + output + target) to TensorBoard + self.writer.add_image(f'sample_{i}', img_tensor, global_step=step) if batch is not None: batch_targets = batch[0].detach() @@ -374,7 +384,14 @@ def process_and_save(img, target_img, save_path): # zero-pad 2 digits i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg" - process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) + output_img = process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename)) + + # Log to TensorBoard + if step is not None and hasattr(self, 'writer'): + # Convert PIL image to tensor + img_tensor = transforms.ToTensor()(output_img) + # Log the stacked image (input + output + target) to TensorBoard + self.writer.add_image(f'batch_sample_{i}', img_tensor, global_step=step) self.model.train() diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index fb6536cd..e7b1a543 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -5,30 +5,28 @@ import time from collections import OrderedDict -from PIL import Image -from PIL.ImageOps import exif_transpose -from safetensors.torch import save_file, load_file -from torch.utils.data import DataLoader, ConcatDataset +import lpips +import numpy as np import torch -from torch import nn -from torchvision.transforms import transforms - +from diffusers import AutoencoderKL from jobs.process import BaseTrainProcess -from toolkit.image_utils import show_tensors -from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm +from PIL import Image +from PIL.ImageOps import exif_transpose +from safetensors.torch import load_file, save_file from toolkit.data_loader import ImageDataset -from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss +from toolkit.image_utils import show_tensors +from toolkit.kohya_model_util import convert_diffusers_back_to_ldm, load_vae +from toolkit.losses import ComparativeTotalVariation, PatternLoss, get_gradient_penalty from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype -from diffusers import AutoencoderKL +from torch import nn +from torch.utils.data import ConcatDataset, DataLoader +from torchvision.transforms import Resize, transforms from tqdm import tqdm -import time -import numpy as np + from .models.vgg19_critic import Critic -from torchvision.transforms import Resize -import lpips IMAGE_TRANSFORMS = transforms.Compose( [ @@ -310,6 +308,13 @@ def sample(self, step=None): filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" output_img.save(os.path.join(sample_folder, filename)) + # Log to TensorBoard + if step is not None and hasattr(self, 'writer'): + # Convert PIL image to tensor + img_tensor = transforms.ToTensor()(output_img) + # Log the combined image (input + decoded) to TensorBoard + self.writer.add_image(f'sample_{i}', img_tensor, global_step=step) + def load_vae(self): path_to_load = self.vae_path # see if we have a checkpoint in out output to resume from