Skip to content

Unstable adaptive temperature? #13

@wangh09

Description

@wangh09

Hi, thanks for the innovative work! I'm trying to reproduce Transover++ with Transover, and I found that the adaptive temperature seems unstable. So I'm wondering if there's any other trick other than a linear layer described in the paper that needs to be applied?

In particular, Eq(3) may involve the following code adding to PhysicsAttention:
In init():

self.temperature0 = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5)
self.to_delta_temp = nn.Linear(head_dim, 1) 
nn.init.zeros_(self.to_delta_tau.weight)
nn.init.zeros_(self.to_delta_tau.bias)

In forward():

temperature = self.temperature0 + self.to_delta_temp(x_mid)
slice_weights = self.softmax(self.in_project_slice(x_mid) / temperature)

With above code, gradients may explode and It seems hard to prevent zeros in the temperature variable above. Is there any trick on this? Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions