Skip to content

Commit

Permalink
Clean up some controlnet code.
Browse files Browse the repository at this point in the history
Remove self.device which was useless.
  • Loading branch information
comfyanonymous committed Oct 23, 2024
1 parent 915fdb5 commit 754597c
Showing 1 changed file with 10 additions and 11 deletions.
21 changes: 10 additions & 11 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class StrengthType(Enum):
LINEAR_UP = 2

class ControlBase:
def __init__(self, device=None):
def __init__(self):
self.cond_hint_original = None
self.cond_hint = None
self.strength = 1.0
Expand All @@ -72,10 +72,6 @@ def __init__(self, device=None):
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}

if device is None:
device = comfy.model_management.get_torch_device()
self.device = device
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
Expand Down Expand Up @@ -185,8 +181,8 @@ def set_extra_arg(self, argument, value=None):


class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__(device)
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
super().__init__()
self.control_model = control_model
self.load_device = load_device
if control_model is not None:
Expand Down Expand Up @@ -242,7 +238,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)

self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
self.cond_hint = self.cond_hint.to(device=self.load_device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)

Expand Down Expand Up @@ -341,8 +337,8 @@ def forward(self, input):


class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): #TODO? model_options
ControlBase.__init__(self, device)
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
ControlBase.__init__(self)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
self.extra_conds += ["y"]
Expand Down Expand Up @@ -662,12 +658,15 @@ def load_controlnet(ckpt_path, model=None, model_options={}):

class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
super().__init__()
self.t2i_model = t2i_model
self.channels_in = channels_in
self.control_input = None
self.compression_ratio = compression_ratio
self.upscale_algorithm = upscale_algorithm
if device is None:
device = comfy.model_management.get_torch_device()
self.device = device

def scale_image_to(self, width, height):
unshuffle_amount = self.t2i_model.unshuffle_amount
Expand Down

0 comments on commit 754597c

Please sign in to comment.