34
34
from torch .utils .data import DataLoader , Sampler
35
35
from transformers import (
36
36
AutoConfig ,
37
+ AutoModelForCausalLM ,
37
38
AutoModelForSequenceClassification ,
38
39
AutoProcessor ,
39
40
AutoTokenizer ,
@@ -239,9 +240,13 @@ def __init__(
239
240
f"a `torch.dtype` (e.g., 'float32'), but got { dtype } ."
240
241
)
241
242
# Disable caching if gradient checkpointing is enabled (not supported)
242
- config = AutoConfig .from_pretrained (model_id )
243
- architecture = getattr (transformers , config .architectures [0 ])
244
- model = architecture .from_pretrained (model_id , ** model_init_kwargs )
243
+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = args .trust_remote_code )
244
+ if architecture := getattr (transformers , config .architectures [0 ], None ):
245
+ model = architecture .from_pretrained (model_id , ** model_init_kwargs )
246
+ else :
247
+ model = AutoModelForCausalLM .from_pretrained (
248
+ model_id , trust_remote_code = args .trust_remote_code , ** model_init_kwargs
249
+ )
245
250
else :
246
251
model_id = model .config ._name_or_path
247
252
if args .model_init_kwargs is not None :
@@ -263,7 +268,9 @@ def __init__(
263
268
264
269
# Processing class
265
270
if processing_class is None :
266
- processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path )
271
+ processing_class = AutoProcessor .from_pretrained (
272
+ model .config ._name_or_path , trust_remote_code = args .trust_remote_code
273
+ )
267
274
268
275
# Handle pad token for processors or tokenizers
269
276
if isinstance (processing_class , ProcessorMixin ):
@@ -427,9 +434,13 @@ def __init__(
427
434
self .ref_model = None
428
435
else :
429
436
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430
- config = AutoConfig .from_pretrained (model_id )
431
- architecture = getattr (transformers , config .architectures [0 ])
432
- self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
437
+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = args .trust_remote_code )
438
+ if architecture := getattr (transformers , config .architectures [0 ], None ):
439
+ self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
440
+ else :
441
+ self .ref_model = AutoModelForCausalLM .from_pretrained (
442
+ model_id , trust_remote_code = args .trust_remote_code , ** model_init_kwargs
443
+ )
433
444
434
445
# Disable dropout in the models
435
446
if args .disable_dropout :
@@ -537,6 +548,7 @@ def __init__(
537
548
max_num_batched_tokens = 4096 ,
538
549
model_impl = self .args .vllm_model_impl ,
539
550
enable_sleep_mode = self .args .vllm_enable_sleep_mode ,
551
+ trust_remote_code = self .args .trust_remote_code ,
540
552
)
541
553
if self .args .vllm_enable_sleep_mode :
542
554
self .llm .sleep (level = 1 )
0 commit comments