Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add low VRAM mode, CPU-only mode + image pre-loading fix #407

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 29 additions & 23 deletions scripts/demo/streamlit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@
from torchvision import transforms
from torchvision.utils import make_grid, save_image

# Additional options for lower end setups
USE_CUDA = True # Set this to `False`` if you want to force CPU-only mode
lowvram_mode = False # Set to `True` to enable low VRAM mode
# (low VRAM mode = float32 => float16, tested to work great on RTX 3060 w/ 12GB VRAM)

device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")

@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
Expand All @@ -59,35 +65,29 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state


def load_model(model):
model.cuda()


lowvram_mode = False

device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")
model.to(device)

def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode


def initial_model_load(model):
device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")
global lowvram_mode
if lowvram_mode:
model.model.half()
model.model.half().to(device)
else:
model.cuda()
model.to(device)
return model


def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
if lowvram_mode or not USE_CUDA:
model.cpu() # Move model to CPU to free GPU memory
torch.cuda.empty_cache()


def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)

Expand Down Expand Up @@ -497,13 +497,14 @@ def load_img(
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img


def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
device = torch.device("cuda" if USE_CUDA and torch.cuda.is_available() else "cpu")

init_image = load_img(key=key).to(device) # Use `to(device)` to move to the correct device
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)

return init_image


def do_sample(
model,
sampler,
Expand All @@ -529,9 +530,9 @@ def do_sample(
st.text("Sampling")

outputs = st.empty()
precision_scope = autocast
precision_scope = autocast if USE_CUDA else lambda device: device
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope("cuda" if USE_CUDA else "cpu"):
with model.ema_scope():
if T is not None:
num_samples = [num_samples, T]
Expand Down Expand Up @@ -754,7 +755,7 @@ def do_img2img(
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope("cuda" if USE_CUDA else "cpu"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
Expand Down Expand Up @@ -783,20 +784,25 @@ def do_img2img(

noise = torch.randn_like(z)

sigmas = sampler.discretization(sampler.num_steps).cuda()
# Move sigmas to the correct device (CUDA or CPU)
sigmas = sampler.discretization(sampler.num_steps).to(device)
sigma = sigmas[0]

st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")

# Offset noise level handling
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
torch.randn(z.shape[0], device=device), z.ndim
)

# Add noise handling
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = z + noise * append_dims(sigma, z.ndim).to(device)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
) # Hardcoded to DDPM-like scaling; generalize if needed
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)

Expand Down
6 changes: 6 additions & 0 deletions scripts/demo/video_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@

if mode == "img2vid":
img = load_img_for_prediction(W, H)

# Check if the image is None and use a dummy image if necessary
if img is None:
st.warning("No image provided. Using a dummy tensor for initialization.")
img = torch.zeros([1, 3, H, W]).to(device) # Dummy tensor

if "sv3d" in version:
cond_aug = 1e-5
else:
Expand Down