Skip to content

Commit

Permalink
Merge pull request #146 from gyorilab/ner_improve
Browse files Browse the repository at this point in the history
Improve NER with word tokenization
  • Loading branch information
bgyori authored Jul 24, 2024
2 parents 7677132 + cfc0dcd commit 1ee3e85
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
10 changes: 5 additions & 5 deletions gilda/generate_terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from indra.databases import hgnc_client, uniprot_client, chebi_client, \
go_client, mesh_client, doid_client
from indra.statements.resources import amino_acids
from .term import Term, dump_terms, filter_out_duplicates
from .process import normalize
from .resources import resource_dir, popular_organisms
from gilda.term import Term, dump_terms, filter_out_duplicates
from gilda.process import normalize
from gilda.resources import resource_dir, popular_organisms


indra_module_path = indra.__path__[0]
Expand Down Expand Up @@ -666,7 +666,7 @@ def _generate_obo_terms(prefix, ignore_mappings=False, map_to_ns=None):

def _make_mesh_mappings():
# Load MeSH ID/label mappings
from .resources import MESH_MAPPINGS_PATH
from gilda.resources import MESH_MAPPINGS_PATH
mesh_mappings = {}
mesh_mappings_reverse = {}
for row in read_csv(MESH_MAPPINGS_PATH, delimiter='\t'):
Expand Down Expand Up @@ -715,7 +715,7 @@ def get_all_terms():


def main():
from .resources import GROUNDING_TERMS_PATH as fname
from gilda.resources import GROUNDING_TERMS_PATH as fname
terms = get_all_terms()
dump_terms(terms, fname)

Expand Down
40 changes: 27 additions & 13 deletions gilda/ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from typing import List

from nltk.corpus import stopwords
from nltk.tokenize import sent_tokenize
from nltk.tokenize import PunktSentenceTokenizer, TreebankWordTokenizer

from gilda import get_grounder
from gilda.grounder import Annotation
Expand Down Expand Up @@ -103,17 +103,23 @@ def annotate(
"""
if grounder is None:
grounder = get_grounder()
sent_tokenizer = PunktSentenceTokenizer()
if sent_split_fun is None:
sent_split_fun = sent_tokenize
sent_split_fun = sent_tokenizer.tokenize
# Get sentences
sentences = sent_split_fun(text)
sentence_coords = list(sent_tokenizer.span_tokenize(text))
text_coord = 0
annotations = []
for sentence in sentences:
raw_words = [w for w in sentence.rstrip('.').split()]
word_coords = [text_coord]
for word in raw_words:
word_coords.append(word_coords[-1] + len(word) + 1)
word_tokenizer = TreebankWordTokenizer()
# FIXME: a custom sentence split function can be inconsistent
# with the coordinates being used here which come from NLTK
for sentence, sentence_coord in zip(sentences, sentence_coords):
# FIXME: one rare corner case is named entities with single quotes
# in them which get tokenized in a weird way
raw_word_coords = \
list(word_tokenizer.span_tokenize(sentence.rstrip('.')))
raw_words = [sentence[start:end] for start, end in raw_word_coords]
text_coord += len(sentence) + 1
words = [normalize(w) for w in raw_words]
skip_until = 0
Expand All @@ -132,17 +138,25 @@ def annotate(

# Find the largest matching span
for span in sorted(applicable_spans, reverse=True):
txt_span = ' '.join(raw_words[idx:idx+span])
# We have to reconstruct a text span while adding spaces
# where needed
raw_span = ''
for rw, c in zip(raw_words[idx:idx+span],
raw_word_coords[idx:idx+span]):
# Figure out if we need a space before this word, then
# append the word.
spaces = ' ' * (c[0] - len(raw_span) -
raw_word_coords[idx][0])
raw_span += spaces + rw
context = text if context_text is None else context_text
matches = grounder.ground(txt_span,
matches = grounder.ground(raw_span,
context=context,
organisms=organisms,
namespaces=namespaces)
if matches:
start_coord = word_coords[idx]
end_coord = word_coords[idx+span-1] + \
len(raw_words[idx+span-1])
raw_span = ' '.join(raw_words[idx:idx+span])
start_coord = sentence_coord[0] + raw_word_coords[idx][0]
end_coord = sentence_coord[0] + \
raw_word_coords[idx+span-1][1]
annotations.append(Annotation(
raw_span, matches, start_coord, end_coord
))
Expand Down
25 changes: 25 additions & 0 deletions gilda/tests/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,28 @@ def test_context_test():
assert results[0].matches[0].term.get_curie() == "GO:0005783"
assert results[0].text == "ER"
assert (results[0].start, results[0].end) == (14, 16)


def test_punctuation_comma_in_entity():
# A named entity with an actual comma in its name
res = gilda.annotate('access, internet')
assert len(res) == 1
# Make sure we capture the text span exactly despite
# tokenization
assert res[0].text == 'access, internet'
assert res[0].start == 0
assert res[0].end == 16
assert res[0].matches[0].term.db == 'MESH'
assert res[0].matches[0].term.id == 'D000077230'


def test_punctuation_outside_entities():
res = gilda.annotate('EGF binds EGFR, which is a receptor.')
assert len(res) == 3

assert [ann.text for ann in res] == ['EGF', 'EGFR', 'receptor']

res = gilda.annotate('EGF binds EGFR: a receptor.')
assert len(res) == 3

assert [ann.text for ann in res] == ['EGF', 'EGFR', 'receptor']

0 comments on commit 1ee3e85

Please sign in to comment.