Skip to content

Commit afa9e26

Browse files
authored
refine eval (#282)
1 parent 00122bc commit afa9e26

File tree

3 files changed

+49
-40
lines changed

3 files changed

+49
-40
lines changed

auto_round/__main__.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -328,16 +328,22 @@ def tune(args):
328328
logger.info(f"Using lm-eval version {lm_eval_version}")
329329
model_args = f"pretrained={eval_folder}"
330330
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
331-
user_model = None
332331
if args.act_bits <= 8:
333-
user_model = model.to(device_str)
334-
335-
res = simple_evaluate(
336-
model="hf",
337-
model_args=model_args,
338-
tasks=tasks,
339-
batch_size=args.eval_bs,
340-
user_model=user_model)
332+
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
333+
from accelerate.big_modeling import dispatch_model
334+
335+
dispatch_model(model, model.hf_device_map)
336+
user_model = model
337+
else:
338+
user_model = model.to(device_str)
339+
if args.eval_bs == "auto":
340+
args.eval_bs = 16
341+
from auto_round.eval.evaluation import simple_evaluate_user_model
342+
res = simple_evaluate_user_model(user_model, tokenizer, tasks=tasks, batch_size=args.eval_bs)
343+
else:
344+
res = simple_evaluate(model="hf", model_args=model_args,
345+
tasks=tasks,
346+
batch_size=args.eval_bs)
341347
print(make_table(res))
342348

343349

auto_round/eval/evaluation.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,28 @@
2323

2424
os.environ["TOKENIZERS_PARALLELISM"] = "false"
2525

26+
from lm_eval.models.huggingface import HFLM
27+
28+
29+
def simple_evaluate_user_model(
30+
user_model,
31+
tokenizer,
32+
batch_size: Optional[int] = None,
33+
max_batch_size: Optional[int] = None,
34+
**kwargs
35+
):
36+
hflm = HFLM(pretrained=user_model, tokenizer=tokenizer, batch_size=batch_size, max_batch_size=max_batch_size)
37+
return lm_simple_evaluate(
38+
model=hflm,
39+
model_args=None,
40+
batch_size=batch_size,
41+
max_batch_size=max_batch_size,
42+
**kwargs)
43+
2644

2745
def simple_evaluate(
2846
model,
2947
model_args: Optional[Union[str, dict]] = None,
30-
user_model=None,
3148
batch_size: Optional[int] = None,
3249
max_batch_size: Optional[int] = None,
3350
device: Optional[str] = None,
@@ -37,32 +54,8 @@ def simple_evaluate(
3754
except:
3855
from auto_round.auto_quantizer import AutoHfQuantizer
3956

40-
if model_args is None:
41-
model_args = ""
42-
43-
if isinstance(model_args, dict):
44-
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
45-
model_args,
46-
{
47-
"batch_size": batch_size,
48-
"max_batch_size": max_batch_size,
49-
"device": device,
50-
},
51-
)
52-
53-
else:
54-
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
55-
model_args,
56-
{
57-
"batch_size": batch_size,
58-
"max_batch_size": max_batch_size,
59-
"device": device,
60-
},
61-
)
62-
if user_model is not None:
63-
lm._model = user_model
6457
return lm_simple_evaluate(
65-
model=lm,
58+
model=model,
6659
model_args=model_args,
6760
batch_size=batch_size,
6861
max_batch_size=max_batch_size,

examples/language-modeling/main.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,21 @@
446446
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
447447
user_model = None
448448
if args.act_bits <= 8:
449-
user_model = model.to(device_str)
450-
451-
res = simple_evaluate(model="hf", model_args=model_args,
452-
tasks=tasks,
453-
batch_size=args.eval_bs, user_model=user_model)
449+
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
450+
from accelerate.big_modeling import dispatch_model
451+
452+
dispatch_model(model, model.hf_device_map)
453+
user_model = model
454+
else:
455+
user_model = model.to(device_str)
456+
if args.eval_bs == "auto":
457+
args.eval_bs = 16
458+
from auto_round.eval.evaluation import simple_evaluate_user_model
459+
res = simple_evaluate_user_model(user_model, tokenizer,tasks=tasks,batch_size=args.eval_bs)
460+
else:
461+
res = simple_evaluate(model="hf", model_args=model_args,
462+
tasks=tasks,
463+
batch_size=args.eval_bs)
454464
from lm_eval.utils import make_table
455465

456466
print(make_table(res))

0 commit comments

Comments
 (0)