Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion optimum/habana/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import shutil
import time
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
Expand Down Expand Up @@ -341,7 +342,100 @@ def _move_model_to_device(self, model, device):
model = model.to(device)
# Moving a model to HPU disconnects the tied weights, so we have to retie them.
if self.args.use_habana and hasattr(model, "tie_weights"):
model.tie_weights()
try:
model.tie_weights()
except KeyError as error:
logger.warning(
"[GaudiTrainer] model.tie_weights() failed with KeyError (%s). "
"Falling back to safe manual retie for PEFT/LoRA modules.",
error,
)
self._safe_tie_weights(model)

def _safe_tie_weights(self, model):
"""Fallback that re-ties embeddings without re-registering existing parameters."""
get_input = getattr(model, "get_input_embeddings", None)
get_output = getattr(model, "get_output_embeddings", None)

if not callable(get_input) or not callable(get_output):
return

input_embeddings = get_input()
output_embeddings = get_output()

if input_embeddings is None or output_embeddings is None:
return

if not hasattr(input_embeddings, "weight") or not hasattr(output_embeddings, "weight"):
return

logger.warning(
"[GaudiTrainer][tie_weights fallback] input=%s output=%s | input_param_id=%s output_param_id=%s | "
"output_has_param=%s",
type(input_embeddings).__name__,
type(output_embeddings).__name__,
id(input_embeddings.weight),
id(output_embeddings.weight),
hasattr(output_embeddings, "_parameters") and "weight" in output_embeddings._parameters,
)

self._replace_module_parameter(output_embeddings, "weight", input_embeddings.weight)

def _replace_module_parameter(self, module, name, new_parameter):
params = getattr(module, "_parameters", None)
attr_value = getattr(module, name, None)
attr_type = type(attr_value).__name__ if attr_value is not None else "None"

if params is None:
params = OrderedDict()
module._parameters = params

if name in params:
logger.warning(
"[GaudiTrainer][tie_weights fallback] Overwriting existing parameter '%s' on %s",
name,
type(module).__name__,
)
params[name] = new_parameter
return

if isinstance(attr_value, torch.nn.Parameter):
logger.warning(
"[GaudiTrainer][tie_weights fallback] Restoring Parameter '%s' on %s whose _parameters entry went missing",
name,
type(module).__name__,
)
params[name] = new_parameter
module.__dict__[name] = new_parameter
return

if hasattr(module, name):
logger.warning(
"[GaudiTrainer][tie_weights fallback] Removing non-Parameter attribute '%s' (type=%s) on %s before registering tied weight",
name,
attr_type,
type(module).__name__,
)
with contextlib.suppress(AttributeError):
delattr(module, name)

try:
module.register_parameter(name, new_parameter)
except KeyError as exc:
logger.error(
"[GaudiTrainer][tie_weights fallback] register_parameter still failed for '%s' on %s (existing attr type=%s): %s",
name,
type(module).__name__,
attr_type,
exc,
)
params[name] = new_parameter
module.__dict__[name] = new_parameter
logger.warning(
"[GaudiTrainer][tie_weights fallback] Directly injected Parameter '%s' into %s._parameters after register failure",
name,
type(module).__name__,
)

def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.train_dataset is None or not has_length(self.train_dataset):
Expand Down