-
Notifications
You must be signed in to change notification settings - Fork 48
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft : HACK to preprocess data and pass to SFT TRainer #38
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
Signed-off-by: Sukriti-Sharma4 <[email protected]>
|
||
def infer_max_steps( | ||
num_epochs: int, | ||
batch_size: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gradient_accumlation needs to be taken into consideration for this.
if dataset_type == IterableDataset: | ||
return mapped_dataset | ||
else: | ||
return HFDataset(mapped_dataset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since we are not using iterable dataset, this can potentially blow up memory for larger dataset. I think the processing would also happen upfront
max_concat_length = max_seq_length | ||
|
||
# Truncate based on max source or max target length before considering as a joined sequence | ||
model_inputs = tokenizer(source) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm.. by default truncation is off in tokenizer ? So no truncation would happen ? (as opposed to comment above)
QQ for the testing parameters:
|
To test this:
twitter_complaints_formatted.json
on Local installation of SFT TRainer : https://github.com/huggingface/trl/blob/main/trl/trainer/sft_trainer.py#L248 comment lines 248 and 249
if dataset_text_field is None and formatting_func is None:
to allow input/output dataset
Command to run