Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into mamba_merge_main
Browse files Browse the repository at this point in the history
  • Loading branch information
rrutmann committed May 27, 2024
2 parents d094fa6 + 2b562b1 commit 8c7f9bd
Show file tree
Hide file tree
Showing 12 changed files with 4 additions and 314 deletions.
25 changes: 3 additions & 22 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from modalities.activation_checkpointing import apply_activation_checkpointing_inplace
from modalities.batch import EvaluationResultBatch
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ProcessGroupBackendType, TokenizerTypes, load_app_config_dict
from modalities.config.config import ProcessGroupBackendType, load_app_config_dict
from modalities.config.instantiation_models import (
PackedDatasetComponentsInstantiationModel,
TrainingComponentsInstantiationModel,
Expand Down Expand Up @@ -56,33 +56,14 @@ def entry_point_run_modalities(config_file_path: Path):


@main.command(name="generate_text")
@click.argument("model_path", type=Path)
@click.option(
"--config_file_path",
type=click_pathlib.Path(exists=False),
required=True,
help="Path to a file with the YAML config file.",
)
@click.option(
"--tokenizer_type",
type=TokenizerTypes,
show_default=True,
default=TokenizerTypes.GPT2TokenizerFast,
help="Specify which Tokenizer (inheriting from transformers.PretrainedTokenizers) should get used.",
)
@click.option(
"--tokenizer_file",
type=Path,
show_default=True,
default=Path(__file__).parents[2] / Path("data/tokenizer/tokenizer.json"),
help="path to tokenizer json",
)
@click.option("--max_new_tokens", type=int, show_default=True, default=200, help="maximum amount of tokens to generate")
@click.option("--chat", is_flag=True, show_default=True, default=False, help="activate 'chat' mode")
def entry_point_generate_text(model_path, config_file_path, tokenizer_type, tokenizer_file, max_new_tokens, chat):
tokenizer = tokenizer_type.value(tokenizer_file=str(tokenizer_file))
with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl):
generate_text_main(model_path, config_file_path, tokenizer, max_new_tokens, chat)
def entry_point_generate_text(config_file_path: FilePath):
generate_text(config_file_path)


@main.group(name="data")
Expand Down
2 changes: 0 additions & 2 deletions src/modalities/checkpointing/fsdp/fsdp_checkpoint_saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from modalities.exceptions import CheckpointingError




class CheckpointingEntityType(Enum):
MODEL = "model"
OPTIMIZER = "optimizer"
Expand Down
13 changes: 0 additions & 13 deletions src/modalities/models/coca/coca_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,3 @@ def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) ->
decoder_outputs = self.multimodal_decoder(decoder_inputs)
logits = decoder_outputs[self.multimodal_decoder.prediction_key]
return logits

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
):
raise NotImplementedError

def generate(self, stop_token_ids: List[int], input_ids: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0) -> torch.Tensor:
raise NotImplementedError
13 changes: 0 additions & 13 deletions src/modalities/models/coca/multi_modal_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,16 +123,3 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
return {self.prediction_key: logits}

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
):
raise NotImplementedError

def generate(self, stop_token_ids: List[int], input_ids: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0) -> torch.Tensor:
raise NotImplementedError
13 changes: 0 additions & 13 deletions src/modalities/models/coca/text_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,3 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for block in self.transformer.h:
x = block(x)
return {self.prediction_key: x}

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
):
raise NotImplementedError

def generate(self, stop_token_ids: List[int], input_ids: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0) -> torch.Tensor:
raise NotImplementedError
35 changes: 0 additions & 35 deletions src/modalities/models/gpt2/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,38 +465,3 @@ def forward_impl(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tenso

def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
return self.forward_impl(inputs)

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
):
in_batch = tokenizer([context])
in_batch[self.sample_key] = torch.Tensor(in_batch[self.sample_key]).to(torch.int64)

for _ in range(max_new_tokens):
in_batch[self.sample_key] = (
in_batch[self.sample_key] if in_batch[self.sample_key].size(1) <= self.block_size else in_batch[
self.sample_key][
:,
-self.block_size:]
)
logits = self.forward(in_batch)[self.prediction_key]
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx_next_str = tokenizer.decode(idx_next[0])
if idx_next_str == tokenizer.eos_token:
print("\n<reached eos token>", end="")
break
else:
print(idx_next_str, end="")
sys.stdout.flush()
in_batch[self.sample_key] = torch.cat((in_batch[self.sample_key], idx_next), dim=1)
print("")

