Skip to content

Commit 47c1ab5

Browse files
tirthasheshpatelmattdangerw
authored andcommitted
Fix Mistral memory consumption with JAX and default dtype bug (#1460)
1 parent 712f172 commit 47c1ab5

File tree

3 files changed

+72
-123
lines changed

3 files changed

+72
-123
lines changed

keras_nlp/models/mistral/mistral_causal_lm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def next(prompt, cache, index):
190190
mask=padding_mask,
191191
end_token_id=end_token_id,
192192
hidden_states=hidden_states,
193+
model=self,
193194
)
194195

195196
# Compute an output padding mask with the token ids we updated.

keras_nlp/models/mistral/mistral_presets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
"path": "mistral",
2424
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
2525
},
26-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/3",
26+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_7b_en/6",
2727
},
2828
"mistral_instruct_7b_en": {
2929
"metadata": {
@@ -33,6 +33,6 @@
3333
"path": "mistral",
3434
"model_card": "https://github.com/mistralai/mistral-src/blob/main/README.md",
3535
},
36-
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/3",
36+
"kaggle_handle": "kaggle://keras/mistral/keras/mistral_instruct_7b_en/6",
3737
},
3838
}

tools/checkpoint_conversion/convert_mistral_checkpoints.py

Lines changed: 69 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import datetime
1514
import gc
16-
import json
1715
import os
18-
import pathlib
16+
import shutil
17+
import tempfile
1918
import traceback
2019

21-
import keras
2220
import numpy as np
2321
import requests
2422
from absl import app
@@ -27,10 +25,10 @@
2725
from transformers import AutoTokenizer
2826
from transformers import MistralForCausalLM
2927

30-
import keras_nlp
3128
from keras_nlp.models import MistralBackbone
3229
from keras_nlp.models import MistralCausalLMPreprocessor
3330
from keras_nlp.models import MistralTokenizer
31+
from keras_nlp.utils.preset_utils import save_to_preset
3432

