Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: now we can upload images during train into tb #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 95 additions & 41 deletions jobs/process/BaseSDTrainProcess.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,107 @@
import copy
import gc
import glob
import hashlib
import inspect
import json
import os
import random
import re
import shutil
from collections import OrderedDict
import os
import re
from typing import Union, List, Optional
from hashlib import md5
from typing import List, Optional, Union

import diffusers
import numpy as np
import yaml
from diffusers import T2IAdapter, ControlNetModel
from diffusers.training_utils import compute_density_for_timestep_sampling
from safetensors.torch import save_file, load_file
# from lycoris.config import PRESET
from torch.utils.data import DataLoader
import torch
import torch.backends.cuda
import torchvision.transforms as transforms
import transformers
import yaml
from accelerate import Accelerator
from diffusers import ControlNetModel, FluxTransformer2DModel, T2IAdapter
from diffusers.training_utils import compute_density_for_timestep_sampling
from huggingface_hub import HfApi, Repository, interpreter_login
from huggingface_hub.utils import HfFolder

from jobs.process import BaseTrainProcess
from PIL import Image
from safetensors.torch import load_file, save_file
from toolkit.accelerator import get_accelerator
from toolkit.basic import value_map
from toolkit.clip_vision_adapter import ClipVisionAdapter
from toolkit.config_modules import (
AdapterConfig,
DatasetConfig,
DecoratorConfig,
EmbeddingConfig,
GenerateImageConfig,
GuidanceConfig,
LoggingConfig,
ModelConfig,
NetworkConfig,
SampleConfig,
SaveConfig,
TrainConfig,
preprocess_dataset_raw_config,
validate_configs,
)
from toolkit.custom_adapter import CustomAdapter
from toolkit.data_loader import get_dataloader_from_datasets, trigger_dataloader_setup_epoch
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.data_loader import (
get_dataloader_from_datasets,
trigger_dataloader_setup_epoch,
)
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.ema import ExponentialMovingAverage
from toolkit.embedding import Embedding
from toolkit.image_utils import show_tensors, show_latents, reduce_contrast
from toolkit.image_utils import reduce_contrast, show_latents, show_tensors
from toolkit.ip_adapter import IPAdapter
from toolkit.logging import create_logger
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lorm import convert_diffusers_unet_to_lorm, count_parameters, print_lorm_extract_details, \
lorm_ignore_if_contains, lorm_parameter_threshold, LORM_TARGET_REPLACE_MODULE
from toolkit.lorm import (
LORM_TARGET_REPLACE_MODULE,
convert_diffusers_unet_to_lorm,
count_parameters,
lorm_ignore_if_contains,
lorm_parameter_threshold,
print_lorm_extract_details,
)
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.metadata import (
add_base_model_info_to_meta,
get_meta_for_safetensors,
load_metadata_from_safetensors,
parse_metadata_from_safetensors,
)
from toolkit.models.decorator import Decorator
from toolkit.network_mixins import Network
from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.print import print_acc
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.reference_adapter import ReferenceAdapter
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
load_ip_adapter_model, load_custom_adapter_model

from toolkit.saving import (
load_custom_adapter_model,
load_ip_adapter_model,
load_t2i_model,
save_ip_adapter_from_diffusers,
save_t2i_from_diffusers,
)
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.train_tools import (
LearnableSNRGamma,
apply_learnable_snr_gos,
apply_snr_weight,
get_torch_dtype,
)

from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors, add_base_model_info_to_meta, \
parse_metadata_from_safetensors
from toolkit.train_tools import get_torch_dtype, LearnableSNRGamma, apply_learnable_snr_gos, apply_snr_weight
import gc

# from lycoris.config import PRESET
from torch.utils.data import DataLoader
from tqdm import tqdm