def generate(self, stop_token_ids: List[int], input_ids: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0) -> torch.Tensor:
raise NotImplementedError
13 changes: 0 additions & 13 deletions src/modalities/models/huggingface/huggingface_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def fsdp_block_names(self) -> List[str]:
return self.huggingface_model._no_split_modules

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
):
raise NotImplementedError

def generate(self, stop_token_ids: List[int], input_ids: torch.Tensor, max_new_tokens: int,
temperature: float = 1.0) -> torch.Tensor:
raise NotImplementedError


if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
Expand Down
43 changes: 0 additions & 43 deletions src/modalities/models/mamba/mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,46 +237,3 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
hidden_states = hidden_states[:, -self.num_last_tokens:]
lm_logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype))
return {self.prediction_key: lm_logits}

def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
) -> str:
assert temperature > 0
if not context:
raise ValueError("Context must be not empty")

in_batch = tokenizer([context])
in_batch[self.sample_key] = torch.Tensor(in_batch[self.sample_key]).to(torch.int32).to(
next(self.parameters()).device)

generated_ids = self.generate(stop_token_ids=[tokenizer.eos_token_id],
input_ids=in_batch[self.sample_key],
max_new_tokens=max_new_tokens, temperature=temperature)

generated_string = tokenizer.decode(generated_ids.tolist()[0])
return generated_string

def generate(
self,
stop_token_ids: List[int],
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
) -> torch.Tensor:

for _ in range(max_new_tokens):
logits = self.forward({self.sample_key: input_ids})[
self.prediction_key]
logits = logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)

if idx_next.item() in stop_token_ids:
break
else:
input_ids = torch.cat([input_ids, idx_next], dim=1)
return input_ids
19 changes: 0 additions & 19 deletions src/modalities/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,25 +21,6 @@ def forward(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
def get_parameters(self) -> Dict[str, torch.Tensor]:
return {name: param for name, param in self.named_parameters()}

@abstractmethod
def generate_text(
self,
tokenizer: PreTrainedTokenizer,
context: str,
max_new_tokens: int,
temperature: float = 1.0,
) -> str:
...

@abstractmethod
def generate(
self,
stop_token_ids: List[int],
input_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
) -> torch.Tensor:
...


def model_predict_batch(model: nn.Module, batch: DatasetBatch) -> InferenceResultBatch:
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/running_env/fsdp/fsdp_auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, List

import torch.nn as nn
from accelerate.utils.dataclasses import FullyShardedDataParallelPlugin
from accelerate.utils.dataclasses import get_module_class_from_name
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from modalities.config.lookup_enum import LookupEnum
Expand Down
114 changes: 0 additions & 114 deletions src/modalities/utils/generate_text.py

This file was deleted.

26 changes: 0 additions & 26 deletions tests/models/mamba/test_mamba_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,3 @@ def test_tie_weights(mamba_llm):
mamba_llm.tie_embeddings = True
mamba_llm.tie_weights()
assert (mamba_llm.lm_head.weight == mamba_llm.backbone.embedding.weight).all()


def test_generate_text(d_model, n_layer, rms_norm, residual_in_fp32, fused_add_norm, prediction_key, sample_key, seed, dtype, initializer_cfg, mixer_model_config):
mamba_llm = MambaLLM(d_model=d_model, n_layer=n_layer, vocab_size=50257, rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32, fused_add_norm=fused_add_norm, pad_vocab_size_multiple=1,
tie_embeddings=False, prediction_key=prediction_key, sample_key=sample_key, seed=seed, dtype=dtype,
initializer_cfg=initializer_cfg, num_last_tokens=0, inference_params={},
mixer_model_config=mixer_model_config)
tokenizer = AutoTokenizer.from_pretrained(_ROOT_DIR / "data/tokenizer/hf_gpt2")
context = "My name is"
output = mamba_llm.to("cuda").generate_text(tokenizer=tokenizer, context=context, max_new_tokens=5,
temperature=1)
assert type(output) == str
assert context in output
assert len(output) > len(context)

def test_generate(mamba_llm, vocab_size):
num_input_tokens = 3
max_new_tokens = 5
input_ids = torch.randint(0, vocab_size, (1, num_input_tokens)).to("cuda")
output = mamba_llm.to("cuda").generate(stop_token_ids=[],input_ids=input_ids, max_new_tokens=max_new_tokens,temperature=1)

assert type(output) == torch.Tensor
assert output.shape[1] == num_input_tokens + max_new_tokens


0 comments on commit 8c7f9bd

Please sign in to comment.