Skip to content

Commit f208ba6

Browse files
authored
Fix HF_HUB_OFFLINE=1 for Gaudi backend (#3193)
* Fix `HF_HUB_OFFLINE=1` for Gaudi backend * Fix HF cache default value in server.rs * Format
1 parent 7253be3 commit f208ba6

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

backends/gaudi/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ PYTORCH_VERSION := 2.6.0
88
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install
99

1010
image:
11-
docker build -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
11+
docker build --ulimit nofile=4096 -t tgi-gaudi -f ${root_dir}/Dockerfile_gaudi ${root_dir} --build-arg HABANA_VERSION=$(HABANA_VERSION) --build-arg PYTORCH_VERSION=$(PYTORCH_VERSION)
1212

1313
run-local-dev-container:
1414
docker run -it \

backends/gaudi/server/text_generation_server/models/causal_lm.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dataclasses import dataclass
55
from functools import wraps
66
import itertools
7+
import json
78
import math
89
import os
910
import tempfile
@@ -17,15 +18,12 @@
1718
from opentelemetry import trace
1819

1920
import text_generation_server.habana_quantization_env as hq_env
21+
from text_generation_server.utils import weight_files
2022
import habana_frameworks.torch as htorch
2123
from optimum.habana.utils import HabanaProfile
2224
from optimum.habana.transformers.generation import MODELS_OPTIMIZED_WITH_STATIC_SHAPES
2325
from text_generation_server.utils.chunks import concat_text_chunks
24-
from optimum.habana.checkpoint_utils import (
25-
get_repo_root,
26-
model_on_meta,
27-
write_checkpoints_json,
28-
)
26+
from optimum.habana.checkpoint_utils import model_on_meta
2927
from transformers import (
3028
AutoTokenizer,
3129
AutoModelForCausalLM,
@@ -708,15 +706,16 @@ def __init__(
708706
if hq_env.is_quantization_enabled:
709707
htorch.core.hpu_set_env()
710708

709+
# Get weight files
710+
weight_files(model_id, revision=revision, extension=".safetensors")
711+
711712
if world_size > 1:
712713
os.environ.setdefault(
713714
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
714715
)
715716
model = self.get_deepspeed_model(model_id, dtype, revision)
716717
model = hq_env.prepare_model_for_quantization(model)
717718
else:
718-
get_repo_root(model_id)
719-
720719
# Check support for rope scaling
721720
model_kwargs = {}
722721
config = AutoConfig.from_pretrained(model_id)
@@ -868,7 +867,6 @@ def get_deepspeed_model(
868867
with deepspeed.OnDevice(dtype=dtype, device="meta"):
869868
model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype)
870869
else:
871-
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
872870
# TODO: revisit placement on CPU when auto-injection is possible
873871
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
874872
model = AutoModelForCausalLM.from_pretrained(
@@ -884,7 +882,16 @@ def get_deepspeed_model(
884882
if load_to_meta:
885883
# model loaded to meta is managed differently
886884
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
887-
write_checkpoints_json(model_id, local_rank, checkpoints_json)
885+
checkpoint_files = [
886+
str(f)
887+
for f in weight_files(
888+
model_id, revision=revision, extension=".safetensors"
889+
)
890+
]
891+
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
892+
json.dump(data, checkpoints_json)
893+
checkpoints_json.flush()
894+
888895
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
889896
model = deepspeed.init_inference(model, **ds_inference_kwargs)
890897

backends/gaudi/server/text_generation_server/models/vlm_causal_lm.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import re
23
import torch
34
import os
@@ -12,6 +13,7 @@
1213
import copy
1314
from text_generation_server.models import Model
1415
from transformers import PreTrainedTokenizerBase
16+
from text_generation_server.utils import weight_files
1517
from text_generation_server.utils.tokens import batch_top_tokens
1618
from text_generation_server.pb import generate_pb2
1719
from text_generation_server.models.causal_lm import (
@@ -43,11 +45,7 @@
4345
AutoTokenizer,
4446
AutoConfig,
4547
)
46-
from optimum.habana.checkpoint_utils import (
47-
get_repo_root,
48-
model_on_meta,
49-
write_checkpoints_json,
50-
)
48+
from optimum.habana.checkpoint_utils import model_on_meta
5149

5250
from text_generation_server.utils.speculate import get_speculate
5351
from text_generation_server.models.types import (
@@ -840,15 +838,16 @@ def __init__(
840838
if hq_env.is_quantization_enabled:
841839
htorch.core.hpu_set_env()
842840

841+
# Get weight files
842+
weight_files(model_id, revision=revision, extension=".safetensors")
843+
843844
if world_size > 1:
844845
os.environ.setdefault(
845846
"DEEPSPEED_USE_HABANA_FRAMEWORKS_DETERMINISTIC_API", "1"
846847
)
847848
model = self.get_deepspeed_model(model_class, model_id, dtype, revision)
848849
model = hq_env.prepare_model_for_quantization(model)
849850
else:
850-
get_repo_root(model_id)
851-
852851
# Check support for rope scaling
853852
model_kwargs = {}
854853
config = AutoConfig.from_pretrained(model_id)
@@ -1000,7 +999,6 @@ def get_deepspeed_model(
1000999
with deepspeed.OnDevice(dtype=dtype, device="meta"):
10011000
model = model_class.from_config(config, torch_dtype=dtype)
10021001
else:
1003-
get_repo_root(model_id, local_rank=os.getenv("LOCAL_RANK"))
10041002
# TODO: revisit placement on CPU when auto-injection is possible
10051003
with deepspeed.OnDevice(dtype=dtype, device="cpu"):
10061004
model = model_class.from_pretrained(
@@ -1019,7 +1017,15 @@ def get_deepspeed_model(
10191017
if load_to_meta:
10201018
# model loaded to meta is managed differently
10211019
checkpoints_json = tempfile.NamedTemporaryFile(suffix=".json", mode="+w")
1022-
write_checkpoints_json(model_id, local_rank, checkpoints_json)
1020+
checkpoint_files = [
1021+
str(f)
1022+
for f in weight_files(
1023+
model_id, revision=revision, extension=".safetensors"
1024+
)
1025+
]
1026+
data = {"type": "ds_model", "checkpoints": checkpoint_files, "version": 1.0}
1027+
json.dump(data, checkpoints_json)
1028+
checkpoints_json.flush()
10231029
ds_inference_kwargs["checkpoint"] = checkpoints_json.name
10241030
model = deepspeed.init_inference(model, **ds_inference_kwargs)
10251031

router/src/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1578,7 +1578,7 @@ pub async fn run(
15781578
let cache = std::env::var("HUGGINGFACE_HUB_CACHE")
15791579
.map_err(|_| ())
15801580
.map(|cache_dir| Cache::new(cache_dir.into()))
1581-
.unwrap_or_else(|_| Cache::default());
1581+
.unwrap_or_else(|_| Cache::from_env());
15821582
tracing::warn!("Offline mode active using cache defaults");
15831583
Type::Cache(cache)
15841584
} else {

0 commit comments

Comments
 (0)