diff --git a/comfy/model_management.py b/comfy/model_management.py index 2346d4ac..0d5b0730 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -840,27 +840,21 @@ def force_channels_last(): #TODO return False -def cast_to_device(tensor, device, dtype, copy=False): - device_supports_cast = False - if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: - device_supports_cast = True - elif tensor.dtype == torch.bfloat16: - if hasattr(device, 'type') and device.type.startswith("cuda"): - device_supports_cast = True - elif is_intel_xpu(): - device_supports_cast = True +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) - non_blocking = device_should_use_non_blocking(device) + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + return r + +def cast_to_device(tensor, device, dtype, copy=False): + non_blocking = device_supports_non_blocking(device) + return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) - if device_supports_cast: - if copy: - if tensor.device == device: - return tensor.to(dtype, copy=copy, non_blocking=non_blocking) - return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) - else: - return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) - else: - return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled diff --git a/comfy/ops.py b/comfy/ops.py index c90e25ea..a8bfe1ea 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -20,19 +20,10 @@ import comfy.model_management from comfy.cli_args import args -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): - if device is None or weight.device == device: - if not copy: - if dtype is None or weight.dtype == dtype: - return weight - return weight.to(dtype=dtype, copy=copy) - - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) - return r +cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): - return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: @@ -47,12 +38,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = s.bias_function is not None - bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: bias = s.bias_function(bias) has_function = s.weight_function is not None - weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: weight = s.weight_function(weight) return weight, bias