3533
PRESET_MAP = {
3634
"mistral_7b_en": "mistralai/Mistral-7B-v0.1",
@@ -227,124 +225,74 @@ def main(_):
227225
preset = FLAGS.preset
228226
hf_preset = PRESET_MAP[preset]
229227

230-
# === Create the save directories ===
231-
model_dir = pathlib.Path(__file__).parent / f"{preset}"
232-
tokenizer_dir = model_dir / "assets" / "tokenizer"
233-
if not model_dir.exists():
234-
os.makedirs(model_dir)
235-
if not tokenizer_dir.exists():
236-
os.makedirs(tokenizer_dir)
228+
# === Create the temporary save directories ===
229+
temp_dir = tempfile.mkdtemp()
237230

238-
# === Load the Huggingface model ===
239-
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
240-
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
241-
hf_model.eval()
242-
print("\n-> Huggingface model and tokenizer loaded")
243-
244-
# === Load the KerasNLP model ===
245-
keras_nlp_config = dict(
246-
vocabulary_size=hf_model.config.vocab_size,
247-
hidden_dim=hf_model.config.hidden_size,
248-
num_layers=hf_model.config.num_hidden_layers,
249-
num_query_heads=hf_model.config.num_attention_heads,
250-
num_key_value_heads=hf_model.config.num_key_value_heads,
251-
intermediate_dim=hf_model.config.intermediate_size,
252-
sliding_window=hf_model.config.sliding_window,
253-
layer_norm_epsilon=hf_model.config.rms_norm_eps,
254-
rope_max_wavelength=hf_model.config.rope_theta,
255-
dtype="float32",
256-
)
257-
keras_nlp_model = MistralBackbone(**keras_nlp_config)
258-
259-
# === Download the tokenizer from Huggingface model card ===
260-
spm_path = (
261-
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
262-
)
263-
response = requests.get(spm_path)
264-
if not response.ok:
265-
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
266-
tokenizer_path = tokenizer_dir / "vocabulary.spm"
267-
with open(tokenizer_path, "wb") as tokenizer_file:
268-
tokenizer_file.write(response.content)
269-
keras_nlp_tokenizer = MistralTokenizer(str(tokenizer_path.absolute()))
270-
print("\n-> Keras 3 model and tokenizer loaded.")
271-
272-
# === Port the weights ===
273-
convert_checkpoints(keras_nlp_model, hf_model)
274-
print("\n-> Weight transfer done.")
275-
276-
# === Check that the models and tokenizers outputs match ===
277-
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
278-
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
279-
print("\n-> Tests passed!")
280-
281-
# === Save the model weights in float32 format ===
282-
keras_nlp_model.save_weights(
283-
str((model_dir / "model.weights.h5").absolute())
284-
)
285-
print("\n-> Saved the model weights in float16")
286-
287-
del keras_nlp_model, hf_model
288-
gc.collect()
289-
290-
keras_nlp_config["dtype"] = "float16"
291-
292-
# === Save the weights again in float16 ===
293-
keras_nlp_model = MistralBackbone(**keras_nlp_config)
294-
keras_nlp_model.load_weights(
295-
str((model_dir / "model.weights.h5").absolute())
296-
)
297-
keras_nlp_model.save_weights(
298-
str((model_dir / "model.weights.h5").absolute())
299-
)
300-
print("-> Saved the model weights in float16")
301-
302-
# === Save the model config ===
303-
keras_nlp_config["dtype"] = "bfloat16"
304-
model_config = {
305-
"module": "keras_nlp.src.models.mistral.mistral_backbone",
306-
"class_name": "MistralBackbone",
307-
"config": {**keras_nlp_config},
308-
"registered_name": "keras_nlp>MistralBackbone",
309-
"assets": [],
310-
"weights": "model.weights.h5",
311-
}
312-
model_config_json = json.dumps(model_config)
313-
with open(model_dir / "config.json", "w") as model_config_file:
314-
model_config_file.write(model_config_json)
315-
print("\n-> Saved model config")
316-
317-
# === Save the tokenizer config ===
318-
tokenizer_config = {
319-
"module": "keras_nlp.src.models.mistral.Mistral_tokenizer",
320-
"class_name": "MistralTokenizer",
321-
"config": {
322-
"name": "mistral_tokenizer",
323-
"trainable": True,
324-
"dtype": "int32",
325-
"proto": None,
326-
"sequence_length": None,
327-
},
328-
"registered_name": "keras_nlp>MistralTokenizer",
329-
"assets": ["assets/tokenizer/vocabulary.spm"],
330-
"weights": None,
331-
}
332-
tokenizer_config_json = json.dumps(tokenizer_config)
333-
with open(model_dir / "tokenizer.json", "w") as tokenizer_config_file:
334-
tokenizer_config_file.write(tokenizer_config_json)
335-
print("\n-> Saved tokenizer config")
231+
try:
232+
# === Load the Huggingface model ===
233+
hf_model = MistralForCausalLM.from_pretrained(hf_preset)
234+
hf_tokenizer = AutoTokenizer.from_pretrained(hf_preset)
235+
hf_model.eval()
236+
print("\n-> Huggingface model and tokenizer loaded")
237+
238+
# === Load the KerasNLP model ===
239+
backbone_kwargs = dict(
240+
vocabulary_size=hf_model.config.vocab_size,
241+
hidden_dim=hf_model.config.hidden_size,
242+
num_layers=hf_model.config.num_hidden_layers,
243+
num_query_heads=hf_model.config.num_attention_heads,
244+
num_key_value_heads=hf_model.config.num_key_value_heads,
245+
intermediate_dim=hf_model.config.intermediate_size,
246+
sliding_window=hf_model.config.sliding_window,
247+
layer_norm_epsilon=hf_model.config.rms_norm_eps,
248+
rope_max_wavelength=hf_model.config.rope_theta,
249+
dtype="float32",
250+
)
251+
keras_nlp_model = MistralBackbone(**backbone_kwargs)
336252

337-
# === Save metadata ===
338-
metadata_config = {
339-
"keras_version": keras.__version__,
340-
"keras_nlp_version": keras_nlp.__version__,
341-
"parameter_count": keras_nlp_model.count_params(),
342-
"date_saved": datetime.datetime.utcnow().strftime("%Y-%m-%d@%H:%M:%S"),
343-
}
344-
metadata_config_json = json.dumps(metadata_config)
345-
with open(model_dir / "metadata.json", "w") as metadata_config_file:
346-
metadata_config_file.write(metadata_config_json)
347-
print("\n-> Saved metadata")
253+
# === Download the tokenizer from Huggingface model card ===
254+
spm_path = (
255+
f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model"
256+
)
257+
response = requests.get(spm_path)
258+
if not response.ok:
259+
raise ValueError(f"Couldn't fetch {preset}'s tokenizer.")
260+
tokenizer_path = os.path.join(temp_dir, "vocabulary.spm")
261+
with open(tokenizer_path, "wb") as tokenizer_file:
262+
tokenizer_file.write(response.content)
263+
keras_nlp_tokenizer = MistralTokenizer(tokenizer_path)
264+
print("\n-> Keras 3 model and tokenizer loaded.")
265+
266+
# === Port the weights ===
267+
convert_checkpoints(keras_nlp_model, hf_model)
268+
print("\n-> Weight transfer done.")
269+
270+
# === Check that the models and tokenizers outputs match ===
271+
test_tokenizer(keras_nlp_tokenizer, hf_tokenizer)
272+
test_model(keras_nlp_model, keras_nlp_tokenizer, hf_model, hf_tokenizer)
273+
print("\n-> Tests passed!")
274+
275+
# === Save the model weights in float32 format ===
276+
keras_nlp_model.save_weights(os.path.join(temp_dir, "model.weights.h5"))
277+
print("\n-> Saved the model weights in float32")
278+
279+
del keras_nlp_model, hf_model
280+
gc.collect()
281+
282+
# === Save the weights again in float16 ===
283+
backbone_kwargs["dtype"] = "float16"
284+
keras_nlp_model = MistralBackbone(**backbone_kwargs)
285+
keras_nlp_model.load_weights(os.path.join(temp_dir, "model.weights.h5"))
286+
save_to_preset(keras_nlp_model, preset)
287+
print("\n-> Saved the model preset in float16")
288+
289+
# === Save the tokenizer ===
290+
save_to_preset(
291+
keras_nlp_tokenizer, preset, config_filename="tokenizer.json"
292+
)
293+
print("\n-> Saved the tokenizer")
294+
finally:
295+
shutil.rmtree(temp_dir)
348296

349297

350298
if __name__ == "__main__":

0 commit comments

Comments
 (0)