Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft : HACK to preprocess data and pass to SFT TRainer #38

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

Ssukriti
Copy link
Collaborator

@Ssukriti Ssukriti commented Feb 9, 2024

To test this:
twitter_complaints_formatted.json

on Local installation of SFT TRainer : https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L248 comment lines 248 and 249
if dataset_text_field is None and formatting_func is None:
to allow input/output dataset

Command to run

python tuning/sft_trainer.py  \
--model_name_or_path $MODEL_PATH  \
--data_path $DATA_PATH  \
--output_dir $OUTPUT_PATH  \
--num_train_epochs 20  \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4  \
--gradient_accumulation_steps 1  \
--evaluation_strategy "no"  \
--save_strategy "epoch"  \
--learning_rate 0.03  \
--weight_decay 0.  \
--warmup_ratio 0.03  \
--lr_scheduler_type "cosine"  \
--logging_steps 1  \
--include_tokens_per_second  \
--packing False  \
--use_flash_attn False  \
--tokenizer_name_or_path $MODEL_PATH \
--torch_dtype "float32" \
--peft_method "pt" \
--num_virtual_tokens 1500 \
--prompt_tuning_init_text "Classify if the tweet is a complaint or not:"

Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>

def infer_max_steps(
num_epochs: int,
batch_size: int,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gradient_accumlation needs to be taken into consideration for this.

if dataset_type == IterableDataset:
return mapped_dataset
else:
return HFDataset(mapped_dataset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are not using iterable dataset, this can potentially blow up memory for larger dataset. I think the processing would also happen upfront

max_concat_length = max_seq_length

# Truncate based on max source or max target length before considering as a joined sequence
model_inputs = tokenizer(source)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm.. by default truncation is off in tokenizer ? So no truncation would happen ? (as opposed to comment above)

@gkumbhat
Copy link
Collaborator

QQ for the testing parameters:

  1. Not sure which model you tried with, but larger models with float32 will probably get closer to OOM, specially train time.
  2. Do we really need to use 1500 virtual token for twitter dataset? That seems a lot. Would have implications on quality

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants