diff --git a/environment.yaml b/environment.yaml index f41c3cada..c369cefae 100644 --- a/environment.yaml +++ b/environment.yaml @@ -10,6 +10,7 @@ dependencies: - torchvision=0.12.0 - numpy=1.20.3 - pip: + - safetensors==0.3.1 - albumentations==0.4.3 - opencv-python==4.1.2.30 - pudb==2019.2 diff --git a/optimizedSD/optimized_txt2img.py b/optimizedSD/optimized_txt2img.py index c82918240..3fef33370 100644 --- a/optimizedSD/optimized_txt2img.py +++ b/optimizedSD/optimized_txt2img.py @@ -15,21 +15,58 @@ from ldm.util import instantiate_from_config from optimUtils import split_weighted_subprompts, logger from transformers import logging +import os +import safetensors.torch # from samplers import CompVisDenoiser logging.set_verbosity_error() +checkpoint_dict_replacements = { + 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', + 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', + 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', +} + def chunk(it, size): it = iter(it) return iter(lambda: tuple(islice(it, size)), ()) + +def transform_checkpoint_dict_key(k): + for text, replacement in checkpoint_dict_replacements.items(): + if k.startswith(text): + k = replacement + k[len(text):] + + return k + +def get_state_dict_from_checkpoint(pl_sd): + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) + + sd = {} + for k, v in pl_sd.items(): + new_key = transform_checkpoint_dict_key(k) + if new_key is not None: + sd[new_key] = v + pl_sd.clear() + pl_sd.update(sd) + + return pl_sd + +# The code for loading model with safetensors was taken from https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/modules/sd_models.py def load_model_from_config(ckpt, verbose=False): + print(f"Loading model from {ckpt}") - pl_sd = torch.load(ckpt, map_location="cpu") - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - sd = pl_sd["state_dict"] + _, extension = os.path.splitext(ckpt) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(ckpt) + else: + pl_sd = torch.load(ckpt, map_location="cpu") + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + #sd = pl_sd["state_dict"] + sd = get_state_dict_from_checkpoint(pl_sd) return sd