Skip to content

Commit 678b2f3

Browse files
feat(server): cleanup flash neox loading (#139)
1 parent d6a93fe commit 678b2f3

File tree

2 files changed

+53
-22
lines changed

2 files changed

+53
-22
lines changed

server/text_generation_server/models/flash_neox.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,6 @@ def generate_token(
450450
next_batch_input_ids = next_batch_input_ids[0].view(1)
451451
next_batch_past_key_values = next_batch_past_key_values[0]
452452

453-
print(next_batch_input_ids.shape)
454-
455453
next_batch = FlashNeoXBatch(
456454
batch_id=batch.batch_id,
457455
requests=next_batch_requests,
@@ -507,6 +505,7 @@ def __init__(
507505
rank=self.rank,
508506
world_size=self.world_size,
509507
)
508+
model.post_load_weights()
510509
self.model = model.eval().to(dtype)
511510
torch.distributed.barrier(group=self.process_group)
512511
super(FlashNeoX, self).__init__(

server/text_generation_server/models/flash_neox_modeling.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch.distributed
33

4+
from torch.nn import functional as F
5+
46
from torch import nn
57
from transformers.activations import ACT2FN
68
from transformers.modeling_utils import PreTrainedModel
@@ -24,13 +26,11 @@ def __init__(
2426
dtype=None,
2527
) -> None:
2628
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
27-
self.swap_dims = True
2829

29-
def forward(self, input: torch.Tensor) -> torch.Tensor:
30-
if self.swap_dims:
31-
self.weight = nn.Parameter(self.weight.T)
32-
self.swap_dims = False
30+
def transpose_weight(self):
31+
self.weight = nn.Parameter(self.weight.T)
3332

33+
def forward(self, input: torch.Tensor) -> torch.Tensor:
3434
if self.bias is not None:
3535
return torch.addmm(self.bias, input, self.weight)
3636
return torch.matmul(input, self.weight)
@@ -120,6 +120,10 @@ def __init__(
120120
self.min_id = self.tp_rank * block_size
121121
self.max_id = (self.tp_rank + 1) * block_size
122122

123+
# Additional entry that will map to zero
124+
# Used for masking
125+
self.null_idx = block_size
126+
123127
super().__init__(
124128
block_size,
125129
embedding_dim,
@@ -133,15 +137,19 @@ def __init__(
133137
dtype=dtype,
134138
)
135139

140+
def add_null_idx(self):
141+
"""Additional 0 entry used for masking"""
142+
self.weight = nn.Parameter(F.pad(self.weight, (0, 0, 0, 1)))
143+
136144
def forward(self, input: torch.Tensor) -> torch.Tensor:
137-
# `0` if input is in the correct interval, else `1`
138-
input_mask = torch.logical_or(self.min_id > input, input >= self.max_id)
145+
# default all out of bounds values to `self.null_idx` that will then be mapped to 0
139146
# translate for [0, self.max_id - self.min_id[
140-
input = input - self.min_id
141-
# default all out of bounds values to `0`
142-
input[input_mask] = 0
147+
input = torch.where(
148+
(self.min_id > input) | (input >= self.max_id),
149+
self.null_idx,
150+
input - self.min_id,
151+
)
143152
out = super().forward(input)
144-
out[input_mask] = 0.0
145153
torch.distributed.all_reduce(out, group=self.process_group)
146154
return out
147155

@@ -214,11 +222,9 @@ def __init__(
214222
hidden_size,
215223
process_group=process_group,
216224
)
217-
self.swap_dims = True
218225

219-
# TODO: remove and swap dims when loading weights
220-
def _swap_dims(self):
221-
"""Swap dims for the first inference to avoid an additional permute"""
226+
def shuffle_qkv_dims(self):
227+
"""Swap dims to avoid an additional permute"""
222228
self.query_key_value.weight = torch.nn.Parameter(
223229
self.query_key_value.weight.view(
224230
self.num_heads, 3, self.head_size, self.hidden_size
@@ -231,7 +237,6 @@ def _swap_dims(self):
231237
.permute(1, 0, 2)
232238
.reshape(-1)
233239
)
234-
self.swap_dims = False
235240

236241
def forward(
237242
self,
@@ -244,9 +249,6 @@ def forward(
244249
layer_past_present_indices,
245250
cu_seqlens_q,
246251
):
247-
if self.swap_dims:
248-
self._swap_dims()
249-
250252
qkv = self.query_key_value(hidden_states)
251253
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
252254
qkv_rot = self.rotary_emb(qkv, cos, sin)
@@ -329,7 +331,6 @@ def __init__(self, act, hidden_size, intermediate_size, process_group=None):
329331
hidden_size,
330332
process_group=process_group,
331333
)
332-
self.heuristic = "auto"
333334
self.process_group = process_group
334335

335336
def forward(self, hidden_states):
@@ -531,6 +532,25 @@ def __init__(self, config, process_group=None):
531532
self.head_size = self.layers[0].attention.head_size
532533
self.num_heads = self.layers[0].attention.num_heads
533534

535+
def post_load_weights(self):
536+
if isinstance(self.embed_in, TensorParallelEmbedding):
537+
self.embed_in.add_null_idx()
538+
for layer in self.layers:
539+
layer: FlashNeoXLayer
540+
layer.attention.shuffle_qkv_dims()
541+
layer.attention.query_key_value.transpose_weight()
542+
layer.attention.dense.transpose_weight()
543+
layer.mlp.dense_h_to_4h.transpose_weight()
544+
layer.mlp.dense_4h_to_h.transpose_weight()
545+
546+
@classmethod
547+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
548+
model = super(FlashGPTNeoXModel, cls).from_pretrained(
549+
pretrained_model_name_or_path, *model_args, **kwargs
550+
)
551+
model.post_load_weights()
552+
return model
553+
534554
def forward(
535555
self,
536556
input_ids,
@@ -627,6 +647,18 @@ def __init__(self, config):
627647
config.hidden_size, config.vocab_size, bias=False
628648
)
629649

650+
def post_load_weights(self):
651+
self.gpt_neox.post_load_weights()
652+
self.embed_out.transpose_weight()
653+
654+
@classmethod
655+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
656+
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
657+
pretrained_model_name_or_path, *model_args, **kwargs
658+
)
659+
model.post_load_weights()
660+
return model
661+
630662
def forward(
631663
self,
632664
input_ids,

0 commit comments

Comments
 (0)