At the moment the llms and associated inference / embeddings are in class specific implementations. Not sure if this is necessary or DRY as you want to start catering for different architectures (eg. Dolly v2).
Consider refactoring with interfaces using a config: AutoConfig = AutoConfig.from_pretrained(path_or_repo) type approach. As this might allow for scaling to different model types without the need for heavy configuration on the users side or massive amounts of boilerplate rewriting of specific implementations.
Eg. stub
from transformers import AutoConfig
config: AutoConfig = AutoConfig.from_pretrained(path_or_repo)
if config.model_type == "llama":
from transformers import LlamaForCausalLM, LlamaTokenizer
tokenizer: LlamaTokenizer = LlamaTokenizer.from_pretrained(path_or_repo)
model: LlamaForCausalLM = LlamaForCausalLM.from_pretrained(
path_or_repo, **model_kwargs
) # , load_in_8bit=True, device_map="auto")
elif config.model_type == "gpt_neox":
from transformers import GPTNeoXForCausalLM, GPTNeoXTokenizerFast
tokenizer: GPTNeoXForCausalLM = GPTNeoXForCausalLM.from_pretrained(path_or_repo)
model: GPTNeoXTokenizerFast = GPTNeoXTokenizerFast.from_pretrained(
path_or_repo, **model_kwargs
) # , load_in_8bit=True, device_map="auto")
else:
logger.error(f"Unable to determine model type {config.model_type}. Attempting AutoModel")
try:
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(path_or_repo)
model: AutoModelForCausalLM = AutoModelForCausalLM.from_pretrained(
path_or_repo, **model_kwargs
) # , load_in_8bit=True, device_map="auto")
except Exception as e:
logger.exception(e)
raise
At the moment the llms and associated inference / embeddings are in class specific implementations. Not sure if this is necessary or DRY as you want to start catering for different architectures (eg. Dolly v2).
Consider refactoring with interfaces using a
config: AutoConfig = AutoConfig.from_pretrained(path_or_repo)type approach. As this might allow for scaling to different model types without the need for heavy configuration on the users side or massive amounts of boilerplate rewriting of specific implementations.Eg. stub