Skip to content

Commit

Permalink
🩹 Fix MaskedMLP initialization on CUDA device
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Nov 24, 2024
1 parent 2f325e4 commit b5fc6cc
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion zuko/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(
adjacency, inverse = torch.unique(adjacency, dim=0, return_inverse=True)

# P_ij = 1 if A_ik = 1 for all k such that A_jk = 1
precedence = adjacency.int() @ adjacency.int().t() == adjacency.sum(dim=-1)
precedence = adjacency.double() @ adjacency.double().t() == adjacency.sum(dim=-1)

# Layers
layers = []
Expand Down

0 comments on commit b5fc6cc

Please sign in to comment.