Skip to content

Novel embedding conditioning for a free improvement - free to use #2045

@kooshi

Description

@kooshi

I don't want to fight for a winning run, but I want to contribute a helpful technique.
I've been independently testing a lot of stuff, even before this competition, but one of the most simple things I found was also one of my first findings.
Embeddings have a little bit of drift to them that make them less than ideal, and we can directly condition them with a loss penalty.

Here's a graph of the penalty: https://www.desmos.com/calculator/padllb0rwj
Each activation is individually penalized by $f(r)=\max(0,\sqrt{2d/(r^{2}+0.1)}-1)$ to prevent collapse from the second penalty,
acting on the mean of all activations: $m^{2}/\sqrt{d}$
and the whole thing is scaled by lambda. I found 0.008 to be reasonable.

This effectively conditions the embedding to stay outside of a shell of size ~ $\sqrt{d-0.1}$, while conditioning the mass of that shell to be nearly perfectly centered. Embeddings can cluster, but cannot form a cone.

class ShellCenteringPenalty(nn.Module):
    """Prevents activation collapse and centroid drift from the origin"""
    def __init__(self, d_model, lam: float = 0.008):
        super().__init__()
        self.lam = lam
        self.d_model = d_model
        self.penalty = None

    def forward(self, x, step: int | None = None):
        # x: [B, T, d_model]
        # Keep individual activations from collapsing to the origin
        norms = x.norm(dim=-1, p=2)  # [B, T]
        self.penalty = self.lam * (((2 * self.d_model/((norms**2)+0.1))**0.5)-1).clamp(min=0).mean()

        # Force the centroid of the activations to stay near the origin
        norm_of_mean = x.flatten(0, 1).mean(dim=0).norm(p=2)
        self.penalty += self.lam * ((norm_of_mean ** 2)/(self.d_model**0.5))

        return x  # pass-through, graph intact

usage:

  x = self.tok_emb(input_ids)
+ x = self.shell_centering_penalty(x)
  logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
- return F.cross_entropy(
+ loss = F.cross_entropy(
      logits.reshape(-1, logits.size(-1)).float(),
      flat_targets,
      reduction="mean",
  )
+ return loss + self.shell_centering_penalty.penalty.to(loss.dtype)

The conditioned embeddings should result in an immediately improved final loss, basically for free.

If this ends up being helpful and anyone wants to help optimize the shape or orherwise collaborate, let me know. This is just a hobby for me, but I have other things worthy of formalizing too.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions