Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

compute_metric(eval_pred) in trainer is not mini-batch #31667

Closed
SamYuen101234 opened this issue Jun 27, 2024 · 2 comments
Closed

compute_metric(eval_pred) in trainer is not mini-batch #31667

SamYuen101234 opened this issue Jun 27, 2024 · 2 comments

Comments

@SamYuen101234
Copy link

I am trying to implement a custom compute metric for trainer. The logits and labels are numpy array of the full evaluation data, however, my evaluation data input has the size (1000, 43, 50257). The computation can't be done in a 24GB L4 GPU on colab. Any way to load the data in mini batch like using dataloader instead of given a full numpy array.

`# eval_pred is all the valid data not only the mini-batch
def compute_metrics(eval_pred):
accuracy_metric = load_metric("accuracy")
logits, labels = eval_pred

# Get predictions (next word prediction)
predictions = np.argmax(logits, axis=-1)

# Shift labels to the left
labels_shifted = labels[:, 1:].flatten()
predictions_shifted = predictions[:, :-1].flatten()

# Create an attention mask based on labels (assuming -100 is padding)
attention_mask_shifted = (labels[:, 1:] != -100).flatten()

# Remove padding tokens using attention mask
predictions_shifted = predictions_shifted[attention_mask_shifted]
labels_shifted = labels_shifted[attention_mask_shifted]

# Compute accuracy
if len(labels_shifted) == 0:
    return {"accuracy": 0.0}
accuracy = accuracy_metric.compute(predictions=predictions_shifted, references=labels_shifted)

return {"accuracy": accuracy["accuracy"]}`
@amyeroberts
Copy link
Collaborator

Hi @SamYuen101234, thanks for raising an issue!

This is a question best placed in our forums. We try to reserve the github issues for feature requests and bug reports.

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Aug 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants