diff --git a/kgpip.py b/kgpip.py index 127a433..fb56ff0 100644 --- a/kgpip.py +++ b/kgpip.py @@ -271,7 +271,7 @@ def fit(self, X, y, task, verbose=True): # read the graph g = pd.read_pickle(graph_file) # relabel nodes and edges (use label instead of IDs) - node_labels = {n: g.node[n]['label'] for n in g.nodes} + node_labels = {n: g.nodes[n]['label'] for n in g.nodes} # TODO: is this needed? for k, v in node_labels.items(): node_labels[k] = v.replace('http://purl.org/twc/', '') @@ -384,4 +384,4 @@ def predict(self, X): kgpip.fit(X_train, y_train, 'regression' if is_regression else 'classification') score = r2_score(y_test, kgpip.predict(X_test)) if is_regression else f1_score(y_test, kgpip.predict(X_test), average='macro') - print('Score:', score) \ No newline at end of file + print('Score:', score)