diff --git a/subgraph_matching/alignment.py b/subgraph_matching/alignment.py index cbebdf2..8499173 100644 --- a/subgraph_matching/alignment.py +++ b/subgraph_matching/alignment.py @@ -48,8 +48,8 @@ def gen_alignment_matrix(model, query, target, method_type="order"): """ mat = np.zeros((len(query), len(target))) - for u in query.nodes: - for v in target.nodes: + for i, u in enumerate(query.nodes): + for j, v in enumerate(target.nodes): batch = utils.batch_nx_graphs([query, target], anchors=[u, v]) embs = model.emb_model(batch) pred = model(embs[1].unsqueeze(0), embs[0].unsqueeze(0)) @@ -58,7 +58,7 @@ def gen_alignment_matrix(model, query, target, method_type="order"): raw_pred = torch.log(raw_pred) elif method_type == "mlp": raw_pred = raw_pred[0][1] - mat[u][v] = raw_pred.item() + mat[i][j] = raw_pred.item() return mat def main():