Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions auto_round/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from auto_round.schemes import QuantizationScheme
from auto_round.auto_scheme import AutoScheme
from auto_round.utils import LazyImport
from auto_round.utils import monkey_patch

monkey_patch()


def __getattr__(name):
Expand Down
36 changes: 36 additions & 0 deletions auto_round/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,42 @@ def __call__(self, *args, **kwargs):
return function(*args, **kwargs)




def rename_kwargs(**name_map):
from functools import wraps

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
for old_name, new_name in name_map.items():
if old_name in kwargs:
if new_name in kwargs:
raise TypeError(f"Cannot specify both {old_name} and {new_name}")
kwargs[new_name] = kwargs.pop(old_name)
return func(*args, **kwargs)

return wrapper

return decorator


# TODO this is not very robust as only AutoModelForCausaLM is patched
def monkey_patch_transformers():
if version.parse(transformers.__version__) >= version.parse("4.56.0"):
transformers.AutoModelForCausalLM.from_pretrained = rename_kwargs(torch_dtype="dtype")(
transformers.AutoModelForCausalLM.from_pretrained
)
else:
transformers.AutoModelForCausalLM.from_pretrained = rename_kwargs(dtype="torch_dtype")(
transformers.AutoModelForCausalLM.from_pretrained
)


def monkey_patch():
monkey_patch_transformers()


auto_gptq = LazyImport("auto_gptq")
htcore = LazyImport("habana_frameworks.torch.core")

Expand Down
Loading