diff --git a/galai/__init__.py b/galai/__init__.py index ba535a6..863362d 100644 --- a/galai/__init__.py +++ b/galai/__init__.py @@ -17,6 +17,7 @@ def load_model( name: str, + cache_dir: str, dtype: Union[str, torch.dtype] = None, num_gpus: int = None, parallelize: bool = False @@ -128,6 +129,6 @@ def load_model( tensor_parallel=parallelize, ) model._set_tokenizer(hf_model) - model._load_checkpoint(checkpoint_path=hf_model) + model._load_checkpoint(checkpoint_path=hf_model, cache_dir=cache_dir) return model diff --git a/galai/model.py b/galai/model.py index 866ffbb..fbf1610 100644 --- a/galai/model.py +++ b/galai/model.py @@ -82,7 +82,7 @@ def __init__( self.max_input_length = 2020 self._master_port = None - def _load_checkpoint(self, checkpoint_path: str): + def _load_checkpoint(self, checkpoint_path: str, cache_dir: str): """ Loads the checkpoint for the model @@ -108,6 +108,7 @@ def _load_checkpoint(self, checkpoint_path: str): self.model = OPTForCausalLM.from_pretrained( checkpoint_path, + cache_dir=cache_dir, torch_dtype=self.dtype, low_cpu_mem_usage=True, device_map=device_map,