diff --git a/kandinsky2/kandinsky2_1_model.py b/kandinsky2/kandinsky2_1_model.py index 2ae2ddd..ab18dfa 100644 --- a/kandinsky2/kandinsky2_1_model.py +++ b/kandinsky2/kandinsky2_1_model.py @@ -30,6 +30,8 @@ def __init__( ): self.config = config self.device = device + if device != "cuda": + self.config["model_config"]["use_fp16"] = False self.use_fp16 = self.config["model_config"]["use_fp16"] self.task_type = task_type self.clip_image_size = config["clip_image_size"] @@ -54,7 +56,7 @@ def __init__( clip_mean, clip_std, ) - self.prior.load_state_dict(torch.load(prior_path), strict=False) + self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False) if self.use_fp16: self.prior = self.prior.half() self.text_encoder = TextEncoder(**self.config["text_enc_params"]) @@ -88,7 +90,7 @@ def __init__( self.config["model_config"]["cache_text_emb"] = True self.model = create_model(**self.config["model_config"]) - self.model.load_state_dict(torch.load(model_path)) + self.model.load_state_dict(torch.load(model_path, map_location='cpu')) if self.use_fp16: self.model.convert_to_fp16() self.image_encoder = self.image_encoder.half() @@ -261,12 +263,14 @@ def denoised_fun(x): model=model_fn, old_diffusion=diffusion, schedule="linear", + device=self.device, ) elif sampler == "plms_sampler": sampler = PLMSSampler( model=model_fn, old_diffusion=diffusion, schedule="linear", + device=self.device, ) else: raise ValueError("Only ddim_sampler and plms_sampler is available") diff --git a/kandinsky2/model/gaussian_diffusion.py b/kandinsky2/model/gaussian_diffusion.py index b5449e1..1a8d2b0 100644 --- a/kandinsky2/model/gaussian_diffusion.py +++ b/kandinsky2/model/gaussian_diffusion.py @@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps] while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape) diff --git a/kandinsky2/model/samplers.py b/kandinsky2/model/samplers.py index 0b4db1d..16f8ff2 100644 --- a/kandinsky2/model/samplers.py +++ b/kandinsky2/model/samplers.py @@ -66,17 +66,18 @@ def extract_into_tensor(a, t, x_shape): class DDIMSampler(object): - def __init__(self, model, old_diffusion, schedule="linear", **kwargs): + def __init__(self, model, old_diffusion, schedule="linear", device="cuda", **kwargs): super().__init__() self.model = model self.old_diffusion = old_diffusion self.ddpm_num_timesteps = 1000 self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + attr = attr.to(dtype=torch.float32).to(torch.device(self.device)) setattr(self, name, attr) def make_schedule( @@ -98,7 +99,7 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda") + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) self.register_buffer( "betas", to_torch(torch.from_numpy(self.old_diffusion.betas)) @@ -223,10 +224,9 @@ def ddim_sampling( unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): - device = "cuda" b = shape[0] if x_T is None: - img = torch.randn(shape, device=device) + img = torch.randn(shape, device=self.device) else: img = x_T @@ -258,7 +258,7 @@ def ddim_sampling( for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) + ts = torch.full((b,), step, device=self.device, dtype=torch.long) outs = self.p_sample_ddim( img, @@ -332,17 +332,18 @@ def p_sample_ddim( class PLMSSampler(object): - def __init__(self, model, old_diffusion, schedule="linear", **kwargs): + def __init__(self, model, old_diffusion, schedule="linear", device="cuda", **kwargs): super().__init__() self.model = model self.old_diffusion = old_diffusion self.ddpm_num_timesteps = 1000 self.schedule = schedule + self.device = device def register_buffer(self, name, attr): if type(attr) == torch.Tensor: - if attr.device != torch.device("cuda"): - attr = attr.to(torch.device("cuda")) + if attr.device != torch.device(self.device): + attr = attr.to(dtype=torch.float32).to(torch.device(self.device)) setattr(self, name, attr) def make_schedule( @@ -366,7 +367,7 @@ def make_schedule( assert ( alphas_cumprod.shape[0] == self.ddpm_num_timesteps ), "alphas have to be defined for each timestep" - to_torch = lambda x: x.clone().detach().to(torch.float32).to("cuda") + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.device) self.register_buffer( "betas", to_torch(torch.from_numpy(self.old_diffusion.betas)) @@ -492,10 +493,9 @@ def plms_sampling( unconditional_guidance_scale=1.0, unconditional_conditioning=None, ): - device = "cuda" b = shape[0] if x_T is None: - img = torch.randn(shape, device=device) + img = torch.randn(shape, device=self.device) else: img = x_T @@ -529,11 +529,11 @@ def plms_sampling( for i, step in enumerate(iterator): index = total_steps - i - 1 - ts = torch.full((b,), step, device=device, dtype=torch.long) + ts = torch.full((b,), step, device=self.device, dtype=torch.long) ts_next = torch.full( (b,), time_range[min(i + 1, len(time_range) - 1)], - device=device, + device=self.device, dtype=torch.long, ) diff --git a/kandinsky2/model/utils.py b/kandinsky2/model/utils.py index c79aad9..0b4009f 100644 --- a/kandinsky2/model/utils.py +++ b/kandinsky2/model/utils.py @@ -15,7 +15,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): dimension equal to the length of timesteps. :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. """ - res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + res = torch.from_numpy(arr).to(dtype=torch.float32).to(device=timesteps.device)[timesteps] while len(res.shape) < len(broadcast_shape): res = res[..., None] return res.expand(broadcast_shape)