Skip to content

Commit

Permalink
add eval (#175)
Browse files Browse the repository at this point in the history
* add eval

* fix llama2 compatibility

* update readme

* fix llama2

* Update README.md

* update instruction

---------

Co-authored-by: Johannes Hagemann <[email protected]>
  • Loading branch information
samsja and JohannesHa authored Dec 10, 2024
1 parent a116ef1 commit a974cf5
Show file tree
Hide file tree
Showing 4 changed files with 598 additions and 9 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,22 @@ Ensure you have at least two GPU to run the full test suite:
```bash
uv run pytest
```


### Eval

To eval you need first to convert the checkpoint to a huggingface compatible model.

```bash
uv run python scripts/export_dcp.py @configs/10B/H100.toml --ckpt.path CONVERTED_MODEL_PATH --ckpt.resume CHECKPOINT_PATH --torch_dtype bfloat16 --ckpt.interval 1
```


```
uv run accelerate launch -m lm_eval --model hf --model_args pretrained=CONVERTED_MODEL_PATH,add_bos_token=True --tasks hellaswag --num_fewshot 10
```


## Environment variables
### Global Store Initialization
| Environment Variable | Description | Default Value |
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@ dependencies = [
]

[project.optional-dependencies]


all = [
"wandb",
"asyncio>=3.4.3",
"aiohttp>=3.10.5",
"requests>=2.32.3",
"lm-eval"
]


[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
Expand Down
33 changes: 24 additions & 9 deletions scripts/export_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _get_ffn_dim(hidden_dim: int, ffn_dim_multiplier: float, multiple_of: int) -
return hidden_dim


def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = False) -> LlamaConfig:
def convert_config_zb_to_hf(
zb_config: ModelArgs, with_debug_automap: bool = False, type_model: str = "llama3"
) -> LlamaConfig:
"""Convert ZeroBand config to HuggingFace config"""
config = LlamaConfig()
config.hidden_size = zb_config.dim
Expand All @@ -63,8 +65,14 @@ def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = Fal
config.rms_norm_eps = zb_config.norm_eps
config.rope_theta = float(zb_config.rope_theta)
config.max_position_embeddings = zb_config.max_seq_len
config.bos_token_id = 128000
config.eos_token_id = [128001, 128008, 128009]

if type_model == "llama2":
config.bos_token_id = [1]
config.eos_token_id = [2]
else:
config.bos_token_id = [128000]
config.eos_token_id = [128001, 128008, 128009]

config.architectures = ["LlamaForCausalLM"]

# Rope scaling
Expand All @@ -76,11 +84,12 @@ def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = Fal
if with_debug_automap:
config.auto_map = {
"AutoConfig": "PrimeIntellect/prime-llama-debug--configuration_llama.LlamaConfig",
"AutoModelForCausalLM": "PrimeIntellect/prime-llama-debug--modeling_llama.LlamaForCausalLM"
"AutoModelForCausalLM": "PrimeIntellect/prime-llama-debug--modeling_llama.LlamaForCausalLM",
}

return config


@torch.no_grad
def convert_qk_from_complex_to_rotate_half(linear_weight: torch.FloatTensor, head_dim: int) -> torch.FloatTensor:
"""Converts the Q/K weight from complex to rotate half form.
Expand All @@ -99,8 +108,12 @@ def convert_qk_from_complex_to_rotate_half(linear_weight: torch.FloatTensor, hea
# This applies the riffle shuffle permutation to the outputs of the linear for each attn head
# Even numbers go to the top half, odd numbers go to the bottom half
for i in range(num_heads):
new_weight[i * head_dim:(i * head_dim + hhd), :].copy_(linear_weight[i * head_dim + 0:(i + 1) * head_dim:2, :])
new_weight[i * head_dim + hhd:(i + 1) * head_dim, :].copy_(linear_weight[i * head_dim + 1:(i + 1) * head_dim:2, :])
new_weight[i * head_dim : (i * head_dim + hhd), :].copy_(
linear_weight[i * head_dim + 0 : (i + 1) * head_dim : 2, :]
)
new_weight[i * head_dim + hhd : (i + 1) * head_dim, :].copy_(
linear_weight[i * head_dim + 1 : (i + 1) * head_dim : 2, :]
)

return new_weight

Expand All @@ -127,9 +140,11 @@ def main(config: ExportConfig):
seq_length=config.data.seq_length,
attn_fn=config.train.attn_fn,
)

# Convert ZeroBand config to HuggingFace config
hf_config = convert_config_zb_to_hf(model_config, with_debug_automap=config.with_debug_automap)
hf_config = convert_config_zb_to_hf(
model_config, with_debug_automap=config.with_debug_automap, type_model=config.type_model
)
hf_config.to_json_file(save_path / "config.json")

# Load checkpoint
Expand All @@ -152,7 +167,7 @@ def main(config: ExportConfig):
index_json = {}
total_size = 0
state_dict = {remap_keys_llama(k): v for k, v in state_dict.items()}
if not config.with_debug_automap: # The debug uses complex rotary impl
if not config.with_debug_automap: # The debug uses complex rotary impl
with torch.no_grad():
for i in range(hf_config.num_hidden_layers):
old_q = state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]
Expand Down
Loading

0 comments on commit a974cf5

Please sign in to comment.