from toolkit.config_modules import SaveConfig, LoggingConfig, SampleConfig, NetworkConfig, TrainConfig, ModelConfig, \
GenerateImageConfig, EmbeddingConfig, DatasetConfig, preprocess_dataset_raw_config, AdapterConfig, GuidanceConfig, validate_configs, \
DecoratorConfig
from toolkit.logging import create_logger
from diffusers import FluxTransformer2DModel
from toolkit.accelerator import get_accelerator
from toolkit.print import print_acc
from accelerate import Accelerator
import transformers
import diffusers
import hashlib

def flush():
torch.cuda.empty_cache()
Expand Down Expand Up @@ -260,13 +295,8 @@ def sample(self, step=None, is_first=False):
if sample_config.walk_seed:
current_seed = start_seed + i

step_num = ''
if step is not None:
# zero-pad 9 digits
step_num = f"_{str(step).zfill(9)}"

step_num = f"_{str(step).zfill(9)}" if step is not None else ''
filename = f"[time]_{step_num}_[count].{self.sample_config.ext}"

output_path = os.path.join(sample_folder, filename)

prompt = sample_config.prompts[i]
Expand Down Expand Up @@ -322,6 +352,30 @@ def sample(self, step=None, is_first=False):

if self.ema is not None:
self.ema.train()
# Log generated images to TensorBoard

if step is not None and hasattr(self, 'writer'):
step_num = f"_{step:09d}"
pattern = os.path.join(sample_folder, f"*{step_num}_*.{self.sample_config.ext}")
image_paths = glob.glob(pattern)
image_paths.sort() # Maintain order using filename sorting

for idx, (img_path, config) in enumerate(zip(image_paths, gen_img_config_list)):
img = Image.open(img_path)
img_tensor = transforms.ToTensor()(img)

# Generate a truncated, safe tag for TensorBoard
prompt_text = config.prompt
sanitized_prompt = re.sub(r'[^a-zA-Z0-9]+', '_', prompt_text)[:50] # Truncate and sanitize
hash_prompt = md5(prompt_text.encode()).hexdigest()[:8] # Short hash for uniqueness

tag = f"samples/{idx}_{sanitized_prompt}_{hash_prompt}"

self.writer.add_image(tag, img_tensor, global_step=step)

# Optionally, log the full prompt text separately
for idx, config in enumerate(gen_img_config_list):
self.writer.add_text(f"prompt/{idx}", config.prompt, global_step=step)

def update_training_metadata(self):
o_dict = OrderedDict({
Expand Down
55 changes: 36 additions & 19 deletions jobs/process/TrainESRGANProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,31 @@
from collections import OrderedDict
from typing import List, Optional

import numpy as np
import torch
from diffusers import AutoencoderKL
from jobs.process import BaseTrainProcess
from PIL import Image
from PIL.ImageOps import exif_transpose

from safetensors.torch import load_file, save_file
from toolkit.basic import flush
from toolkit.models.RRDB import RRDBNet as ESRGAN, esrgan_safetensors_keys
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
import torch
from torch import nn
from torchvision.transforms import transforms

from jobs.process import BaseTrainProcess
from toolkit.data_loader import AugmentedImageDataset
from toolkit.esrgan_utils import convert_state_dict_to_basicsr, convert_basicsr_state_dict_to_save_format
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
from toolkit.esrgan_utils import (
convert_basicsr_state_dict_to_save_format,
convert_state_dict_to_basicsr,
)
from toolkit.losses import ComparativeTotalVariation, PatternLoss, get_gradient_penalty
from toolkit.metadata import get_meta_for_safetensors
from toolkit.models.RRDB import RRDBNet as ESRGAN
from toolkit.models.RRDB import esrgan_safetensors_keys
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
from torchvision.transforms import transforms
from tqdm import tqdm
import time
import numpy as np

from .models.vgg19_critic import Critic

IMAGE_TRANSFORMS = transforms.Compose(
Expand Down Expand Up @@ -289,7 +291,6 @@ def sample(self, step=None, batch: Optional[List[torch.Tensor]] = None):
def process_and_save(img, target_img, save_path):
img = img.to(self.device, dtype=self.esrgan_dtype)
output = self.model(img)
# output = (output / 2 + 0.5).clamp(0, 1)
output = output.clamp(0, 1)
img = img.clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
Expand Down Expand Up @@ -330,6 +331,9 @@ def process_and_save(img, target_img, save_path):

output_img.save(save_path)

# Return the stacked image for TensorBoard logging
return output_img

with torch.no_grad():
for i, img_url in enumerate(self.sample_sources):
img = exif_transpose(Image.open(img_url))
Expand All @@ -345,8 +349,6 @@ def process_and_save(img, target_img, save_path):
# downscale the image input
img = img.resize((self.resolution, self.resolution), resample=Image.BICUBIC)

# downscale the image input

img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.esrgan_dtype)
img = img

Expand All @@ -358,7 +360,15 @@ def process_and_save(img, target_img, save_path):
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(img, target_image, os.path.join(sample_folder, filename))
output_img = process_and_save(img, target_image, os.path.join(sample_folder, filename))

# Log to TensorBoard
if step is not None and hasattr(self, 'writer'):
import torchvision.transforms as transforms
# Convert PIL image to tensor
img_tensor = transforms.ToTensor()(output_img)
# Log the stacked image (input + output + target) to TensorBoard
self.writer.add_image(f'sample_{i}', img_tensor, global_step=step)

if batch is not None:
batch_targets = batch[0].detach()
Expand All @@ -374,7 +384,14 @@ def process_and_save(img, target_img, save_path):
# zero-pad 2 digits
i_str = str(i).zfill(2)
filename = f"{seconds_since_epoch}{step_num}_{i_str}.jpg"
process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))
output_img = process_and_save(batch_inputs[i], batch_targets[i], os.path.join(batch_sample_folder, filename))

