@@ -239,7 +239,7 @@ def __init__(
239
239
f"a `torch.dtype` (e.g., 'float32'), but got { dtype } ."
240
240
)
241
241
# Disable caching if gradient checkpointing is enabled (not supported)
242
- config = AutoConfig .from_pretrained (model_id )
242
+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = self . args . trust_remote_code )
243
243
architecture = getattr (transformers , config .architectures [0 ])
244
244
model = architecture .from_pretrained (model_id , ** model_init_kwargs )
245
245
else :
@@ -263,7 +263,9 @@ def __init__(
263
263
264
264
# Processing class
265
265
if processing_class is None :
266
- processing_class = AutoProcessor .from_pretrained (model .config ._name_or_path )
266
+ processing_class = AutoProcessor .from_pretrained (
267
+ model .config ._name_or_path , trust_remote_code = self .args .trust_remote_code
268
+ )
267
269
268
270
# Handle pad token for processors or tokenizers
269
271
if isinstance (processing_class , ProcessorMixin ):
@@ -427,7 +429,7 @@ def __init__(
427
429
self .ref_model = None
428
430
else :
429
431
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
430
- config = AutoConfig .from_pretrained (model_id )
432
+ config = AutoConfig .from_pretrained (model_id , trust_remote_code = self . args . trust_remote_code )
431
433
architecture = getattr (transformers , config .architectures [0 ])
432
434
self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
433
435
@@ -537,6 +539,7 @@ def __init__(
537
539
max_num_batched_tokens = 4096 ,
538
540
model_impl = self .args .vllm_model_impl ,
539
541
enable_sleep_mode = self .args .vllm_enable_sleep_mode ,
542
+ trust_remote_code = self .args .trust_remote_code ,
540
543
)
541
544
if self .args .vllm_enable_sleep_mode :
542
545
self .llm .sleep (level = 1 )
0 commit comments