From 3c2eaf48db6902ae2135853f828359bb823f6e44 Mon Sep 17 00:00:00 2001 From: FlyingFathead Date: Sat, 7 Sep 2024 23:11:12 +0300 Subject: [PATCH] Fixed low VRAM mode and CPU option, added image check --- scripts/demo/streamlit_helpers.py | 52 +++++++++++++++++-------------- scripts/demo/video_sampling.py | 6 ++++ 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index e79fc193..759cc94e 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -40,6 +40,12 @@ from torchvision import transforms from torchvision.utils import make_grid, save_image +# Additional options for lower end setups +USE_CUDA = True # Set this to `False`` if you want to force CPU-only mode +lowvram_mode = False # Set to `True` to enable low VRAM mode +# (low VRAM mode = float32 => float16, tested to work great on RTX 3060 w/ 12GB VRAM) + +device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): @@ -59,35 +65,29 @@ def init_st(version_dict, load_ckpt=True, load_filter=True): state["filter"] = DeepFloydDataFiltering(verbose=False) return state - def load_model(model): - model.cuda() - - -lowvram_mode = False - + device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") + model.to(device) def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode - def initial_model_load(model): + device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") global lowvram_mode if lowvram_mode: - model.model.half() + model.model.half().to(device) else: - model.cuda() + model.to(device) return model - def unload_model(model): global lowvram_mode - if lowvram_mode: - model.cpu() + if lowvram_mode or not USE_CUDA: + model.cpu() # Move model to CPU to free GPU memory torch.cuda.empty_cache() - def load_model_from_config(config, ckpt=None, verbose=True): model = instantiate_from_config(config.model) @@ -497,13 +497,14 @@ def load_img( st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img - def get_init_img(batch_size=1, key=None): - init_image = load_img(key=key).cuda() + device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu") + + init_image = load_img(key=key).to(device) # Use `to(device)` to move to the correct device init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) + return init_image - def do_sample( model, sampler, @@ -529,9 +530,9 @@ def do_sample( st.text("Sampling") outputs = st.empty() - precision_scope = autocast + precision_scope = autocast if USE_CUDA else lambda device: device with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope("cuda" if USE_CUDA else "cpu"): with model.ema_scope(): if T is not None: num_samples = [num_samples, T] @@ -754,7 +755,7 @@ def do_img2img( outputs = st.empty() precision_scope = autocast with torch.no_grad(): - with precision_scope("cuda"): + with precision_scope("cuda" if USE_CUDA else "cpu"): with model.ema_scope(): load_model(model.conditioner) batch, batch_uc = get_batch( @@ -783,20 +784,25 @@ def do_img2img( noise = torch.randn_like(z) - sigmas = sampler.discretization(sampler.num_steps).cuda() + # Move sigmas to the correct device (CUDA or CPU) + sigmas = sampler.discretization(sampler.num_steps).to(device) sigma = sigmas[0] st.info(f"all sigmas: {sigmas}") st.info(f"noising sigma: {sigma}") + + # Offset noise level handling if offset_noise_level > 0.0: noise = noise + offset_noise_level * append_dims( - torch.randn(z.shape[0], device=z.device), z.ndim + torch.randn(z.shape[0], device=device), z.ndim ) + + # Add noise handling if add_noise: - noised_z = z + noise * append_dims(sigma, z.ndim).cuda() + noised_z = z + noise * append_dims(sigma, z.ndim).to(device) noised_z = noised_z / torch.sqrt( 1.0 + sigmas[0] ** 2.0 - ) # Note: hardcoded to DDPM-like scaling. need to generalize later. + ) # Hardcoded to DDPM-like scaling; generalize if needed else: noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0) diff --git a/scripts/demo/video_sampling.py b/scripts/demo/video_sampling.py index 1f4fcfc4..786baa11 100644 --- a/scripts/demo/video_sampling.py +++ b/scripts/demo/video_sampling.py @@ -180,6 +180,12 @@ if mode == "img2vid": img = load_img_for_prediction(W, H) + + # Check if the image is None and use a dummy image if necessary + if img is None: + st.warning("No image provided. Using a dummy tensor for initialization.") + img = torch.zeros([1, 3, H, W]).to(device) # Dummy tensor + if "sv3d" in version: cond_aug = 1e-5 else: