Skip to content

Commit

Permalink
Update cross_entropy_loss.py
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhanchen committed Jul 2, 2024
1 parent 95c934d commit 66f35fd
Showing 1 changed file with 70 additions and 26 deletions.
96 changes: 70 additions & 26 deletions unsloth/kernels/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,17 @@
from transformers.models.llama.modeling_llama import logger


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
Expand Down Expand Up @@ -58,29 +61,38 @@ def _cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = logsumexp - x
x = tl.load(logits_ptr + label_idx)
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = logsumexp - x.to(tl.float32)
else:
loss = 0.0
tl.store(logsumexp_ptr, logsumexp)
tl.store(loss_ptr, loss)
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _chunked_cross_entropy_forward(
logits_ptr, logits_row_stride,
loss_ptr,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
N_CHUNKS : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
256K vocab divided in 4 chunks
Expand Down Expand Up @@ -117,7 +129,11 @@ def _chunked_cross_entropy_forward(
mask = col_offsets < VOCAB_SIZE

label_idx = tl.load(labels_ptr).to(tl.int32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: logits = SOFTCAP * tl.math.tanh(logits / SOFTCAP)

logits = logits.to(tl.float32)
c = tl.max(logits, 0)
logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))

Expand All @@ -126,7 +142,9 @@ def _chunked_cross_entropy_forward(
# Do the -x separately
if label_idx != -100:
x = tl.load(logits_ptr + label_idx).to(tl.float32)
loss = -1.0 * x
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING: x = SOFTCAP * tl.math.tanh(x / SOFTCAP)
loss = -1.0 * x.to(tl.float32)
else:
loss = 0.0
tl.store(loss_ptr, loss)
Expand All @@ -135,14 +153,17 @@ def _chunked_cross_entropy_forward(
pass


@triton.heuristics({"DO_SOFTCAPPING": lambda args: args["DO_SOFTCAPPING"],})
@triton.jit
def _cross_entropy_backward(
logits_ptr, logits_row_stride,
dloss_ptr, dloss_row_stride,
logsumexp_ptr,
labels_ptr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
VOCAB_SIZE : tl.constexpr,
BLOCK_SIZE : tl.constexpr,
DO_SOFTCAPPING : tl.constexpr,
SOFTCAP : tl.constexpr,
):
"""
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
Expand Down Expand Up @@ -173,15 +194,27 @@ def _cross_entropy_backward(
else:
dloss = 0.0

x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
x = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf"))
# Do logit softcapping for Gemma 2: t * tanh(1/t * x)
if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
partial = tl.math.tanh(x / SOFTCAP)
x = SOFTCAP * partial
pass

logsumexp = tl.load(logsumexp_ptr + row_idx)
y = tl.exp(x - logsumexp)
y = tl.exp(x.to(tl.float32) - logsumexp)
y = tl.where(
col_offsets == label_idx,
y - 1.0, # exp(x - logsumexp) - 1
y, # exp(x - logsumexp)
)

if DO_SOFTCAPPING:
# d/dx [t * tanh(1/t * x)] = 1 - tanh^2(1/t * x)
y = y * (1.0 - partial*partial)
pass

# If y == 0: dC/dx = 0 ==> we already masked it to be = 0, so dloss = 0.
tl.store(logits_ptr + col_offsets, dloss * y, mask = mask)
pass
Expand All @@ -191,13 +224,15 @@ def _cross_entropy_backward(

class Fast_CrossEntropyLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, logits, labels):
def forward(ctx, logits, labels, logit_softcapping = 0):
n_rows, vocab_size = logits.shape

div, mod = divmod(vocab_size, MAX_FUSED_SIZE)
n_chunks = div + (mod != 0)
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")

DO_SOFTCAPPING = (logit_softcapping != 0)

if n_chunks == 1:
# For small vocabs <= 65336 like Llama, Mistral
BLOCK_SIZE, num_warps = calculate_settings(vocab_size)
Expand All @@ -208,9 +243,11 @@ def forward(ctx, logits, labels):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = num_warps,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = num_warps,
)
else:
# For large vocabs > 65336 like Gemma 256K
Expand All @@ -221,10 +258,12 @@ def forward(ctx, logits, labels):
losses,
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
num_warps = 32,
VOCAB_SIZE = vocab_size,
N_CHUNKS = n_chunks,
BLOCK_SIZE = MAX_FUSED_SIZE,
DO_SOFTCAPPING = DO_SOFTCAPPING,
SOFTCAP = logit_softcapping,
num_warps = 32,
)
# logsumexp(chunked_logsumexp) - x
# Do the -x separately
Expand All @@ -234,6 +273,8 @@ def forward(ctx, logits, labels):
pass

ctx.save_for_backward(logits, logsumexp, labels)
ctx.DO_SOFTCAPPING = DO_SOFTCAPPING
ctx.logit_softcapping = logit_softcapping
return losses
pass

Expand All @@ -251,16 +292,18 @@ def backward(ctx, dlosses):
dlosses, dlosses.stride(0),
logsumexp,
labels,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
num_warps = 8,
VOCAB_SIZE = vocab_size,
BLOCK_SIZE = BLOCK_SIZE,
DO_SOFTCAPPING = ctx.DO_SOFTCAPPING,
SOFTCAP = ctx.logit_softcapping,
num_warps = 8,
)
return logits, None, None,
pass
pass


def fast_cross_entropy_loss(logits, labels):
def fast_cross_entropy_loss(logits, labels, logit_softcapping = 0):
"""
Arguments:
logits: (batch, seq_len, vocab_size)
Expand All @@ -274,6 +317,7 @@ def fast_cross_entropy_loss(logits, labels):
loss = Fast_CrossEntropyLoss.apply(
logits.view(batch*seq_len, d),
labels.view(-1),
logit_softcapping,
)
n_items = torch.count_nonzero(labels != -100)
return loss.sum() / n_items
Expand Down

0 comments on commit 66f35fd

Please sign in to comment.