-
-
Notifications
You must be signed in to change notification settings - Fork 851
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add kto support * test cleanup * fix outdated comment * fix llama3 ultra * chore: lint * update to use rl_beta instead of dpo_beta --------- Co-authored-by: Wing Lian <[email protected]>
- Loading branch information
1 parent
ba45531
commit 22ae21a
Showing
11 changed files
with
435 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
module for KTO style dataset transform strategies | ||
""" | ||
|
||
from functools import partial | ||
|
||
from ..base import load as load_base | ||
|
||
load = partial(load_base, module_base="axolotl.prompt_strategies.kto") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
KTO strategies for chatml | ||
""" | ||
# pylint: disable=duplicate-code | ||
|
||
|
||
def argilla( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n" | ||
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["completion"] = f"{sample['completion']}<|im_end|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def argilla_chat( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
for argilla/kto-mix-15k conversations | ||
""" | ||
|
||
def transform_fn(sample): | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
For Intel Orca KTO | ||
ex: argilla/distilabel-intel-orca-kto | ||
""" | ||
|
||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n" | ||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["completion"] = f"{sample['completion']}<|im_end|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def prompt_pairs( | ||
cfg, **kwargs | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n" | ||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["completion"] = f"{sample['completion']}<|im_end|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
for ultrafeedback binarized conversations | ||
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto | ||
""" | ||
|
||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|im_start|>system\n{sample['system']}<|im_end|>\n" | ||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n" | ||
sample["completion"] = f"{sample['completion']}<|im_end|>" | ||
return sample | ||
|
||
return transform_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
""" | ||
KTO strategies for llama-3 chat template | ||
""" | ||
# pylint: disable=duplicate-code | ||
|
||
|
||
def argilla( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" | ||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["completion"] = f"{sample['completion']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def argilla_chat( | ||
cfg, | ||
**kwargs, | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
for argilla/kto-mix-15k conversations | ||
""" | ||
|
||
def transform_fn(sample): | ||
sample[ | ||
"prompt" | ||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
For Intel Orca KTO | ||
ex: argilla/distilabel-intel-orca-kto | ||
""" | ||
|
||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" | ||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["completion"] = f"{sample['completion']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def prompt_pairs( | ||
cfg, **kwargs | ||
): # pylint: disable=possibly-unused-variable,unused-argument | ||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" | ||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["completion"] = f"{sample['completion']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn | ||
|
||
|
||
def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument | ||
""" | ||
for ultrafeedback binarized conversations | ||
ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto | ||
""" | ||
|
||
def transform_fn(sample): | ||
if "system" in sample and sample["system"]: | ||
sample["prompt"] = ( | ||
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" | ||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
) | ||
else: | ||
sample[ | ||
"prompt" | ||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" | ||
sample["completion"] = f"{sample['completion']}<|eot_id|>" | ||
return sample | ||
|
||
return transform_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
""" | ||
User-defined KTO strategies | ||
""" | ||
# pylint: disable=duplicate-code | ||
|
||
|
||
def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument | ||
ds_cfg = cfg["datasets"][dataset_idx]["type"] | ||
if not isinstance(ds_cfg, dict): | ||
raise ValueError( | ||
f"User-defined dataset type must be a dictionary. Got: {ds_cfg}" | ||
) | ||
field_prompt = ds_cfg.get("field_prompt", "prompt") | ||
field_system = ds_cfg.get("field_system", "system") | ||
field_completion = ds_cfg.get("field_completion", "completion") | ||
field_label = ds_cfg.get("field_label", "label") | ||
prompt_format = ds_cfg.get("prompt_format") | ||
if not prompt_format: | ||
prompt_format = "{" + field_prompt + "}" | ||
completion_format = ds_cfg.get("completion_format") | ||
if not completion_format: | ||
chosen_format = "{" + field_completion + "}" | ||
|
||
def transform_fn(sample): | ||
if ( | ||
"{" + field_system + "}" in prompt_format | ||
and field_system in sample | ||
and sample[field_system] | ||
): | ||
sample["prompt"] = prompt_format.format( | ||
system=sample[field_system], prompt=sample[field_prompt] | ||
) | ||
else: | ||
sample["prompt"] = prompt_format.format(prompt=sample["prompt"]) | ||
sample["completion"] = chosen_format.format(chosen=sample[field_completion]) | ||
sample["label"] = sample[field_label] | ||
return sample | ||
|
||
return transform_fn |
Oops, something went wrong.