-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
22 lines (20 loc) · 776 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
from torch.nn import functional as F
def gather_log_probabilities(
logits: torch.Tensor, # size = (B, L, V)
labels: torch.LongTensor, # size = (B, L)
) -> torch.Tensor: # size = (B, L)
"""Gather log probabilities of the given labels from the logits."""
log_probs = F.log_softmax(logits, dim=-1) # size = (B, L, V)
gathered_log_probs = torch.gather( # size = (B, L, 1)
log_probs,
dim=-1,
index=labels.unsqueeze(dim=-1),
)
return gathered_log_probs.squeeze(dim=-1) # size = (B, L)
def remove_eligal_characters(text: str) -> str:
"""Remove illegal characters from the given text."""
text = text.replace("'", "")
text = text.replace("`", "")
text = text.replace("python", "")
return text