From b5fc6cc47d427db994bf38657b641012ab1fd22d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Rozet?= Date: Sun, 24 Nov 2024 15:01:44 -0500 Subject: [PATCH] =?UTF-8?q?=F0=9F=A9=B9=20Fix=20MaskedMLP=20initialization?= =?UTF-8?q?=20on=20CUDA=20device?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- zuko/nn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/zuko/nn.py b/zuko/nn.py index daf57ba..dd56a3a 100644 --- a/zuko/nn.py +++ b/zuko/nn.py @@ -270,7 +270,7 @@ def __init__( adjacency, inverse = torch.unique(adjacency, dim=0, return_inverse=True) # P_ij = 1 if A_ik = 1 for all k such that A_jk = 1 - precedence = adjacency.int() @ adjacency.int().t() == adjacency.sum(dim=-1) + precedence = adjacency.double() @ adjacency.double().t() == adjacency.sum(dim=-1) # Layers layers = []