Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eval modalities #1

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions lm_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from . import megatronlm
from . import textsynth
from . import dummy
from . import modalities

MODEL_REGISTRY = {
"hf": gpt2.HFLM,
Expand All @@ -17,6 +18,7 @@
"megatronlm": megatronlm.MegatronLMClient,
"textsynth": textsynth.TextSynthLM,
"dummy": dummy.DummyLM,
"modalities": modalities.Modalities
}


Expand Down
22 changes: 22 additions & 0 deletions lm_eval/models/modalities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

from typing import Union, List, Optional
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, BatchEncoding
from modalities.config.config import HuggingFaceModelConfig
from modalities.models.gpt2.huggingface_model import HuggingFaceModel
from .huggingface import AutoCausalLM

TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding]


class Modalities(AutoCausalLM):
def __init__(self, *args, **kwargs):
AutoConfig.register("modalities_gpt2", HuggingFaceModelConfig)
AutoModelForCausalLM.register(HuggingFaceModelConfig, HuggingFaceModel)
# TODO load our own tokenizer
super().__init__(tokenizer="gpt2", *args, **kwargs)

def _model_call(
self, inputs: TokenSequence, labels: Optional[TokenSequence] = None
) -> TokenSequence:
return self.model(inputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is labels never needed?

24 changes: 22 additions & 2 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,26 @@
]


# Test HuggingFace Models (GPT-2)
def test_modalities():
# dismiss sequences that are too long for our test checkpoint
test_cases = LOGLIKELIHOOD_TEST_CASES[:5]
modalities = models.get_model("modalities").create_from_arg_string("pretrained='testdata/models/modalities/checkpoint'")
results = modalities.loglikelihood(test_cases)
for loglikelihood, is_max_loglikelihood in results:
assert type(loglikelihood) == float
assert type(is_max_loglikelihood) == bool

# test empty context
modalities.loglikelihood([("", "test")])
greedy_len = 20
(gen,) = modalities.greedy_until(
[("The quick brown fox jumps over the lazy", {"until": [".", "\n"], "max_length": greedy_len})]
)
ajude2s marked this conversation as resolved.
Show resolved Hide resolved
assert type(gen) == str
assert len(gen.split()) == greedy_len


# Test HuggingFace Models (GPT-2)
def test_gpt2():
gpt2 = models.get_model("gpt2").create_from_arg_string("device=cpu")
(
Expand All @@ -78,8 +95,11 @@ def test_gpt2():
# test empty context
gpt2.loglikelihood([("", "test")])

# (gen,) = gpt2.greedy_until(
# [("The quick brown fox jumps over the lazy", [".", "\n"])]
# )
ajude2s marked this conversation as resolved.
Show resolved Hide resolved
(gen,) = gpt2.greedy_until(
ajude2s marked this conversation as resolved.
Show resolved Hide resolved
[("The quick brown fox jumps over the lazy", [".", "\n"])]
[("The quick brown fox jumps over the lazy", {"until": [".", "\n"]})]
)

assert gen == ", lazy fox and they both fall to the ground"
Expand Down
1 change: 1 addition & 0 deletions tests/testdata/models/modalities/checkpoint/config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"config": {"sample_key": "input_ids", "prediction_key": "logits", "block_size": 128, "vocab_size": 50304, "n_layer": 1, "n_head": 1, "n_embd": 128, "ffn_hidden": 128, "dropout": 0.0, "bias": true, "attention": {"attention_type": "pytorch_flash_attention", "scaling_factor": 3}, "activation": "gelu", "epsilon": 1e-05, "weight_init": {"mean": 0.0, "std": 0.02}}, "model_type": "modalities_gpt2"}
Binary file not shown.