Skip to content

Commit

Permalink
fix llama2
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Dec 10, 2024
1 parent 8d7dafe commit d2ffd59
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions scripts/export_dcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def _get_ffn_dim(hidden_dim: int, ffn_dim_multiplier: float, multiple_of: int) -
return hidden_dim


def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = False) -> LlamaConfig:
def convert_config_zb_to_hf(
zb_config: ModelArgs, with_debug_automap: bool = False, type_model: str = "llama3"
) -> LlamaConfig:
"""Convert ZeroBand config to HuggingFace config"""
config = LlamaConfig()
config.hidden_size = zb_config.dim
Expand All @@ -64,7 +66,7 @@ def convert_config_zb_to_hf(zb_config: ModelArgs, with_debug_automap: bool = Fal
config.rope_theta = float(zb_config.rope_theta)
config.max_position_embeddings = zb_config.max_seq_len

if zb_config.type_model == "llama3":
if type_model == "llama2":
config.bos_token_id = [1]
config.eos_token_id = [2]
else:
Expand Down Expand Up @@ -140,7 +142,9 @@ def main(config: ExportConfig):
)

# Convert ZeroBand config to HuggingFace config
hf_config = convert_config_zb_to_hf(model_config, with_debug_automap=config.with_debug_automap)
hf_config = convert_config_zb_to_hf(
model_config, with_debug_automap=config.with_debug_automap, type_model=config.type_model
)
hf_config.to_json_file(save_path / "config.json")

# Load checkpoint
Expand Down

0 comments on commit d2ffd59

Please sign in to comment.