Skip to content

Commit 57898ac

Browse files
committed
small fix
1 parent a2ae798 commit 57898ac

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

crystalformer/src/transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def network(G, XYZ, A, W, M, is_train):
178178
[ jnp.where(W==0, jnp.ones((n)), jnp.zeros((n))).reshape(n, 1),
179179
jnp.zeros((n, wyck_types-1))
180180
], axis = 1 ) # (n, wyck_types) mask = 1 for those locations to place pad atoms of type 0
181-
w_logit = w_logit + jnp.where(w_mask, 0.0, -1e10)
181+
w_logit = jnp.where(w_mask, 1e10, w_logit)
182182
w_logit -= jax.scipy.special.logsumexp(w_logit, axis=1)[:, None] # normalization
183183

184184
# (3) mask out unavaiable position after w_max for the given spacegroup

0 commit comments

Comments
 (0)