I don't want to fight for a winning run, but I want to contribute a helpful technique.
I've been independently testing a lot of stuff, even before this competition, but one of the most simple things I found was also one of my first findings.
Embeddings have a little bit of drift to them that make them less than ideal, and we can directly condition them with a loss penalty.
Here's a graph of the penalty: https://www.desmos.com/calculator/padllb0rwj
Each activation is individually penalized by $f(r)=\max(0,\sqrt{2d/(r^{2}+0.1)}-1)$ to prevent collapse from the second penalty,
acting on the mean of all activations: $m^{2}/\sqrt{d}$
and the whole thing is scaled by lambda. I found 0.008 to be reasonable.
This effectively conditions the embedding to stay outside of a shell of size ~ $\sqrt{d-0.1}$, while conditioning the mass of that shell to be nearly perfectly centered. Embeddings can cluster, but cannot form a cone.
class ShellCenteringPenalty(nn.Module):
"""Prevents activation collapse and centroid drift from the origin"""
def __init__(self, d_model, lam: float = 0.008):
super().__init__()
self.lam = lam
self.d_model = d_model
self.penalty = None
def forward(self, x, step: int | None = None):
# x: [B, T, d_model]
# Keep individual activations from collapsing to the origin
norms = x.norm(dim=-1, p=2) # [B, T]
self.penalty = self.lam * (((2 * self.d_model/((norms**2)+0.1))**0.5)-1).clamp(min=0).mean()
# Force the centroid of the activations to stay near the origin
norm_of_mean = x.flatten(0, 1).mean(dim=0).norm(p=2)
self.penalty += self.lam * ((norm_of_mean ** 2)/(self.d_model**0.5))
return x # pass-through, graph intact
usage:
x = self.tok_emb(input_ids)
+ x = self.shell_centering_penalty(x)
logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
- return F.cross_entropy(
+ loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)).float(),
flat_targets,
reduction="mean",
)
+ return loss + self.shell_centering_penalty.penalty.to(loss.dtype)
The conditioned embeddings should result in an immediately improved final loss, basically for free.
If this ends up being helpful and anyone wants to help optimize the shape or orherwise collaborate, let me know. This is just a hobby for me, but I have other things worthy of formalizing too.
I don't want to fight for a winning run, but I want to contribute a helpful technique.
I've been independently testing a lot of stuff, even before this competition, but one of the most simple things I found was also one of my first findings.
Embeddings have a little bit of drift to them that make them less than ideal, and we can directly condition them with a loss penalty.
Here's a graph of the penalty: https://www.desmos.com/calculator/padllb0rwj$f(r)=\max(0,\sqrt{2d/(r^{2}+0.1)}-1)$ to prevent collapse from the second penalty,$m^{2}/\sqrt{d}$
Each activation is individually penalized by
acting on the mean of all activations:
and the whole thing is scaled by lambda. I found 0.008 to be reasonable.
This effectively conditions the embedding to stay outside of a shell of size ~$\sqrt{d-0.1}$ , while conditioning the mass of that shell to be nearly perfectly centered. Embeddings can cluster, but cannot form a cone.
usage:
x = self.tok_emb(input_ids) + x = self.shell_centering_penalty(x)The conditioned embeddings should result in an immediately improved final loss, basically for free.
If this ends up being helpful and anyone wants to help optimize the shape or orherwise collaborate, let me know. This is just a hobby for me, but I have other things worthy of formalizing too.