Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow loading and serializing with tensorizer #2

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions exllamav2/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ class ExLlamaV2Config:
tensor_file_map: dict
tensor_files: list

# Tensorizer args
write_state_dict: bool # Whether to construct a state_dict, necessary for serializing with `tensorizer`
load_with_tensorizer: bool # Deserialize tensors with `tensorizer`. Model tensors are expected to be found in model_dir/model.tensors or 'TENSORIZER_LOC'
# TODO: May want to raise a NameError/ValueError/warning if no model.tensors file can be found in model_dir/*

tokenizer_path: str

bos_token_id: int
Expand Down Expand Up @@ -156,6 +161,11 @@ class ExLlamaV2Config:
fasttensors: bool # Fasttensors loader removed in v0.2.3


def __new__(cls, *args, **kwargs):
if kwargs.get("load_with_tensorizer"):
from util.tensorizer_utils import TensorizerConfigExtension
return TensorizerConfigExtension(*args, **kwargs)

def __init__(self,
model_dir: str | None = None):
"""
Expand All @@ -176,6 +186,20 @@ def __init__(self,
self.no_flash_attn = 'EXLLAMA_NO_FLASH_ATTN' in os.environ
self.no_xformers = 'EXLLAMA_NO_XFORMERS' in os.environ
self.no_sdpa = 'EXLLAMA_NO_SDPA' in os.environ
self.fasttensors = 'EXLLAMA_FASTTENSORS' in os.environ

# TODO: Make this exposed in a better way, as it forces config.write_state_dict = True
self.write_state_dict = False

## TODO: Think of a nicer way than this
self.load_with_tensorizer = 'TENSORIZER' in os.environ
self.tensorizer_args = {
"s3_access_key_id": os.environ.get("S3_ACCESS_KEY_ID"),
"s3_secret_access_key": os.environ.get("S3_SECRET_ACCESS_KEY"),
"s3_endpoint": os.environ.get("S3_ENDPOINT_URL"),
}


self.load_in_q4 = False
self.no_graphs = 'EXLLAMA_NO_GRAPHS' in os.environ

Expand All @@ -202,15 +226,16 @@ def set_low_mem(self):
def prepare(self, no_tensors: bool = False):

assert self.model_dir is not None, "No model_dir specified in ExLlamaV2Config"
assert os.path.exists(self.model_dir), "Can't find " + self.model_dir

# Load config.json
# TODO: Add this in the __init__ of the tensorizer subclass
if self.load_with_tensorizer or self.write_state_dict:
from util.tensorizer_utils import validate_tensorizer_args
validate_tensorizer_args(self)

self.model_config = os.path.join(self.model_dir, "config.json")
assert os.path.exists(self.model_config), "Can't find " + self.model_config
assert os.path.exists(self.model_dir), "Can't find " + self.model_dir

with open(self.model_config, encoding = "utf8") as f:
read_config = json.load(f)
# Load config.json
read_config = self._load_config()

# Load generation_config.json

Expand Down Expand Up @@ -393,13 +418,7 @@ def prepare(self, no_tensors: bool = False):
st_pattern = os.path.join(self.model_dir, "*.safetensors")
self.tensor_files = glob.glob(st_pattern)

if len(self.tensor_files) == 0:
raise ValueError(f" ## No .safetensors files found in {self.model_dir}")

for st_file in self.tensor_files:
f = STFile.open(st_file, keymap = self.arch.keymap)
for key in f.get_dict():
self.tensor_file_map[key] = st_file
self._load_tensor_file_map()

# For loading checkpoints with fused MLP layers

Expand Down Expand Up @@ -590,3 +609,21 @@ def arch_compat_overrides(self, quiet: bool = False, warn_only = False):
if not quiet:
for w in warnings:
print(w)


def _load_config(self):
self.model_config = os.path.join(self.model_dir, "config.json")
assert os.path.exists(self.model_config), "Can't find " + self.model_config

with open(self.model_config, encoding = "utf8") as f:
read_config = json.load(f)
return read_config

def _load_tensor_file_map(self):
if len(self.tensor_files) == 0:
raise ValueError(f" ## No .safetensors files found in {self.model_dir}")

for st_file in self.tensor_files:
f = STFile.open(st_file, keymap=self.arch.keymap)
for key in f.get_dict():
self.tensor_file_map[key] = st_file
11 changes: 10 additions & 1 deletion exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ class ExLlamaV2:

tp_context: TPContext | None

def __new__(cls, *args, **kwargs):
assert isinstance(args[1], ExLlamaV2Config)
if args[1].load_with_tensorizer:
from util.tensorizer_utils import TensorizerModelExtension
return TensorizerModelExtension(*args, **kwargs)

def __init__(
self,
config: ExLlamaV2Config,
Expand Down Expand Up @@ -319,7 +325,7 @@ def load_gen(
with torch.inference_mode():
set_device_streams()

stats_ = self.set_device_map(gpu_split or [99999])
stats_ = self._get_stats(gpu_split)

# Load module weights

Expand Down Expand Up @@ -1044,3 +1050,6 @@ def forward_chunk(
r["logits"] = x

return r

def _get_stats(self, gpu_split):
return self.set_device_map(gpu_split or [99999])
8 changes: 8 additions & 0 deletions exllamav2/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ class ExLlamaV2Module:
submodules: list[ExLlamaV2Module]
assumed_footprint: int

def __new__(cls, *args, **kwargs):
assert isinstance(args[0], ExLlamaV2Config)
if args[0].load_with_tensorizer:
from util.tensorizer_utils import TensorizerModuleExtension
return TensorizerModuleExtension(*args, **kwargs)

def __init__(
self,
model: ExLlamaV2,
Expand Down Expand Up @@ -94,6 +100,8 @@ def load_multi(self,
size += stfile.measure(key + "." + k)
else:
tensors[k] = stfile.get_tensor(key + "." + k, device = self.device() if not cpu else "cpu")
if self.model.config.write_state_dict:
self.model.state_dict[key + "." + k] = tensors[k].to(device="cpu", copy=True)

return size if measure else tensors

Expand Down
14 changes: 12 additions & 2 deletions exllamav2/tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def tokenizer(self):

tokenizer_config_dict: dict | None

def __new__(cls, *args, **kwargs):
assert isinstance(args[1], ExLlamaV2Config)
if args[1].load_with_tensorizer:
from util.tensorizer_utils import TensorizerTokenizerExtension
return TensorizerTokenizerExtension(*args, **kwargs)

def __init__(
self,
config,
Expand Down Expand Up @@ -123,8 +129,7 @@ def __init__(

# Detect tokenizer model type and initialize

path_spm = os.path.join(self.config.model_dir, "tokenizer.model")
path_hf = os.path.join(self.config.model_dir, "tokenizer.json")
path_spm, path_hf = self._load_tokenizer_artifacts()

if os.path.exists(path_hf) and not force_spm:
self.tokenizer_model = ExLlamaV2TokenizerHF(path_hf)
Expand Down Expand Up @@ -815,3 +820,8 @@ def cached_encode_str(self, text: str):
new_enc = self.encode(text)
self.tokenized_str_cache[text] = new_enc
return new_enc

def _load_tokenizer_artifacts(self):
path_spm = os.path.join(self.config.model_dir, "tokenizer.model")
path_hf = os.path.join(self.config.model_dir, "tokenizer.json")
return path_spm, path_hf
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ regex
numpy~=1.26.4
tokenizers
rich
pillow>=9.1.0
pillow>=9.1.0
tensorizer~=2.9.0
Loading