Skip to content

Commit

Permalink
Fixes to run on CPU and MPS
Browse files Browse the repository at this point in the history
  • Loading branch information
Wojtek Kowaluk authored and Wojtek Kowaluk committed Apr 6, 2023
1 parent a4354c0 commit d5162d4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion kandinsky2/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"model_dim": 768,
"use_scale_shift_norm": True,
"resblock_updown": True,
"use_fp16": True,
"use_fp16": False,
"cache_text_emb": True,
"text_encoder_in_dim1": 1024,
"text_encoder_in_dim2": 768,
Expand Down
4 changes: 2 additions & 2 deletions kandinsky2/kandinsky2_1_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
clip_mean,
clip_std,
)
self.prior.load_state_dict(torch.load(prior_path), strict=False)
self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
if self.use_fp16:
self.prior = self.prior.half()
self.text_encoder = TextEncoder(**self.config["text_enc_params"])
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(

self.config["model_config"]["cache_text_emb"] = True
self.model = create_model(**self.config["model_config"])
self.model.load_state_dict(torch.load(model_path))
self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
if self.use_fp16:
self.model.convert_to_fp16()
self.image_encoder = self.image_encoder.half()
Expand Down
2 changes: 1 addition & 1 deletion kandinsky2/model/gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,7 +822,7 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
dimension equal to the length of timesteps.
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
"""
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
res = th.from_numpy(arr).to(dtype=th.float32).to(device=timesteps.device)[timesteps]
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)

0 comments on commit d5162d4

Please sign in to comment.