diff --git a/classy_train.py b/classy_train.py index e693220057..7b942496d0 100755 --- a/classy_train.py +++ b/classy_train.py @@ -117,7 +117,7 @@ def main(args, config): def configure_hooks(args, config): - hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] + hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook(verbose=True)] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() diff --git a/classy_vision/dataset/dataloader_limit_wrapper.py b/classy_vision/dataset/dataloader_limit_wrapper.py index 6e6e347636..787ebe5339 100644 --- a/classy_vision/dataset/dataloader_limit_wrapper.py +++ b/classy_vision/dataset/dataloader_limit_wrapper.py @@ -58,7 +58,7 @@ def __next__(self) -> Any: if self.wrap_around: # create a new iterator to load data from the beginning logging.info( - f"Wrapping around after {self._count} calls. Limit: {self.limit}" + f"Wrapping around after {self._count - 1} calls. Limit: {self.limit}" ) try: self._iter = iter(self.dataloader) diff --git a/classy_vision/generic/profiler.py b/classy_vision/generic/profiler.py index 020f67bcd5..712978f3bb 100644 --- a/classy_vision/generic/profiler.py +++ b/classy_vision/generic/profiler.py @@ -86,7 +86,24 @@ def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]: return x.size() -def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int: +def _get_batchsize_per_replica(x: Union[Tuple, List, Dict]) -> int: + """ + Some layer may take tuple/list/dict/list[dict] as input in forward function. We + recursively dive into the tuple/list until we meet a tensor and infer the batch size + """ + while isinstance(x, (list, tuple)): + assert len(x) > 0, "input x of tuple/list type must have at least one element" + x = x[0] + + if isinstance(x, (dict,)): + # index zero is always equal to batch size. select an arbitrary key. + key_list = list(x.keys()) + x = x[key_list[0]] + + return x.size()[0] + + +def _layer_flops(layer: nn.Module, x: Any, y: Any, verbose: bool = False) -> int: """ Computes the number of FLOPs required for a single layer. @@ -146,6 +163,36 @@ def flops(self, x): / layer.groups ) + # 3D convolution + elif layer_type in ["Conv3d"]: + out_t = int( + (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) + // layer.stride[0] + + 1 + ) + out_h = int( + (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) + // layer.stride[1] + + 1 + ) + out_w = int( + (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) + // layer.stride[2] + + 1 + ) + flops = ( + batchsize_per_replica + * layer.in_channels + * layer.out_channels + * layer.kernel_size[0] + * layer.kernel_size[1] + * layer.kernel_size[2] + * out_t + * out_h + * out_w + / layer.groups + ) + # learned group convolution: elif layer_type in ["LearnedGroupConv"]: conv = layer.conv @@ -170,45 +217,31 @@ def flops(self, x): ) flops = count1 + count2 - # non-linearities: + # non-linearities are not considered in MAC counting elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]: - flops = x.numel() - - # 2D pooling layers: - elif layer_type in ["AvgPool2d", "MaxPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.kernel_size, int): - layer.kernel_size = (layer.kernel_size, layer.kernel_size) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] - out_h = 1 + int( - (in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride - ) - out_w = 1 + int( - (in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride + flops = 0 + + elif layer_type in [ + "MaxPool1d", + "MaxPool2d", + "MaxPool3d", + "AdaptiveMaxPool1d", + "AdaptiveMaxPool2d", + "AdaptiveMaxPool3d", + ]: + flops = 0 + + elif layer_type in ["AvgPool1d", "AvgPool2d", "AvgPool3d"]: + kernel_ops = 1 + flops = kernel_ops * y.numel() + + elif layer_type in ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"]: + assert isinstance(layer.output_size, (list, tuple)) + kernel = torch.Tensor(list(x.shape[2:])) // torch.Tensor( + [list(layer.output_size)] ) - flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops - - # adaptive avg pool2d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling.cpp - elif layer_type in ["AdaptiveAvgPool2d"]: - in_h = x.size()[2] - in_w = x.size()[3] - if isinstance(layer.output_size, int): - out_h, out_w = layer.output_size, layer.output_size - elif len(layer.output_size) == 1: - out_h, out_w = layer.output_size[0], layer.output_size[0] - else: - out_h, out_w = layer.output_size - if out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kh * kw - flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops + kernel_ops = torch.prod(kernel) + flops = kernel_ops * y.numel() # linear layer: elif layer_type in ["Linear"]: @@ -224,94 +257,12 @@ def flops(self, x): "SyncBatchNorm", "LayerNorm", ]: - flops = 2 * x.numel() - - # 3D convolution - elif layer_type in ["Conv3d"]: - out_t = int( - (x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0]) - // layer.stride[0] - + 1 - ) - out_h = int( - (x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1]) - // layer.stride[1] - + 1 - ) - out_w = int( - (x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2]) - // layer.stride[2] - + 1 - ) - flops = ( - batchsize_per_replica - * layer.in_channels - * layer.out_channels - * layer.kernel_size[0] - * layer.kernel_size[1] - * layer.kernel_size[2] - * out_t - * out_h - * out_w - / layer.groups - ) - - # 3D pooling layers - elif layer_type in ["AvgPool3d", "MaxPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - if isinstance(layer.kernel_size, int): - layer.kernel_size = ( - layer.kernel_size, - layer.kernel_size, - layer.kernel_size, - ) - if isinstance(layer.padding, int): - layer.padding = (layer.padding, layer.padding, layer.padding) - if isinstance(layer.stride, int): - layer.stride = (layer.stride, layer.stride, layer.stride) - kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2] - out_t = 1 + int( - (in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0] - ) - out_h = 1 + int( - (in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1] - ) - out_w = 1 + int( - (in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2] - ) - flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops - - # adaptive avg pool3d - # This is approximate and works only for downsampling without padding - # based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp - elif layer_type in ["AdaptiveAvgPool3d"]: - in_t = x.size()[2] - in_h = x.size()[3] - in_w = x.size()[4] - out_t = layer.output_size[0] - out_h = layer.output_size[1] - out_w = layer.output_size[2] - if out_t > in_t or out_h > in_h or out_w > in_w: - raise ClassyProfilerNotImplementedError(layer) - batchsize_per_replica = x.size()[0] - num_channels = x.size()[1] - kt = in_t - out_t + 1 - kh = in_h - out_h + 1 - kw = in_w - out_w + 1 - kernel_ops = kt * kh * kw - flops = ( - batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops - ) + # batchnorm can be merged into conv op. Thus, count 0 FLOPS + flops = 0 # dropout layer elif layer_type in ["Dropout"]: - # At test time, we do not drop values but scale the feature map by the - # dropout ratio - flops = 1 - for dim_size in x.size(): - flops *= dim_size + flops = 0 elif layer_type == "Identity": flops = 0 @@ -335,11 +286,14 @@ def flops(self, x): f"params(M): {count_params(layer) / 1e6}", f"flops(M): {int(flops) / 1e6}", ] - logging.debug("\t".join(message)) - return flops + if verbose: + logging.info("\t".join(message)) + return int(flops) -def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int: +def _layer_activations( + layer: nn.Module, x: Any, out: Any, verbose: bool = False +) -> int: """ Computes the number of activations produced by a single layer. @@ -360,8 +314,9 @@ def activations(self, x, out): return 0 message = [f"module: {typestr}", f"activations: {activations}"] - logging.debug("\t".join(message)) - return activations + if verbose: + logging.info("\t".join(message)) + return int(activations) def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str: @@ -386,17 +341,19 @@ def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str: class ComplexityComputer: - def __init__(self, compute_fn: Callable, count_unique: bool): + def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = False): self.compute_fn = compute_fn self.count_unique = count_unique self.count = 0 + self.verbose = verbose self.seen_modules = set() def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str): if self.count_unique and module_name in self.seen_modules: return - logging.debug(f"module name: {module_name}") - self.count += self.compute_fn(layer, x, out) + self.count += self.compute_fn(layer, x, out, self.verbose) + if self.verbose: + logging.info(f"module name: {module_name}, count {self.count}") self.seen_modules.add(module_name) def reset(self): @@ -482,6 +439,7 @@ def compute_complexity( input_key: Optional[Union[str, List[str]]] = None, patch_attr: str = None, compute_unique: bool = False, + verbose: bool = False, ) -> int: """ Compute the complexity of a forward pass. @@ -501,7 +459,7 @@ def compute_complexity( else: input = get_model_dummy_input(model, input_shape, input_key) - complexity_computer = ComplexityComputer(compute_fn, compute_unique) + complexity_computer = ComplexityComputer(compute_fn, compute_unique, verbose) # measure FLOPs: modify_forward(model, complexity_computer, patch_attr=patch_attr) @@ -519,12 +477,13 @@ def compute_flops( model: nn.Module, input_shape: Tuple[int] = (3, 224, 224), input_key: Optional[Union[str, List[str]]] = None, + verbose: bool = False, ) -> int: """ Compute the number of FLOPs needed for a forward pass. """ return compute_complexity( - model, _layer_flops, input_shape, input_key, patch_attr="flops" + model, _layer_flops, input_shape, input_key, patch_attr="flops", verbose=verbose ) @@ -532,12 +491,18 @@ def compute_activations( model: nn.Module, input_shape: Tuple[int] = (3, 224, 224), input_key: Optional[Union[str, List[str]]] = None, + verbose: bool = False, ) -> int: """ Compute the number of activations created in a forward pass. """ return compute_complexity( - model, _layer_activations, input_shape, input_key, patch_attr="activations" + model, + _layer_activations, + input_shape, + input_key, + patch_attr="activations", + verbose=verbose, ) diff --git a/classy_vision/hooks/exponential_moving_average_model_hook.py b/classy_vision/hooks/exponential_moving_average_model_hook.py index 5aa2144b41..324a173f99 100644 --- a/classy_vision/hooks/exponential_moving_average_model_hook.py +++ b/classy_vision/hooks/exponential_moving_average_model_hook.py @@ -32,7 +32,11 @@ class ExponentialMovingAverageModelHook(ClassyHook): on_end = ClassyHook._noop def __init__( - self, decay: float, consider_bn_buffers: bool = True, device: str = "gpu" + self, + decay: float, + consider_bn_buffers: bool = True, + device: str = "gpu", + ema_model_state_init: bool = False, ) -> None: """The constructor method of ExponentialMovingAverageModelHook. @@ -49,6 +53,7 @@ def __init__( self.decay: int = decay self.consider_bn_buffers = consider_bn_buffers self.device = "cuda" if device == "gpu" else "cpu" + self.ema_model_state_init = ema_model_state_init self.state.model_state = {} self.state.ema_model_state = {} self.ema_model_state_list = [] @@ -72,25 +77,37 @@ def get_model_state_iterator(self, model: nn.Module) -> Iterable[Tuple[str, Any] iterable = itertools.chain(iterable, buffers_iterable) return iterable - def _save_current_model_state(self, model: nn.Module, model_state: Dict[str, Any]): + def _save_current_model_state( + self, model: nn.Module, model_state: Dict[str, Any], overwrite: bool = False + ): """Copy the model's state to the provided dict.""" for name, param in self.get_model_state_iterator(model): - model_state[name] = param.detach().clone().to(device=self.device) + if overwrite or (name not in model_state): + model_state[name] = param.detach().clone().to(device=self.device) def on_start(self, task) -> None: if self.state.model_state: # loaded state from checkpoint, do not re-initialize, only move the state # to the right device for name in self.state.model_state: - self.state.model_state[name] = self.state.model_state[name].to( - device=self.device - ) + if self.ema_model_state_init: + self.state.model_state[name] = ( + self.state.ema_model_state[name] + .detach() + .clone() + .to(device=self.device) + ) + else: + self.state.model_state[name] = self.state.model_state[name].to( + device=self.device + ) + self.state.ema_model_state[name] = self.state.ema_model_state[name].to( device=self.device ) - else: - self._save_current_model_state(task.base_model, self.state.model_state) - self._save_current_model_state(task.base_model, self.state.ema_model_state) + + self._save_current_model_state(task.base_model, self.state.model_state) + self._save_current_model_state(task.base_model, self.state.ema_model_state) if self.use_optimization(task): non_fp_states = [] @@ -129,7 +146,9 @@ def on_phase_end(self, task) -> None: if task.train: # save the current model state since this will be overwritten by the ema # state in the test phase - self._save_current_model_state(task.base_model, self.state.model_state) + self._save_current_model_state( + task.base_model, self.state.model_state, overwrite=True + ) def on_step(self, task) -> None: if not task.train: diff --git a/classy_vision/hooks/loss_lr_meter_logging_hook.py b/classy_vision/hooks/loss_lr_meter_logging_hook.py index 99078efb27..34b5d80c3c 100644 --- a/classy_vision/hooks/loss_lr_meter_logging_hook.py +++ b/classy_vision/hooks/loss_lr_meter_logging_hook.py @@ -7,6 +7,7 @@ import logging from typing import Optional +import torch from classy_vision.generic.distributed_util import get_rank from classy_vision.hooks import register_hook from classy_vision.hooks.classy_hook import ClassyHook @@ -33,6 +34,10 @@ def __init__(self, log_freq: Optional[int] = None) -> None: log_freq, int ), "log_freq must be an int or None" self.log_freq: Optional[int] = log_freq + self.state.meter_max = { + "train": {}, + "test": {}, + } def on_start(self, task) -> None: logging.info(f"Starting training. Task: {task}") @@ -47,7 +52,16 @@ def on_phase_end(self, task) -> None: # do not explicitly state this since it is possible for a # trainer to implement an unsynced end of phase meter or # for meters to not provide a sync function. - self._log_loss_lr_meters(task, prefix="Synced meters: ", log_batches=True) + self._log_loss_lr_meters( + task, prefix="Synced meters: ", log_batches=True, log_meter_max=True + ) + + logging.info( + f"max memory allocated(MB) {torch.cuda.max_memory_allocated() // 1e6}" + ) + logging.info( + f"max memory reserved(MB) {torch.cuda.max_memory_reserved() // 1e6}" + ) def on_step(self, task) -> None: """ @@ -59,7 +73,25 @@ def on_step(self, task) -> None: if batches and batches % self.log_freq == 0: self._log_loss_lr_meters(task, prefix="Approximate meters: ") - def _log_loss_lr_meters(self, task, prefix="", log_batches=False) -> None: + def _log_meter_max(self, task): + for meter in task.meters: + if meter.name not in self.state.meter_max[task.phase_type]: + self.state.meter_max[task.phase_type][meter.name] = { + k: v for k, v in meter.value.items() + } + else: + for k, v in meter.value.items(): + self.state.meter_max[task.phase_type][meter.name][k] = max( + v, self.state.meter_max[task.phase_type][meter.name][k] + ) + for k, v in self.state.meter_max[task.phase_type][meter.name].items(): + logging.info( + f"phase {task.phase_type}, meter {meter.name} {k}, current best: {v}" + ) + + def _log_loss_lr_meters( + self, task, prefix="", log_batches=False, log_meter_max=False + ) -> None: """ Compute and log the loss, lr, and meters. """ @@ -82,4 +114,7 @@ def _log_loss_lr_meters(self, task, prefix="", log_batches=False) -> None: if log_batches: msg += f", processed batches: {batches}" + if log_meter_max: + self._log_meter_max(task) + logging.info(msg) diff --git a/classy_vision/hooks/model_complexity_hook.py b/classy_vision/hooks/model_complexity_hook.py index 2d950e229a..0fc4b234c3 100644 --- a/classy_vision/hooks/model_complexity_hook.py +++ b/classy_vision/hooks/model_complexity_hook.py @@ -27,11 +27,12 @@ class ModelComplexityHook(ClassyHook): on_phase_end = ClassyHook._noop on_end = ClassyHook._noop - def __init__(self) -> None: + def __init__(self, verbose=False) -> None: super().__init__() self.num_flops = None self.num_activations = None self.num_parameters = None + self.verbose = verbose def on_start(self, task) -> None: """Measure number of parameters, FLOPs and activations.""" @@ -48,15 +49,13 @@ def on_start(self, task) -> None: input_key=task.base_model.input_key if hasattr(task.base_model, "input_key") else None, + verbose=self.verbose, ) if self.num_flops is None: logging.info("FLOPs for forward pass: skipped.") self.num_flops = 0 else: - logging.info( - "FLOPs for forward pass: %d MFLOPs" - % (float(self.num_flops) / 1e6) - ) + logging.info(f"FLOPs for forward pass: {self.num_flops} FLOPs") except ClassyProfilerNotImplementedError as e: logging.warning(f"Could not compute FLOPs for model forward pass: {e}") try: @@ -66,6 +65,7 @@ def on_start(self, task) -> None: input_key=task.base_model.input_key if hasattr(task.base_model, "input_key") else None, + verbose=self.verbose, ) logging.info(f"Number of activations in model: {self.num_activations}") except ClassyProfilerNotImplementedError as e: diff --git a/classy_vision/hooks/precise_batch_norm_hook.py b/classy_vision/hooks/precise_batch_norm_hook.py index de08fde558..1304d88e42 100644 --- a/classy_vision/hooks/precise_batch_norm_hook.py +++ b/classy_vision/hooks/precise_batch_norm_hook.py @@ -4,21 +4,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import torch -from classy_vision.generic.util import ( - get_batchsize_per_replica, - recursive_copy_to_device, - recursive_copy_to_gpu, -) +import logging + +from classy_vision.generic.util import get_batchsize_per_replica, recursive_copy_to_gpu from classy_vision.hooks import ClassyHook, register_hook from fvcore.nn.precise_bn import update_bn_stats -def _get_iterator(cache, use_gpu): - for elem in cache: +def _get_iterator(data_iter, use_gpu): + for elem in data_iter: if use_gpu: elem = recursive_copy_to_gpu(elem, non_blocking=True) - yield elem + yield elem["input"] @register_hook("precise_bn") @@ -32,9 +29,7 @@ class PreciseBatchNormHook(ClassyHook): fvcore/nn/precise_bn.py>`_ for more information. """ - on_start = ClassyHook._noop on_phase_start = ClassyHook._noop - on_step = ClassyHook._noop on_end = ClassyHook._noop def __init__(self, num_samples: int) -> None: @@ -49,31 +44,32 @@ def __init__(self, num_samples: int) -> None: if num_samples <= 0: raise ValueError("num_samples has to be a positive integer") self.num_samples = num_samples - self.cache = [] - self.current_samples = 0 + self.batch_size = None @classmethod def from_config(cls, config): return cls(config["num_samples"]) - def on_phase_start(self, task) -> None: - self.cache = [] - self.current_samples = 0 + def on_start(self, task) -> None: + logging.info("Use precise BatchNorm hook") def on_step(self, task) -> None: - if not task.train or self.current_samples >= self.num_samples: + if not task.train or self.batch_size is not None: return - input = recursive_copy_to_device( - task.last_batch.sample["input"], - non_blocking=True, - device=torch.device("cpu"), - ) - self.cache.append(input) - self.current_samples += get_batchsize_per_replica(input) + + self.batch_size = get_batchsize_per_replica(task.last_batch.sample["input"]) def on_phase_end(self, task) -> None: if not task.train: return - iterator = _get_iterator(self.cache, task.use_gpu) - num_batches = len(self.cache) + + num_batches = (self.num_samples + self.batch_size - 1) // self.batch_size + + task.build_dataloaders_for_current_phase() + task.create_data_iterators() + if num_batches > len(task.data_iterator): + num_batches = len(task.data_iterator) + logging.info(f"Reduce no. of samples to {num_batches * self.batch_size}") + + iterator = _get_iterator(task.data_iterator, task.use_gpu) update_bn_stats(task.base_model, iterator, num_batches) diff --git a/classy_vision/tasks/fine_tuning_task.py b/classy_vision/tasks/fine_tuning_task.py index bd30e29c7d..53f910a89d 100644 --- a/classy_vision/tasks/fine_tuning_task.py +++ b/classy_vision/tasks/fine_tuning_task.py @@ -4,7 +4,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Dict +import logging +from typing import Any, Dict, List from classy_vision.generic.util import ( load_and_broadcast_checkpoint, @@ -20,6 +21,7 @@ def __init__(self, *args, **kwargs): self.pretrained_checkpoint_dict = None self.pretrained_checkpoint_path = None self.pretrained_checkpoint_load_strict = True + self.hooks_load_from_pretrained_checkpoint = [] self.reset_heads = False self.freeze_trunk = False @@ -38,9 +40,12 @@ def from_config(cls, config: Dict[str, Any]) -> "FineTuningTask": pretrained_checkpoint_path = config.get("pretrained_checkpoint") if pretrained_checkpoint_path: - task.set_pretrained_checkpoint(pretrained_checkpoint_path) - task.set_pretrained_checkpoint_load_strict( + task.set_pretrained_checkpoint( + pretrained_checkpoint_path + ).set_pretrained_checkpoint_load_strict( config.get("pretrained_checkpoint_load_strict", True) + ).set_hooks_load_from_pretrained_checkpoint( + config.get("hooks_load_from_pretrained_checkpoint", []) ) task.set_reset_heads(config.get("reset_heads", False)) @@ -57,6 +62,19 @@ def set_pretrained_checkpoint_load_strict( self.pretrained_checkpoint_load_strict = pretrained_checkpoint_load_strict return self + def set_hooks_load_from_pretrained_checkpoint( + self, hooks_load_from_pretrained_checkpoint: List[str] + ): + """ + Args: + hooks_load_from_pretrained_checkpoint: a list of the names of the hooks that we + want to load state dict from pretrained checkpoint + """ + self.hooks_load_from_pretrained_checkpoint = ( + hooks_load_from_pretrained_checkpoint + ) + return self + def _set_pretrained_checkpoint_dict( self, checkpoint_dict: Dict[str, Any] ) -> "FineTuningTask": @@ -84,6 +102,14 @@ def _set_model_train_mode(self): else: self.base_model.train(phase["train"]) + def _load_hooks_from_pretrained_checkpoint(self, state: Dict[str, Any]): + for hook in self.hooks: + if ( + hook.name() in state["hooks"] + and hook.name() in self.hooks_load_from_pretrained_checkpoint + ): + hook.set_classy_state(state["hooks"][hook.name()]) + def prepare(self) -> None: super().prepare() if self.checkpoint_dict is None: @@ -95,19 +121,26 @@ def prepare(self) -> None: self.pretrained_checkpoint_path ) - assert ( - self.pretrained_checkpoint_dict is not None - ), "Need a pretrained checkpoint for fine tuning" + if self.pretrained_checkpoint_dict is None: + logging.warn("a pretrained checkpoint is not provided") + else: + assert ( + self.pretrained_checkpoint_dict is not None + ), "Need a pretrained checkpoint for fine tuning" - state_load_success = update_classy_model( - self.base_model, - self.pretrained_checkpoint_dict["classy_state_dict"]["base_model"], - self.reset_heads, - self.pretrained_checkpoint_load_strict, - ) - assert ( - state_load_success - ), "Update classy state from pretrained checkpoint was unsuccessful." + state = self.pretrained_checkpoint_dict["classy_state_dict"] + + state_load_success = update_classy_model( + self.base_model, + state["base_model"], + self.reset_heads, + self.pretrained_checkpoint_load_strict, + ) + assert ( + state_load_success + ), "Update classy state from pretrained checkpoint was unsuccessful." + + self._load_hooks_from_pretrained_checkpoint(state) if self.freeze_trunk: # do not track gradients for all the parameters in the model except