# Log to TensorBoard
if step is not None and hasattr(self, 'writer'):
# Convert PIL image to tensor
img_tensor = transforms.ToTensor()(output_img)
# Log the stacked image (input + output + target) to TensorBoard
self.writer.add_image(f'batch_sample_{i}', img_tensor, global_step=step)

self.model.train()

Expand Down
35 changes: 20 additions & 15 deletions jobs/process/TrainVAEProcess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,28 @@
import time
from collections import OrderedDict

from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import save_file, load_file
from torch.utils.data import DataLoader, ConcatDataset
import lpips
import numpy as np
import torch
from torch import nn
from torchvision.transforms import transforms

from diffusers import AutoencoderKL
from jobs.process import BaseTrainProcess
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
from PIL import Image
from PIL.ImageOps import exif_transpose
from safetensors.torch import load_file, save_file
from toolkit.data_loader import ImageDataset
from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
from toolkit.image_utils import show_tensors
from toolkit.kohya_model_util import convert_diffusers_back_to_ldm, load_vae
from toolkit.losses import ComparativeTotalVariation, PatternLoss, get_gradient_penalty
from toolkit.metadata import get_meta_for_safetensors
from toolkit.optimizer import get_optimizer
from toolkit.style import get_style_model_and_losses
from toolkit.train_tools import get_torch_dtype
from diffusers import AutoencoderKL
from torch import nn
from torch.utils.data import ConcatDataset, DataLoader
from torchvision.transforms import Resize, transforms
from tqdm import tqdm
import time
import numpy as np

from .models.vgg19_critic import Critic
from torchvision.transforms import Resize
import lpips

IMAGE_TRANSFORMS = transforms.Compose(
[
Expand Down Expand Up @@ -310,6 +308,13 @@ def sample(self, step=None):
filename = f"{seconds_since_epoch}{step_num}_{i_str}.png"
output_img.save(os.path.join(sample_folder, filename))

# Log to TensorBoard
if step is not None and hasattr(self, 'writer'):
# Convert PIL image to tensor
img_tensor = transforms.ToTensor()(output_img)
# Log the combined image (input + decoded) to TensorBoard
self.writer.add_image(f'sample_{i}', img_tensor, global_step=step)

def load_vae(self):
path_to_load = self.vae_path
# see if we have a checkpoint in out output to resume from
Expand Down