Skip to content

Commit a9beab0

Browse files
rebased and added other profile
1 parent 72e9814 commit a9beab0

File tree

6 files changed

+12
-33
lines changed

6 files changed

+12
-33
lines changed

Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@ hf-hub = { version = "0.3.1", features = ["tokio"] }
2121
[profile.release]
2222
debug = 1
2323
incremental = true
24+
panic = "abort"
25+
26+
[profile.release-opt]
27+
inherits = "release"
28+
debug = 0
29+
incremental = false
2430
lto = "fat"
2531
opt-level = 3
2632
codegen-units = 1
27-
panic = "abort"

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ COPY proto proto
3333
COPY benchmark benchmark
3434
COPY router router
3535
COPY launcher launcher
36-
RUN cargo build --release
36+
RUN cargo build --profile release-opt
3737

3838
# Python builder
3939
# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile

Dockerfile_amd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ COPY proto proto
3333
COPY benchmark benchmark
3434
COPY router router
3535
COPY launcher launcher
36-
RUN cargo build --release
36+
RUN cargo build --profile release-opt
3737

3838
# Text Generation Inference base image for RoCm
3939
FROM rocm/dev-ubuntu-22.04:6.1.1_hip_update as base

Dockerfile_intel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ COPY proto proto
3232
COPY benchmark benchmark
3333
COPY router router
3434
COPY launcher launcher
35-
RUN cargo build --release
35+
RUN cargo build --profile release-opt
3636

3737

3838
# Text Generation Inference base image for Intel

server/text_generation_server/models/flash_causal_lm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -766,8 +766,7 @@ def init_kv_cache(
766766
device: torch.device,
767767
):
768768
self.kv_cache = []
769-
if IS_CUDA_SYSTEM:
770-
torch.cuda.empty_cache()
769+
empty_cache()
771770

772771
element_size = torch.tensor([], dtype=dtype).element_size()
773772
if SYSTEM == "xpu":
@@ -960,7 +959,6 @@ def tunableop_warmup(self, seqlen: int):
960959
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
961960
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
962961
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
963-
kv_cache = get_cache_manager().kv_cache
964962

965963
# Dummy value, some models (starcoder2) don't accept `None`.
966964
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
@@ -972,12 +970,13 @@ def tunableop_warmup(self, seqlen: int):
972970
cu_seqlen_prefill=torch.tensor(
973971
[0, seqlen], device=self.device, dtype=torch.int32
974972
),
975-
kv_cache=get_cache_manager().kv_cache,
973+
kv_cache=self.kv_cache,
976974
block_tables=None,
977975
input_lengths=input_lengths,
978976
slots=slots,
979977
max_s=seqlen,
980978
lm_head_indices=None,
979+
prefill_cache_indices=None,
981980
)
982981

983982
def forward(

server/text_generation_server/models/flash_mistral.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -98,31 +98,6 @@ def get_layer_config(self, model) -> Tuple[int, int, int]:
9898
model.model.head_size,
9999
)
100100

101-
def tunableop_warmup(self, seqlen: int):
102-
input_ids = torch.zeros(seqlen, dtype=torch.int64, device=self.device)
103-
position_ids = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
104-
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
105-
kv_cache = get_cache_manager().kv_cache
106-
107-
# Dummy value, some models (starcoder2) don't accept `None`.
108-
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
109-
110-
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
111-
self.model.forward(
112-
input_ids=input_ids,
113-
position_ids=position_ids,
114-
cu_seqlen_prefill=torch.tensor(
115-
[0, seqlen], device=self.device, dtype=torch.int32
116-
),
117-
kv_cache=self.kv_cache,
118-
block_tables=None,
119-
input_lengths=input_lengths,
120-
slots=slots,
121-
max_s=seqlen,
122-
lm_head_indices=None,
123-
prefill_cache_indices=None,
124-
)
125-
126101

127102
class FlashMistral(BaseFlashMistral):
128103
def __init__(

0 commit comments

Comments
 (0)