diff --git a/gilda/ner.py b/gilda/ner.py index bc0a25b..d904734 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -47,10 +47,8 @@ from typing import List -import nltk from nltk.corpus import stopwords -from nltk.tokenize import sent_tokenize -from nltk.tokenize import TreebankWordTokenizer +from nltk.tokenize import PunktSentenceTokenizer, TreebankWordTokenizer from gilda import get_grounder from gilda.grounder import Annotation @@ -105,14 +103,20 @@ def annotate( """ if grounder is None: grounder = get_grounder() + sent_tokenizer = PunktSentenceTokenizer() if sent_split_fun is None: - sent_split_fun = sent_tokenize + sent_split_fun = sent_tokenizer.tokenize # Get sentences sentences = sent_split_fun(text) + sentence_coords = list(sent_tokenizer.span_tokenize(text)) text_coord = 0 annotations = [] word_tokenizer = TreebankWordTokenizer() - for sentence in sentences: + # FIXME: a custom sentence split function can be inconsistent + # with the coordinates being used here which come from NLTK + for sentence, sentence_coord in zip(sentences, sentence_coords): + # FIXME: one rare corner case is named entities with single quotes + # in them which get tokenized in a weird way raw_word_coords = \ list(word_tokenizer.span_tokenize(sentence.rstrip('.'))) raw_words = [sentence[start:end] for start, end in raw_word_coords] @@ -143,7 +147,8 @@ def annotate( raw_word_coords[idx:idx+span]): # Figure out if we need a space before this word, then # append the word. - spaces = ' ' * (c[0] - len(raw_span) - raw_word_coords[idx][0]) + spaces = ' ' * (c[0] - len(raw_span) - + raw_word_coords[idx][0]) txt_span += spaces + w raw_span += spaces + rw context = text if context_text is None else context_text @@ -152,8 +157,9 @@ def annotate( organisms=organisms, namespaces=namespaces) if matches: - start_coord = raw_word_coords[idx][0] - end_coord = raw_word_coords[idx+span-1][1] + start_coord = sentence_coord[0] + raw_word_coords[idx][0] + end_coord = sentence_coord[0] + \ + raw_word_coords[idx+span-1][1] annotations.append(Annotation( raw_span, matches, start_coord, end_coord )) diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py index 6c54758..c8b24c1 100644 --- a/gilda/tests/test_ner.py +++ b/gilda/tests/test_ner.py @@ -83,9 +83,9 @@ def test_context_test(): context_text = "Calcium is released from the ER." results = gilda.annotate(text, context_text=context_text) assert len(results) == 1 - assert results[0].matches[0].term.get_curie() == "GO:0005783" - assert results[0].text == "ER" - assert (results[0].start, results[0].end) == (14, 16) + assert results[1].matches[0].term.get_curie() == "GO:0005783" + assert results[1].text == "ER" + assert (results[1].start, results[0].end) == (14, 16) def test_punctuation_comma_in_entity():