From 37f7c2c90b482f51f789f7dd08236405386ae9e5 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Tue, 9 Jul 2024 16:19:38 -0400 Subject: [PATCH 01/19] Initial implementation of NER benchmark. --- benchmarks/bioid_ner_benchmark.py | 718 ++++++++++++++++++++++++++++++ 1 file changed, 718 insertions(+) create mode 100644 benchmarks/bioid_ner_benchmark.py diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py new file mode 100644 index 0000000..714c4b6 --- /dev/null +++ b/benchmarks/bioid_ner_benchmark.py @@ -0,0 +1,718 @@ +import os +import json +import pathlib +from collections import defaultdict + +from functools import lru_cache +import pandas as pd +import xml.etree.ElementTree as ET + +from lxml import etree +from tqdm import tqdm +from datetime import datetime +from typing import List, Tuple, Set, Dict, Optional, Iterable, Collection +import click +import pystow +import gilda +from gilda import ground +# from benchmarks.bioid_evaluation import fplx_members +from gilda.ner import annotate +from gilda.grounder import logger + +import famplex +from indra.databases.chebi_client import get_chebi_id_from_pubchem +from indra.databases.hgnc_client import get_hgnc_from_entrez +from indra.databases.uniprot_client import get_hgnc_id +from indra.ontology.bio import bio_ontology + + + +logger.setLevel('WARNING') + +# Constants +HERE = os.path.dirname(os.path.abspath(__file__)) +# MODULE = pystow.module('gilda', 'biocreative') +# URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017' +# '/BioIDtraining_2.tar.gz') +DATA_DIR = os.path.join(HERE, 'data', 'BioIDtraining_2', 'caption_bioc') +ANNOTATIONS_PATH = os.path.join(HERE, 'data', 'BioIDtraining_2', + 'annotations.csv') +RESULTS_DIR = os.path.join(HERE, 'results', "bioid_ner_performance", + gilda.__version__) +MODULE = pystow.module('gilda', 'biocreative') +URL = 'https://biocreative.bioinformatics.udel.edu/media/store/files/2017/BioIDtraining_2.tar.gz' + +tqdm.pandas() + +BO_MISSING_XREFS = set() + + +class BioIDNERBenchmarker: + def __init__(self): + print("Instantiating benchmarker...") + self.equivalences = self._load_equivalences() + self.paper_level_grounding = defaultdict(set) + self.processed_data = self.process_xml_files() #xml files processesed + self.annotations_df = self._process_annotations_table() #csv annotations + # self.reference_map = self.create_reference_map() # Create reference map for efficient lookup + self.gilda_annotations_map = defaultdict(list) # New field to store Gilda annotations + self.counts_table = None + self.precision_recall = None + + # Print a small sample of annotations_df for debugging + print("Sample of annotations_df:") + print(self.annotations_df.head(10)) # Display first 10 rows + print(self.annotations_df.columns) # Display column names + + # Print unique values of doc_id and don_article for debugging + # print("Unique doc_id values in processed_data:") + # print(self.processed_data['id'].unique()[:10]) # Display first 10 unique IDs + # print("Unique don_article values in annotations_df:") + # print(self.annotations_df['don_article'].unique()[:10]) # Display first 10 unique IDs + # Print unique values of doc_id and don_article for debugging + print("First 10 unique doc_id values in processed_data:") + print(self.processed_data['doc_id'].unique()[ + :10]) # Display first 10 unique IDs + print("First 10 unique figure values in processed_data:") + print(self.processed_data['figure'].unique()[ + :10]) # Display first 10 unique IDs + print("First 10 unique don_article values in annotations_df:") + print(self.annotations_df['don_article'].unique()[ + :10]) # Display first 10 unique IDs + print("First 10 unique figure values in annotations_df:") + print(self.annotations_df['figure'].unique()[ + :10]) # Display first 10 unique IDs + + def process_xml_files(self): + """Extract relevant information from XML files.""" + print("Extracting information from XML files...") + data = [] + total_annotations = 0 + for filename in os.listdir(DATA_DIR): + if filename.endswith('.xml'): + filepath = os.path.join(DATA_DIR, filename) + tree = ET.parse(filepath) + root = tree.getroot() + for document in root.findall('.//document'): + doc_id_full = document.find('.//id').text.strip() + don_article, figure = doc_id_full.split(' ', 1) # Split the full ID to get don_article and figure + don_article = don_article + for passage in document.findall('.//passage'): + offset = int(passage.find('.//offset').text) + text = passage.find('.//text').text + annotations = [] + for annotation in passage.findall('.//annotation'): + annot_id = annotation.get('id') + annot_text = annotation.find('.//text').text + annot_type = annotation.find('.//infon[@key="type"]').text + annot_offset = int(annotation.find('.//location').attrib['offset']) + annot_length = int(annotation.find('.//location').attrib['length']) + annotations.append({ + 'annot_id': annot_id, + 'annot_text': annot_text, + 'annot_type': annot_type, + 'annot_offset': annot_offset, + 'annot_length': annot_length, + }) + total_annotations += 1 + data.append({ + 'doc_id': don_article, + 'figure': figure, + 'offset': offset, + 'text': text, + 'annotations': annotations, + }) + # df = pd.DataFrame(data) + # print(f"{len(df)} rows in processed XML data.") + print(f"Total annotations in XML files: {total_annotations}") + print("Finished extracting information from XML files.") + return pd.DataFrame(data) + + + + + + + # document = root.find('.//document') + # doc_id = document.find('.//id').text.strip() + # try: + # doc_id = int(doc_id) + # except ValueError: + # print(f"Skipping file with non-integer doc_id: {filename}") + # continue + # + # text_elements = document.findall('.//text') + # texts = [elem.text for elem in text_elements if elem.text] + # full_text = ' '.join(texts) + # + # if doc_id == 3868508: + # # Print the text being used for annotation for document ID 3868508 + # print(f"Document ID: {doc_id}") + # print( + # f"Full Text: {full_text[:500]}...") # Print first 500 characters for brevity + # + # data.append({'id': doc_id, 'text': full_text}) + # df = pd.DataFrame(data) + # print(f"{len(df)} rows in processed XML data.") + # return df + + def _load_equivalences(self) -> Dict[str, List[str]]: + try: + with open(os.path.join(DATA_DIR, 'equivalences.json')) as f: + equivalences = json.load(f) + except FileNotFoundError: + equivalences = {} + return equivalences + + @classmethod + def _normalize_ids(cls, curies: str) -> List[Tuple[str, str]]: + return [cls._normalize_id(y) for y in curies.split('|')] + + @staticmethod + def _normalize_id(curie): + """Convert ID into standardized format, f'{namespace}:{id}'.""" + if curie.startswith('CVCL'): + return curie.replace('_', ':') + split_id = curie.split(':', maxsplit=1) + if split_id[0] == 'Uberon': + return split_id[1] + if split_id[0] == 'Uniprot': + return f'UP:{split_id[1]}' + if split_id[0] in ['GO', 'CHEBI']: + return f'{split_id[0]}:{split_id[0]}:{split_id[1]}' + return curie + + def get_synonym_set(self, curies: Iterable[str]) -> Set[str]: + """Return set containing all elements in input list along with synonyms + """ + output = set() + for curie in curies: + output.update(self._get_equivalent_entities(curie)) + # We accept all FamPlex terms that cover some or all of the specific + # entries in the annotations + # covered_fplx = {fplx_entry for fplx_entry, members + # in fplx_members.items() if (members <= output)} + # output |= {'FPLX:%s' % fplx_entry for fplx_entry in covered_fplx} + return output + + def _get_equivalent_entities(self, curie: str) -> Set[str]: + """Return set of equivalent entity groundings + + Uses set of equivalences in self.equiv_map as well as those + available in indra's hgnc, uniprot, and chebi clients. + """ + output = {curie} + prefix, identifier = curie.split(':', maxsplit=1) + for xref_prefix, xref_id in bio_ontology.get_mappings(prefix, + identifier): + output.add(f'{xref_prefix}:{xref_id}') + + # TODO these should all be in bioontology, eventually + for xref_curie in self.equivalences.get(curie, []): + if xref_curie in output: + continue + xref_prefix, xref_id = xref_curie.split(':', maxsplit=1) + if (prefix, xref_prefix) not in BO_MISSING_XREFS: + BO_MISSING_XREFS.add((prefix, xref_prefix)) + tqdm.write( + f'Bioontology v{bio_ontology.version} is missing mappings from {prefix} to {xref_prefix}') + output.add(xref_curie) + + if prefix == 'NCBI gene': + hgnc_id = get_hgnc_from_entrez(identifier) + if hgnc_id is not None: + output.add(f'HGNC:{hgnc_id}') + if prefix == 'UP': + hgnc_id = get_hgnc_id(identifier) + if hgnc_id is not None: + output.add(f'HGNC:{hgnc_id}') + if prefix == 'PubChem': + chebi_id = get_chebi_id_from_pubchem(identifier) + if chebi_id is not None: + output.add(f'CHEBI:{chebi_id}') + return output + + def _get_entity_type_helper(self, row) -> str: + if self._get_entity_type(row.obj) != 'Gene': + return self._get_entity_type(row.obj) + elif any(y.startswith('HGNC') for y in row.obj_synonyms): + return 'Human Gene' + else: + return 'Nonhuman Gene' + + @staticmethod + def _get_entity_type(groundings: Collection[str]) -> str: + """Get entity type based on entity groundings of text in corpus.""" + if any( + grounding.startswith('NCBI gene') or grounding.startswith('UP') + for grounding in groundings + ): + return 'Gene' + elif any(grounding.startswith('Rfam') for grounding in groundings): + return 'miRNA' + elif any( + grounding.startswith('CHEBI') or grounding.startswith('PubChem') + for grounding in groundings): + return 'Small Molecule' + elif any(grounding.startswith('GO') for grounding in groundings): + return 'Cellular Component' + elif any( + grounding.startswith('CVCL') or grounding.startswith('CL') + for grounding in groundings + ): + return 'Cell types/Cell lines' + elif any(grounding.startswith('UBERON') for grounding in groundings): + return 'Tissue/Organ' + elif any( + grounding.startswith('NCBI taxon') for grounding in groundings): + return 'Taxon' + else: + return 'unknown' + + def _process_annotations_table(self): + """Extract relevant information from annotations table.""" + print("Extracting information from annotations table...") + df = MODULE.ensure_tar_df( + url=URL, + inner_path='BioIDtraining_2/annotations.csv', + read_csv_kwargs=dict(sep=',', low_memory=False), + ) + # Split entries with multiple groundings then normalize ids + df.loc[:, 'obj'] = df['obj'].apply(self._normalize_ids) + # Add synonyms of gold standard groundings to help match more things + df.loc[:, 'obj_synonyms'] = df['obj'].apply(self.get_synonym_set) + # Create column for entity type + df.loc[:, 'entity_type'] = df.apply(self._get_entity_type_helper, axis=1) + processed_data = df[['text', 'obj', 'obj_synonyms', 'entity_type', + 'don_article', 'figure', 'annot id', 'first left', 'last right']] + print("%d rows in processed annotations table." % len(processed_data)) + processed_data = processed_data[processed_data.entity_type + != 'unknown'] + print("%d rows in annotations table without unknowns." % + len(processed_data)) + for don_article, text, synonyms in df[['don_article', 'text', + 'obj_synonyms']].values: + self.paper_level_grounding[don_article, text].update(synonyms) + return processed_data + + @lru_cache(maxsize=None) + def _get_plaintext(self, don_article: str) -> str: + """Get plaintext content from XML file in BioID corpus + + Parameters + ---------- + don_article : + Identifier for paper used within corpus. + + Returns + ------- + : + Plaintext of specified article + """ + directory = MODULE.ensure_untar(url=URL, directory='BioIDtraining_2') + path = directory.joinpath('BioIDtraining_2', 'fulltext_bioc', + f'{don_article}.xml') + tree = etree.parse(path.as_posix()) + paragraphs = tree.xpath('//text') + paragraphs = [' '.join(text.itertext()) for text in paragraphs] + return '/n'.join(paragraphs) + '/n' + + def print_annotations_for_doc_id(self, doc_id): + filtered_df = self.annotations_df[ + self.annotations_df['don_article'] == doc_id] + print(f"Annotations for Document ID {doc_id}:") + print(filtered_df) + + def annotate_entities_with_gilda(self): + """Performs NER on the XML files using gilda.annotate()""" + # df = self.processed_data + tqdm.write("Annotating corpus with Gilda...") + + results = [] + # for _, row in df.iterrows(): + for _, item in self.processed_data.iterrows(): + doc_id = item['doc_id'] + figure = item['figure'] + text = item['text'] + # annotations = item['annotations'] + + # Get the full text for the paper-level disambiguation + full_text = self._get_plaintext(doc_id) + + gilda_annotations = annotate(text, context_text=full_text) + + for matched_text, scored_ner_match, start, end in gilda_annotations: + grounding_results = ground(matched_text, context=full_text) + + for scored_grounding_match in grounding_results: + + db, entity_id = grounding_results[0].term.db, grounding_results[0].term.id + curie = f"{db}:{entity_id}" + # normalized_id = self._normalize_id(curie) + synonyms = self.get_synonym_set([curie]) + entity_type = self._get_entity_type([curie]) + + self.gilda_annotations_map[(doc_id, figure)].append({ + 'matched_text': matched_text, + 'db': db, + 'id': entity_id, + 'start': start, + 'end': end, + # 'normalized_id': normalized_id, + 'synonyms': synonyms, + 'entity_type': entity_type + }) + # results.append( + # (scored_match.term.db, scored_match.term.id, start, end)) + if doc_id == '3868508' and figure == 'Figure_1-A': + print(f"Document ID: {doc_id}, Figure: {figure}") + print(f"Scored NER Match: {scored_ner_match}") + print( + f"Annotated Text Segment: {text[max(0, start - 50):min(len(text), end + 50)]}") + print( + f"Matched Text: {matched_text}, DB: {db}, ID: {entity_id}, Start: {start}, End: {end}") + print(f"Grounding Results: {curie}") + print(f"synonyms: {synonyms}") + print(f"entity type: {entity_type}") + # annotations.append({'doc_id': doc_id, 'text': text, 'gilda_annotations': gilda_annotations}) + + + + # df = pd.DataFrame(annotations) + # self.processed_data = df + # self.processed_data = pd.DataFrame(results) + tqdm.write("Finished annotating corpus with Gilda...") + + # Print a small sample of processed_data for debugging + # print("Sample of processed_data:") + # print(self.processed_data.head(10)) # Display first 10 rows + # print(self.processed_data.columns) # Display column names + + # Update paper-level grounding with Gilda annotations + # for _, row in df.iterrows(): + # doc_id = row['doc_id'] + # text = row['text'] + # gilda_ann = row['gilda_annotations'] + # for ann in gilda_ann: + # self.paper_level_grounding[(doc_id, text)].add(ann) + + # for doc_id, text, gilda_ann in zip(df['id'], df['text'], + # df['gilda_annotations']): + # for ann in gilda_ann: + # self.paper_level_grounding[(doc_id, text)].add( + # (ann[1].term.db, ann[1].term.id, ann[2], ann[3])) + + # def is_correct_annotation(self, row): + # """Cross-references gilda annotations with annotations + # provided by the dataset""" + # doc_id = int(row['doc_id']) + # figure = row['figure'] + # db = row['db'] + # annot_id = row['id'] + # start = row['start'] + # end = row['end'] + # + # # Ensure there are no leading/trailing spaces in the DataFrame + # self.annotations_df['figure'] = self.annotations_df['figure'].str.strip() + # + # # ref_annotations = set() + # + # ref_data = self.annotations_df[ + # (self.annotations_df['don_article'] == doc_id) & + # (self.annotations_df['figure'] == figure) + # ] + # + # # specific_doc_id = 3868508 + # # specific_figure = 'Figure_1-A' + # # Debugging output to check why ref_data might be empty + # # if doc_id == specific_doc_id and figure == specific_figure: + # # print( + # # f"Checking annotations for Document ID: {doc_id}, Figure: {figure}") + # # print(f"Annotations in DataFrame:\n{self.annotations_df.head()}") + # # print(f"Filtered Reference Data:\n{ref_data}") + # + # ref_annotations = set( + # (r['annot id'], r['first left'], r['last right']) + # for _, r in ref_data.iterrows() + # ) + # + # gilda_annotation = (annot_id, f'{db}:{row["id"]}') + # + # + # true_positive = gilda_annotation in ref_annotations + # false_positive = gilda_annotation not in ref_annotations + # false_negative = gilda_annotation not in ref_annotations + # + # # expanded_gilda_annotations = {annot_id, db, row['start'], row['end']} + # + # # for _, ref_row in ref_data.iterrows(): + # # ref_text = ref_row['obj'] + # # annot_id = ref_row['annot id'] + # # start_pos = int(ref_row['first left']) + # # end_pos = int(ref_row['last right']) + # # for normalized_id in ref_text: + # # parts = normalized_id.split(':') + # # db, identifier = parts[0], ':'.join(parts[1:]) + # # ref_annotations.add((annot_id, db, identifier, start_pos, end_pos)) + # # + # # Use synonym sets to match equivalent terms + # # expanded_gilda_annotations = set() + # # for ann in gilda_annotations: + # # expanded_gilda_annotations.update( + # # self.get_synonym_set([f"{ann[0]}:{ann[1]}"])) + # + # # Conditional print for debugging: Limit to a small set of data + # # if doc_id == specific_doc_id and figure == specific_figure: + # # print("Gilda Annotations:") + # # print(list(gilda_annotations)[:5]) # Print first few annotations for brevity + # # print("Reference Annotations:") + # # print(list(ref_annotations)[:5]) # Print first few annotations for brevity + # + # # return expanded_gilda_annotations, ref_annotations + # + # # true_positives = 0 + # # false_positives = 0 + # # false_negatives = 0 + # + # # Calculate true positives and false positives + # # for gilda_ann in expanded_gilda_annotations: + # # if gilda_ann in ref_annotations: + # # true_positives += 1 + # # else: + # # false_positives += 1 + # # + # # # Calculate false negatives + # # for ref_ann in ref_annotations: + # # if ref_ann not in expanded_gilda_annotations: + # # false_negatives += 1 + # + # return true_positive, false_positive, false_negative + + # def check_annotation(self, row, reference_map): + # """Checks if the Gilda annotation exists in the reference annotations""" + # doc_id = int(row['doc_id']) + # figure = row['figure'] + # gilda_annotation = f'{row["db"]}:{row["id"]}' + # gilda_synonyms = self.get_synonym_set([gilda_annotation]) + # + # if (doc_id, figure) in reference_map: + # ref_annotations = reference_map[(doc_id, figure)] + # + # true_positive = any( + # syn in ref_annotations for syn in gilda_synonyms) + # false_positive = not true_positive + # false_negative = not true_positive + # + # return int(true_positive), int(false_positive), int(false_negative) + # + # return 0, 1, 0 # If there are no reference annotations, it's a false positive + + def evaluate_gilda_performance(self): + """Calculates precision, recall, and F1""" + print("Evaluating performance...") + # df = self.processed_data + + total_true_positives = 0 + total_false_positives = 0 + total_false_negatives = 0 + + # reference_map = self.create_reference_map() + + for (doc_id, figure), annotations in self.gilda_annotations_map.items(): + # print(f"Processing Document ID: {doc_id}, Figure: {figure}") + for annotation in annotations: + start = annotation['start'] + end = annotation['end'] + # gilda_annotation = annotation['id'] + gilda_synonyms = annotation['synonyms'] + text = annotation['matched_text'] + + # gilda_annotation = f'{annotation["db"]}:{annotation["id"]}' + # gilda_synonyms = self.get_synonym_set([gilda_annotation]) + + # ref_annotations = self.reference_map.get((doc_id, figure), set()) + + ref_annotations = self.annotations_df[ + (self.annotations_df['don_article'] == int(doc_id)) & + (self.annotations_df['figure'] == figure) + ] + + + + # Check if any synonym of the Gilda annotation matches reference annotations + # match_found = any((text, syn, start, end) in ref_annotations for syn in gilda_synonyms) + + match_found = any( + (text, syn, start, end) in ref_annotations[ + ['text', 'obj_synonyms', 'first left', + 'last right']].values + for syn in gilda_synonyms + ) + + + if doc_id == '3868508' and figure == "Figure_1-A": + print(f"Document ID: {doc_id}, Figure: {figure}") + print(f"Gilda Annotation: {annotation}") + print(f"Reference Annotations: {ref_annotations}") + print(f"Match Found: {match_found}") + + if match_found: + total_true_positives += 1 + else: + total_false_positives += 1 + total_false_negatives += 1 + + # total_true_positives += int(true_positive) + # total_false_positives += int(false_positive) + # total_false_negatives += int(false_negative) + + # tp, fp, fn = self.check_annotation(row, reference_map) + + # = self.is_correct_annotation(row) + # total_true_positives += tp + # total_false_positives += fp + # total_false_negatives += fn + + + + # for data in self.process_xml_files(): + # doc_id = data['doc_id'] + # text = data['text'] + # gilda_ann = data['gilda_annotations'] + # figure = data.get('figure','') # Ensure figure is retrieved if it exists + + precision = total_true_positives / ( + total_true_positives + total_false_positives) if (total_true_positives + total_false_positives) > 0 else 0 + recall = total_true_positives / ( + total_true_positives + total_false_negatives) if (total_true_positives + total_false_negatives) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # df = self.processed_data + # results = df.apply(self.is_correct_annotation, axis=1) + + # df['true_positives'] = results.apply(lambda x: x[0]) + # df['false_positives'] = results.apply(lambda x: x[1]) + # df['false_negatives'] = results.apply(lambda x: x[2]) + + # total_true_positives = df['true_positives'].sum() + # total_false_positives = df['false_positives'].sum() + # total_false_negatives = df['false_negatives'].sum() + + # precision = total_true_positives / (total_true_positives + # + total_false_positives) \ + # if (total_true_positives + total_false_positives) > 0 else 0 + # + # recall = total_true_positives / (total_true_positives + + # total_false_negatives) \ + # if (total_true_positives + total_false_negatives) > 0 else 0 + # + # f1 = 2 * (precision * recall) / (precision + recall) \ + # if (precision + recall) > 0 else 0 + + counts_table = pd.DataFrame([{ + 'True Positives': total_true_positives, + 'False Positives': total_false_positives, + 'False Negatives': total_false_negatives + }]) + + precision_recall = pd.DataFrame([{ + 'Precision': precision, + 'Recall': recall, + 'F1 Score': f1 + }]) + + print("Aggregated Results:") + print(counts_table) + print(precision_recall) + + self.counts_table = counts_table + self.precision_recall = precision_recall + + print("Finished evaluating performance...") + + # def create_reference_map(self): + # """Creates a hashmap (dictionary) of reference annotations for efficient lookup""" + # reference_map = defaultdict(set) + # print("Creating reference map...") + # + # for _, row in self.annotations_df.iterrows(): + # doc_id = str(row['don_article']) + # figure = row['figure'] + # start = row['first left'] + # end = row['last right'] + # obj = row['obj'] # Use the object directly + # text = row['text'] + # + # for original_id in obj: + # reference_map[(doc_id, figure)].add( + # (text, original_id, start, end)) + # + # for synonym in row['obj_synonyms']: + # reference_map[(doc_id, figure)].add((text, synonym, start, end,)) + # + # # ref_obj_synonyms = self.get_synonym_set(row['obj']) + # # for syn in ref_obj_synonyms: + # # reference_map[(doc_id, figure)].add((syn, start, end)) + # + # if doc_id == '3868508' and figure == "Figure_1-A": + # print( + # f"Adding Reference Annotation: {(text, obj, start, end)} for Document ID: {doc_id}, Figure: {figure}") + # + # # reference_map[(doc_id, figure)].append({ + # # 'text': row['text'], + # # 'obj': row['obj'], + # # 'obj_synonyms': row['obj_synonyms'], + # # 'entity_type': row['entity_type'], + # # 'annot_id': row['annot id'], + # # 'first_left': row['first left'], + # # 'last_right': row['last right'] + # # }) + # + # return reference_map + + def get_results_tables(self): + return self.counts_table, self.precision_recall + + +# def get_famplex_members(): +# from indra.databases import hgnc_client +# entities_path = os.path.join(HERE, 'data', 'entities.csv') +# fplx_entities = famplex.load_entities() +# fplx_children = defaultdict(set) +# for fplx_entity in fplx_entities: +# members = famplex.individual_members('FPLX', fplx_entity) +# for db_ns, db_id in members: +# if db_ns == 'HGNC': +# db_id = hgnc_client.get_current_hgnc_id(db_id) +# if db_id: +# fplx_children[fplx_entity].add('%s:%s' % (db_ns, db_id)) +# return dict(fplx_children) +# +# +# fplx_members = get_famplex_members() + + +# def main(results: str): +def main(): + # results_path = os.path.expandvars(os.path.expanduser(results)) + # os.makedirs(results_path, exist_ok=True) + + benchmarker = BioIDNERBenchmarker() + benchmarker.print_annotations_for_doc_id(3868508) + benchmarker.annotate_entities_with_gilda() + benchmarker.evaluate_gilda_performance() + counts, precision_recall = benchmarker.get_results_tables() + print("Counts table:") + print(counts.to_markdown(index=False)) + print("Precision and Recall table:") + print(precision_recall.to_markdown(index=False)) + # time = datetime.now().strftime('%y%m%d-%H%M%S') + # result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}') + # counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False) + # precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"), + # index=False) + # print(f'Results saved to {results_path}') + + +if __name__ == '__main__': + main() From b1dfbb13c03d040a041b4b840403a5736f258174 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 12 Jul 2024 13:49:06 -0400 Subject: [PATCH 02/19] dded implementation of a stoplist for NER, set membership for for matching reference annotations, counter for most common false positives in bioid_ner_benchmark.py. Implemented a condition in the NER component to match only 2+ letter entities in ner.py. --- benchmarks/bioid_ner_benchmark.py | 572 +++++++++++------------------- gilda/ner.py | 2 + 2 files changed, 202 insertions(+), 372 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 714c4b6..f8dc7e5 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -1,7 +1,7 @@ import os import json import pathlib -from collections import defaultdict +from collections import defaultdict, Counter from functools import lru_cache import pandas as pd @@ -25,15 +25,11 @@ from indra.databases.uniprot_client import get_hgnc_id from indra.ontology.bio import bio_ontology - - logger.setLevel('WARNING') # Constants HERE = os.path.dirname(os.path.abspath(__file__)) -# MODULE = pystow.module('gilda', 'biocreative') -# URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017' -# '/BioIDtraining_2.tar.gz') + DATA_DIR = os.path.join(HERE, 'data', 'BioIDtraining_2', 'caption_bioc') ANNOTATIONS_PATH = os.path.join(HERE, 'data', 'BioIDtraining_2', 'annotations.csv') @@ -42,6 +38,8 @@ MODULE = pystow.module('gilda', 'biocreative') URL = 'https://biocreative.bioinformatics.udel.edu/media/store/files/2017/BioIDtraining_2.tar.gz' +STOPLIST_PATH = os.path.join(HERE, 'data', 'ner_stoplist.txt') + tqdm.pandas() BO_MISSING_XREFS = set() @@ -52,37 +50,15 @@ def __init__(self): print("Instantiating benchmarker...") self.equivalences = self._load_equivalences() self.paper_level_grounding = defaultdict(set) - self.processed_data = self.process_xml_files() #xml files processesed - self.annotations_df = self._process_annotations_table() #csv annotations + self.processed_data = self.process_xml_files() # xml files processesed + self.annotations_df = self._process_annotations_table() # csv annotations # self.reference_map = self.create_reference_map() # Create reference map for efficient lookup - self.gilda_annotations_map = defaultdict(list) # New field to store Gilda annotations + self.stoplist = self._load_stoplist() # Load stoplist + self.gilda_annotations_map = defaultdict( + list) # New field to store Gilda annotations self.counts_table = None self.precision_recall = None - # Print a small sample of annotations_df for debugging - print("Sample of annotations_df:") - print(self.annotations_df.head(10)) # Display first 10 rows - print(self.annotations_df.columns) # Display column names - - # Print unique values of doc_id and don_article for debugging - # print("Unique doc_id values in processed_data:") - # print(self.processed_data['id'].unique()[:10]) # Display first 10 unique IDs - # print("Unique don_article values in annotations_df:") - # print(self.annotations_df['don_article'].unique()[:10]) # Display first 10 unique IDs - # Print unique values of doc_id and don_article for debugging - print("First 10 unique doc_id values in processed_data:") - print(self.processed_data['doc_id'].unique()[ - :10]) # Display first 10 unique IDs - print("First 10 unique figure values in processed_data:") - print(self.processed_data['figure'].unique()[ - :10]) # Display first 10 unique IDs - print("First 10 unique don_article values in annotations_df:") - print(self.annotations_df['don_article'].unique()[ - :10]) # Display first 10 unique IDs - print("First 10 unique figure values in annotations_df:") - print(self.annotations_df['figure'].unique()[ - :10]) # Display first 10 unique IDs - def process_xml_files(self): """Extract relevant information from XML files.""" print("Extracting information from XML files...") @@ -95,7 +71,8 @@ def process_xml_files(self): root = tree.getroot() for document in root.findall('.//document'): doc_id_full = document.find('.//id').text.strip() - don_article, figure = doc_id_full.split(' ', 1) # Split the full ID to get don_article and figure + don_article, figure = doc_id_full.split(' ', + 1) # Split the full ID to get don_article and figure don_article = don_article for passage in document.findall('.//passage'): offset = int(passage.find('.//offset').text) @@ -104,9 +81,12 @@ def process_xml_files(self): for annotation in passage.findall('.//annotation'): annot_id = annotation.get('id') annot_text = annotation.find('.//text').text - annot_type = annotation.find('.//infon[@key="type"]').text - annot_offset = int(annotation.find('.//location').attrib['offset']) - annot_length = int(annotation.find('.//location').attrib['length']) + annot_type = annotation.find( + './/infon[@key="type"]').text + annot_offset = int( + annotation.find('.//location').attrib['offset']) + annot_length = int( + annotation.find('.//location').attrib['length']) annotations.append({ 'annot_id': annot_id, 'annot_text': annot_text, @@ -128,39 +108,26 @@ def process_xml_files(self): print("Finished extracting information from XML files.") return pd.DataFrame(data) - - - - - - # document = root.find('.//document') - # doc_id = document.find('.//id').text.strip() - # try: - # doc_id = int(doc_id) - # except ValueError: - # print(f"Skipping file with non-integer doc_id: {filename}") - # continue - # - # text_elements = document.findall('.//text') - # texts = [elem.text for elem in text_elements if elem.text] - # full_text = ' '.join(texts) - # - # if doc_id == 3868508: - # # Print the text being used for annotation for document ID 3868508 - # print(f"Document ID: {doc_id}") - # print( - # f"Full Text: {full_text[:500]}...") # Print first 500 characters for brevity - # - # data.append({'id': doc_id, 'text': full_text}) - # df = pd.DataFrame(data) - # print(f"{len(df)} rows in processed XML data.") - # return df + def _load_stoplist(self) -> Set[str]: + """Load NER stoplist from file.""" + stoplist_path = STOPLIST_PATH + try: + with open(stoplist_path, 'r') as file: + stoplist = {line.strip().lower() for line in file} + print(f"Loaded stoplist with {len(stoplist)} entries.") + return stoplist + except FileNotFoundError: + print( + f"No stoplist found at {stoplist_path}. Proceeding without it.") + return set() def _load_equivalences(self) -> Dict[str, List[str]]: try: - with open(os.path.join(DATA_DIR, 'equivalences.json')) as f: + with open(os.path.join(HERE, 'data', 'equivalences.json')) as f: equivalences = json.load(f) except FileNotFoundError: + print( + f"No Equivalences found at {os.path.join(HERE, 'data', 'equivalences.json')}. Proceeding without it.") equivalences = {} return equivalences @@ -190,9 +157,9 @@ def get_synonym_set(self, curies: Iterable[str]) -> Set[str]: output.update(self._get_equivalent_entities(curie)) # We accept all FamPlex terms that cover some or all of the specific # entries in the annotations - # covered_fplx = {fplx_entry for fplx_entry, members - # in fplx_members.items() if (members <= output)} - # output |= {'FPLX:%s' % fplx_entry for fplx_entry in covered_fplx} + covered_fplx = {fplx_entry for fplx_entry, members + in fplx_members.items() if (members <= output)} + output |= {'FPLX:%s' % fplx_entry for fplx_entry in covered_fplx} return output def _get_equivalent_entities(self, curie: str) -> Set[str]: @@ -282,9 +249,11 @@ def _process_annotations_table(self): # Add synonyms of gold standard groundings to help match more things df.loc[:, 'obj_synonyms'] = df['obj'].apply(self.get_synonym_set) # Create column for entity type - df.loc[:, 'entity_type'] = df.apply(self._get_entity_type_helper, axis=1) + df.loc[:, 'entity_type'] = df.apply(self._get_entity_type_helper, + axis=1) processed_data = df[['text', 'obj', 'obj_synonyms', 'entity_type', - 'don_article', 'figure', 'annot id', 'first left', 'last right']] + 'don_article', 'figure', 'annot id', 'first left', + 'last right']] print("%d rows in processed annotations table." % len(processed_data)) processed_data = processed_data[processed_data.entity_type != 'unknown'] @@ -317,18 +286,13 @@ def _get_plaintext(self, don_article: str) -> str: paragraphs = [' '.join(text.itertext()) for text in paragraphs] return '/n'.join(paragraphs) + '/n' - def print_annotations_for_doc_id(self, doc_id): - filtered_df = self.annotations_df[ - self.annotations_df['don_article'] == doc_id] - print(f"Annotations for Document ID {doc_id}:") - print(filtered_df) - def annotate_entities_with_gilda(self): """Performs NER on the XML files using gilda.annotate()""" # df = self.processed_data tqdm.write("Annotating corpus with Gilda...") - results = [] + # results = [] + total_gilda_annotations = 0 # for _, row in df.iterrows(): for _, item in self.processed_data.iterrows(): doc_id = item['doc_id'] @@ -339,185 +303,73 @@ def annotate_entities_with_gilda(self): # Get the full text for the paper-level disambiguation full_text = self._get_plaintext(doc_id) - gilda_annotations = annotate(text, context_text=full_text) - - for matched_text, scored_ner_match, start, end in gilda_annotations: - grounding_results = ground(matched_text, context=full_text) - - for scored_grounding_match in grounding_results: - - db, entity_id = grounding_results[0].term.db, grounding_results[0].term.id - curie = f"{db}:{entity_id}" - # normalized_id = self._normalize_id(curie) - synonyms = self.get_synonym_set([curie]) - entity_type = self._get_entity_type([curie]) - - self.gilda_annotations_map[(doc_id, figure)].append({ - 'matched_text': matched_text, - 'db': db, - 'id': entity_id, - 'start': start, - 'end': end, - # 'normalized_id': normalized_id, - 'synonyms': synonyms, - 'entity_type': entity_type - }) + gilda_annotations = annotate(text, context_text=full_text, + return_first=True) + # for testing all matches for each entity, return_first = False. + + for matched_text, grounding_result, start, end in gilda_annotations: + + # Checking against stoplist + if matched_text in self.stoplist: + continue + + db, entity_id = grounding_result.term.db, grounding_result.term.id + curie = f"{db}:{entity_id}" + # normalized_id = self._normalize_id(curie) + synonyms = self.get_synonym_set([curie]) + # entity_type = self._get_entity_type([curie]) + + self.gilda_annotations_map[(doc_id, figure)].append({ + 'matched_text': matched_text, + 'db': db, + 'id': entity_id, + 'start': start, + 'end': end, + # 'normalized_id': normalized_id, + 'synonyms': synonyms, + # 'entity_type': entity_type + }) + total_gilda_annotations += 1 # results.append( # (scored_match.term.db, scored_match.term.id, start, end)) - if doc_id == '3868508' and figure == 'Figure_1-A': - print(f"Document ID: {doc_id}, Figure: {figure}") - print(f"Scored NER Match: {scored_ner_match}") - print( - f"Annotated Text Segment: {text[max(0, start - 50):min(len(text), end + 50)]}") - print( - f"Matched Text: {matched_text}, DB: {db}, ID: {entity_id}, Start: {start}, End: {end}") - print(f"Grounding Results: {curie}") - print(f"synonyms: {synonyms}") - print(f"entity type: {entity_type}") - # annotations.append({'doc_id': doc_id, 'text': text, 'gilda_annotations': gilda_annotations}) - - - - # df = pd.DataFrame(annotations) - # self.processed_data = df - # self.processed_data = pd.DataFrame(results) - tqdm.write("Finished annotating corpus with Gilda...") + if doc_id == '3868508' and figure == 'Figure_1-A': + print(f"Scored NER Match: {grounding_result}") + print(f"Annotated Text Segment: {text[start:end]} at " + f"indices {start} to {end}") + print( + f"Gilda Matched Text: {matched_text}, DB: {db}, " + f"ID: {entity_id}, Start: {start}, End: {end}") + print(f"Grounding Results: {curie}") + print(f"synonyms: {synonyms}") + # print(f"entity type: {entity_type}") + print("\n") - # Print a small sample of processed_data for debugging - # print("Sample of processed_data:") - # print(self.processed_data.head(10)) # Display first 10 rows - # print(self.processed_data.columns) # Display column names - - # Update paper-level grounding with Gilda annotations - # for _, row in df.iterrows(): - # doc_id = row['doc_id'] - # text = row['text'] - # gilda_ann = row['gilda_annotations'] - # for ann in gilda_ann: - # self.paper_level_grounding[(doc_id, text)].add(ann) - - # for doc_id, text, gilda_ann in zip(df['id'], df['text'], - # df['gilda_annotations']): - # for ann in gilda_ann: - # self.paper_level_grounding[(doc_id, text)].add( - # (ann[1].term.db, ann[1].term.id, ann[2], ann[3])) - - # def is_correct_annotation(self, row): - # """Cross-references gilda annotations with annotations - # provided by the dataset""" - # doc_id = int(row['doc_id']) - # figure = row['figure'] - # db = row['db'] - # annot_id = row['id'] - # start = row['start'] - # end = row['end'] - # - # # Ensure there are no leading/trailing spaces in the DataFrame - # self.annotations_df['figure'] = self.annotations_df['figure'].str.strip() - # - # # ref_annotations = set() - # - # ref_data = self.annotations_df[ - # (self.annotations_df['don_article'] == doc_id) & - # (self.annotations_df['figure'] == figure) - # ] - # - # # specific_doc_id = 3868508 - # # specific_figure = 'Figure_1-A' - # # Debugging output to check why ref_data might be empty - # # if doc_id == specific_doc_id and figure == specific_figure: - # # print( - # # f"Checking annotations for Document ID: {doc_id}, Figure: {figure}") - # # print(f"Annotations in DataFrame:\n{self.annotations_df.head()}") - # # print(f"Filtered Reference Data:\n{ref_data}") - # - # ref_annotations = set( - # (r['annot id'], r['first left'], r['last right']) - # for _, r in ref_data.iterrows() - # ) - # - # gilda_annotation = (annot_id, f'{db}:{row["id"]}') - # - # - # true_positive = gilda_annotation in ref_annotations - # false_positive = gilda_annotation not in ref_annotations - # false_negative = gilda_annotation not in ref_annotations - # - # # expanded_gilda_annotations = {annot_id, db, row['start'], row['end']} - # - # # for _, ref_row in ref_data.iterrows(): - # # ref_text = ref_row['obj'] - # # annot_id = ref_row['annot id'] - # # start_pos = int(ref_row['first left']) - # # end_pos = int(ref_row['last right']) - # # for normalized_id in ref_text: - # # parts = normalized_id.split(':') - # # db, identifier = parts[0], ':'.join(parts[1:]) - # # ref_annotations.add((annot_id, db, identifier, start_pos, end_pos)) - # # - # # Use synonym sets to match equivalent terms - # # expanded_gilda_annotations = set() - # # for ann in gilda_annotations: - # # expanded_gilda_annotations.update( - # # self.get_synonym_set([f"{ann[0]}:{ann[1]}"])) - # - # # Conditional print for debugging: Limit to a small set of data - # # if doc_id == specific_doc_id and figure == specific_figure: - # # print("Gilda Annotations:") - # # print(list(gilda_annotations)[:5]) # Print first few annotations for brevity - # # print("Reference Annotations:") - # # print(list(ref_annotations)[:5]) # Print first few annotations for brevity - # - # # return expanded_gilda_annotations, ref_annotations - # - # # true_positives = 0 - # # false_positives = 0 - # # false_negatives = 0 - # - # # Calculate true positives and false positives - # # for gilda_ann in expanded_gilda_annotations: - # # if gilda_ann in ref_annotations: - # # true_positives += 1 - # # else: - # # false_positives += 1 - # # - # # # Calculate false negatives - # # for ref_ann in ref_annotations: - # # if ref_ann not in expanded_gilda_annotations: - # # false_negatives += 1 - # - # return true_positive, false_positive, false_negative - - # def check_annotation(self, row, reference_map): - # """Checks if the Gilda annotation exists in the reference annotations""" - # doc_id = int(row['doc_id']) - # figure = row['figure'] - # gilda_annotation = f'{row["db"]}:{row["id"]}' - # gilda_synonyms = self.get_synonym_set([gilda_annotation]) - # - # if (doc_id, figure) in reference_map: - # ref_annotations = reference_map[(doc_id, figure)] - # - # true_positive = any( - # syn in ref_annotations for syn in gilda_synonyms) - # false_positive = not true_positive - # false_negative = not true_positive - # - # return int(true_positive), int(false_positive), int(false_negative) - # - # return 0, 1, 0 # If there are no reference annotations, it's a false positive + tqdm.write("Finished annotating corpus with Gilda...") + print(f"Total Gilda annotations: {total_gilda_annotations}") def evaluate_gilda_performance(self): """Calculates precision, recall, and F1""" print("Evaluating performance...") + # df = self.processed_data total_true_positives = 0 total_false_positives = 0 total_false_negatives = 0 - - # reference_map = self.create_reference_map() - + false_positives_counter = Counter() + + # Create a set of reference annotations for quick lookup + ref_annotations = set() + for _, row in self.annotations_df.iterrows(): + doc_id = str(row['don_article']) + figure = row['figure'] + text = row['text'] + for syn in row['obj_synonyms']: + ref_annotations.add((doc_id, figure, text, syn, + row['first left'], + row['last right'])) + + print(f"Total reference annotations: {len(ref_annotations)}") for (doc_id, figure), annotations in self.gilda_annotations_map.items(): # print(f"Processing Document ID: {doc_id}, Figure: {figure}") for annotation in annotations: @@ -527,87 +379,101 @@ def evaluate_gilda_performance(self): gilda_synonyms = annotation['synonyms'] text = annotation['matched_text'] - # gilda_annotation = f'{annotation["db"]}:{annotation["id"]}' - # gilda_synonyms = self.get_synonym_set([gilda_annotation]) + # Uncomment this if above doesnt work. + # ref_annotations = self.annotations_df[ + # (self.annotations_df['don_article'] == int(doc_id)) & + # (self.annotations_df['figure'] == figure) + # ] - # ref_annotations = self.reference_map.get((doc_id, figure), set()) - - ref_annotations = self.annotations_df[ - (self.annotations_df['don_article'] == int(doc_id)) & - (self.annotations_df['figure'] == figure) - ] - - - - # Check if any synonym of the Gilda annotation matches reference annotations - # match_found = any((text, syn, start, end) in ref_annotations for syn in gilda_synonyms) + # UNCOMMENT IF BELOW DOESN'T WORK + # match_found = any( + # (text, syn, start, end) in ref_annotations[ + # ['text', 'obj_synonyms', 'first left', + # 'last right']].values + # for syn in gilda_synonyms + # ) match_found = any( - (text, syn, start, end) in ref_annotations[ - ['text', 'obj_synonyms', 'first left', - 'last right']].values - for syn in gilda_synonyms - ) + (doc_id, figure, text, syn, start, end) + in ref_annotations for syn in gilda_synonyms) - - if doc_id == '3868508' and figure == "Figure_1-A": - print(f"Document ID: {doc_id}, Figure: {figure}") + # Debugging: Identify and print the exact match + matching_reference = None + if match_found: + for syn in gilda_synonyms: + if (doc_id, figure, text, syn, start, + end) in ref_annotations: + matching_reference = ( + doc_id, figure, text, syn, start, end) + break + + if (match_found == True and doc_id == '3868508' + and figure == "Figure_1-A"): print(f"Gilda Annotation: {annotation}") - print(f"Reference Annotations: {ref_annotations}") + # print(f"Reference Annotations: {ref_annotations}") + print(f"Match Found: {match_found}") + print(f"Synonyms: {gilda_synonyms}") print(f"Match Found: {match_found}") + if match_found: + print(f"Matching Reference: {matching_reference}") if match_found: total_true_positives += 1 else: total_false_positives += 1 - total_false_negatives += 1 - - # total_true_positives += int(true_positive) - # total_false_positives += int(false_positive) - # total_false_negatives += int(false_negative) - - # tp, fp, fn = self.check_annotation(row, reference_map) - - # = self.is_correct_annotation(row) - # total_true_positives += tp - # total_false_positives += fp - # total_false_negatives += fn - - - - # for data in self.process_xml_files(): - # doc_id = data['doc_id'] - # text = data['text'] - # gilda_ann = data['gilda_annotations'] - # figure = data.get('figure','') # Ensure figure is retrieved if it exists - - precision = total_true_positives / ( - total_true_positives + total_false_positives) if (total_true_positives + total_false_positives) > 0 else 0 - recall = total_true_positives / ( - total_true_positives + total_false_negatives) if (total_true_positives + total_false_negatives) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - - # df = self.processed_data - # results = df.apply(self.is_correct_annotation, axis=1) + false_positives_counter[text] += 1 + # total_false_negatives += 1 + print(f"20 Most Common False Positives: " + f"{false_positives_counter.most_common(20)}") + + for doc_id, figure, text, syn, start, end in ref_annotations: + gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), + []) + match_found = any( + ann['matched_text'] == text and + syn in ann['synonyms'] and + ann['start'] == start and + ann['end'] == end + for ann in gilda_annotations + ) + if not match_found: + total_false_negatives += 1 + + + + + # UNCOMMENT IF ABOVE DOESNT WORK + # for _, row in self.annotations_df.iterrows(): + # doc_id, figure = row['don_article'], row['figure'] + # new_ref_annotation = ( + # row['text'], set(row['obj_synonyms']), row['first left'], + # row['last right']) + # + # # Check if this reference annotation is in Gilda's annotations + # new_gilda_annotations = self.gilda_annotations_map.get( + # (str(doc_id), figure), []) + # ref_match_found = False + # for ann in new_gilda_annotations: + # # (new_ref_annotation[1].intersection(ann['synonyms']) and + # if new_ref_annotation[0] == ann['matched_text'] and \ + # new_ref_annotation[2] == ann['start'] and \ + # new_ref_annotation[3] == ann['end']: + # ref_match_found = True + # break + # + # if not ref_match_found: + # total_false_negatives += 1 - # df['true_positives'] = results.apply(lambda x: x[0]) - # df['false_positives'] = results.apply(lambda x: x[1]) - # df['false_negatives'] = results.apply(lambda x: x[2]) + precision = total_true_positives / (total_true_positives + + total_false_positives) \ + if (total_true_positives + total_false_positives) > 0 else 0.0 - # total_true_positives = df['true_positives'].sum() - # total_false_positives = df['false_positives'].sum() - # total_false_negatives = df['false_negatives'].sum() + recall = total_true_positives / (total_true_positives + + total_false_negatives) \ + if (total_true_positives + total_false_negatives) > 0 else 0.0 - # precision = total_true_positives / (total_true_positives - # + total_false_positives) \ - # if (total_true_positives + total_false_positives) > 0 else 0 - # - # recall = total_true_positives / (total_true_positives + - # total_false_negatives) \ - # if (total_true_positives + total_false_negatives) > 0 else 0 - # - # f1 = 2 * (precision * recall) / (precision + recall) \ - # if (precision + recall) > 0 else 0 + f1 = (2 * (precision * recall)) / (precision + recall) \ + if ((precision + recall) > 0) else 0 counts_table = pd.DataFrame([{ 'True Positives': total_true_positives, @@ -621,75 +487,38 @@ def evaluate_gilda_performance(self): 'F1 Score': f1 }]) - print("Aggregated Results:") - print(counts_table) - print(precision_recall) - self.counts_table = counts_table self.precision_recall = precision_recall - print("Finished evaluating performance...") + os.makedirs(RESULTS_DIR, exist_ok=True) + false_positives_df = pd.DataFrame(false_positives_counter.items(), + columns=['False Positive Text', + 'Count']) + false_positives_df = false_positives_df.sort_values(by='Count', ascending=False) + false_positives_df.to_csv( + os.path.join(RESULTS_DIR, 'false_positives.csv'), index=False) - # def create_reference_map(self): - # """Creates a hashmap (dictionary) of reference annotations for efficient lookup""" - # reference_map = defaultdict(set) - # print("Creating reference map...") - # - # for _, row in self.annotations_df.iterrows(): - # doc_id = str(row['don_article']) - # figure = row['figure'] - # start = row['first left'] - # end = row['last right'] - # obj = row['obj'] # Use the object directly - # text = row['text'] - # - # for original_id in obj: - # reference_map[(doc_id, figure)].add( - # (text, original_id, start, end)) - # - # for synonym in row['obj_synonyms']: - # reference_map[(doc_id, figure)].add((text, synonym, start, end,)) - # - # # ref_obj_synonyms = self.get_synonym_set(row['obj']) - # # for syn in ref_obj_synonyms: - # # reference_map[(doc_id, figure)].add((syn, start, end)) - # - # if doc_id == '3868508' and figure == "Figure_1-A": - # print( - # f"Adding Reference Annotation: {(text, obj, start, end)} for Document ID: {doc_id}, Figure: {figure}") - # - # # reference_map[(doc_id, figure)].append({ - # # 'text': row['text'], - # # 'obj': row['obj'], - # # 'obj_synonyms': row['obj_synonyms'], - # # 'entity_type': row['entity_type'], - # # 'annot_id': row['annot id'], - # # 'first_left': row['first left'], - # # 'last_right': row['last right'] - # # }) - # - # return reference_map + print("Finished evaluating performance...") def get_results_tables(self): return self.counts_table, self.precision_recall -# def get_famplex_members(): -# from indra.databases import hgnc_client -# entities_path = os.path.join(HERE, 'data', 'entities.csv') -# fplx_entities = famplex.load_entities() -# fplx_children = defaultdict(set) -# for fplx_entity in fplx_entities: -# members = famplex.individual_members('FPLX', fplx_entity) -# for db_ns, db_id in members: -# if db_ns == 'HGNC': -# db_id = hgnc_client.get_current_hgnc_id(db_id) -# if db_id: -# fplx_children[fplx_entity].add('%s:%s' % (db_ns, db_id)) -# return dict(fplx_children) -# -# -# fplx_members = get_famplex_members() +def get_famplex_members(): + from indra.databases import hgnc_client + fplx_entities = famplex.load_entities() + fplx_children = defaultdict(set) + for fplx_entity in fplx_entities: + members = famplex.individual_members('FPLX', fplx_entity) + for db_ns, db_id in members: + if db_ns == 'HGNC': + db_id = hgnc_client.get_current_hgnc_id(db_id) + if db_id: + fplx_children[fplx_entity].add('%s:%s' % (db_ns, db_id)) + return dict(fplx_children) + + +fplx_members = get_famplex_members() # def main(results: str): @@ -698,7 +527,6 @@ def main(): # os.makedirs(results_path, exist_ok=True) benchmarker = BioIDNERBenchmarker() - benchmarker.print_annotations_for_doc_id(3868508) benchmarker.annotate_entities_with_gilda() benchmarker.evaluate_gilda_performance() counts, precision_recall = benchmarker.get_results_tables() diff --git a/gilda/ner.py b/gilda/ner.py index 39a5009..355d361 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -149,6 +149,8 @@ def annotate( spaces = ' ' * (c[0] - len(raw_span) - raw_word_coords[idx][0]) raw_span += spaces + rw + # if len(txt_span) <= 1: + # continue context = text if context_text is None else context_text matches = grounder.ground(raw_span, context=context, From 8366fff7301bdae887b0f3d2b982ebc00a292775 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Sun, 14 Jul 2024 11:43:27 -0400 Subject: [PATCH 03/19] Make annotation return a list of scored matches --- gilda/ner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gilda/ner.py b/gilda/ner.py index 355d361..570991e 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -14,6 +14,7 @@ - the `start` position in the text string where the entity starts - the `end` position in the text string where the entity ends + In this example, the two concepts are grounded to FamPlex entries. >>> results[0].text, results[0].matches[0].term.get_curie(), results[0].start, results[0].end From 18cca737fb715429c15fa47fe7ec0144830b2261 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Thu, 18 Jul 2024 09:31:50 -0400 Subject: [PATCH 04/19] Commit with changes to gilda evaluation logic. Returns metrics for both top matches and all matches. --- benchmarks/bioid_ner_benchmark.py | 257 +++++++++++++++++++----------- 1 file changed, 163 insertions(+), 94 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index f8dc7e5..30f381c 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -54,10 +54,12 @@ def __init__(self): self.annotations_df = self._process_annotations_table() # csv annotations # self.reference_map = self.create_reference_map() # Create reference map for efficient lookup self.stoplist = self._load_stoplist() # Load stoplist - self.gilda_annotations_map = defaultdict( - list) # New field to store Gilda annotations + self.gilda_annotations_map = defaultdict(list) + self.annotations_count = 0 + # New field to store Gilda annotations self.counts_table = None self.precision_recall = None + self.performance_metrics = None def process_xml_files(self): """Extract relevant information from XML files.""" @@ -105,6 +107,7 @@ def process_xml_files(self): # df = pd.DataFrame(data) # print(f"{len(df)} rows in processed XML data.") print(f"Total annotations in XML files: {total_annotations}") + self.annotations_count = total_annotations print("Finished extracting information from XML files.") return pd.DataFrame(data) @@ -303,46 +306,48 @@ def annotate_entities_with_gilda(self): # Get the full text for the paper-level disambiguation full_text = self._get_plaintext(doc_id) - gilda_annotations = annotate(text, context_text=full_text, - return_first=True) + gilda_annotations = annotate(text, context_text=full_text) # for testing all matches for each entity, return_first = False. - for matched_text, grounding_result, start, end in gilda_annotations: - - # Checking against stoplist - if matched_text in self.stoplist: - continue - - db, entity_id = grounding_result.term.db, grounding_result.term.id - curie = f"{db}:{entity_id}" - # normalized_id = self._normalize_id(curie) - synonyms = self.get_synonym_set([curie]) - # entity_type = self._get_entity_type([curie]) - - self.gilda_annotations_map[(doc_id, figure)].append({ - 'matched_text': matched_text, - 'db': db, - 'id': entity_id, - 'start': start, - 'end': end, - # 'normalized_id': normalized_id, - 'synonyms': synonyms, - # 'entity_type': entity_type - }) + for matched_text, scored_matches, start, end in gilda_annotations: total_gilda_annotations += 1 + # all_synonyms = set() # new addition + for grounding_result in scored_matches: + # Checking against stoplist + if matched_text in self.stoplist: + continue + + db, entity_id = grounding_result.term.db, grounding_result.term.id + curie = f"{db}:{entity_id}" # unnecessary btw + # normalized_id = self._normalize_id(curie) + # synonyms = self.get_synonym_set([curie]) + # all_synonyms.update(synonyms) #new addition. + # entity_type = self._get_entity_type([curie]) + + self.gilda_annotations_map[(doc_id, figure)].append({ + 'matched_text': matched_text, + 'db': db, + 'id': entity_id, + 'start': start, + 'end': end, + # 'normalized_id': normalized_id, + # 'synonyms': all_synonyms #new addition, otherwise change back to just synonyms + # 'entity_type': entity_type + }) + # results.append( # (scored_match.term.db, scored_match.term.id, start, end)) - if doc_id == '3868508' and figure == 'Figure_1-A': - print(f"Scored NER Match: {grounding_result}") - print(f"Annotated Text Segment: {text[start:end]} at " - f"indices {start} to {end}") - print( - f"Gilda Matched Text: {matched_text}, DB: {db}, " - f"ID: {entity_id}, Start: {start}, End: {end}") - print(f"Grounding Results: {curie}") - print(f"synonyms: {synonyms}") - # print(f"entity type: {entity_type}") - print("\n") + if doc_id == '3868508' and figure == 'Figure_1-A': + print(f"Scored NER Match: {grounding_result}") + print(f"Annotated Text Segment: {text[start:end]} at " + f"indices {start} to {end}") + print( + f"Gilda Matched Text: {matched_text}, DB: {db}, " + f"ID: {entity_id}, Start: {start}, End: {end}") + print(f"Grounding Results: {curie}") + # print(f"synonyms: {synonyms}") + # print(f"entity type: {entity_type}") + print("\n") tqdm.write("Finished annotating corpus with Gilda...") print(f"Total Gilda annotations: {total_gilda_annotations}") @@ -353,13 +358,14 @@ def evaluate_gilda_performance(self): # df = self.processed_data - total_true_positives = 0 - total_false_positives = 0 - total_false_negatives = 0 + # total_true_positives = 0 + # total_false_positives = 0 + # total_false_negatives = 0 false_positives_counter = Counter() # Create a set of reference annotations for quick lookup ref_annotations = set() + for _, row in self.annotations_df.iterrows(): doc_id = str(row['don_article']) figure = row['figure'] @@ -369,42 +375,46 @@ def evaluate_gilda_performance(self): row['first left'], row['last right'])) + print(f"Total reference annotations: {len(ref_annotations)}") + metrics = { + 'all_matches': {'tp': 0, 'fp': 0, 'fn': 0}, + 'top_match': {'tp': 0, 'fp': 0, 'fn': 0} + } + for (doc_id, figure), annotations in self.gilda_annotations_map.items(): # print(f"Processing Document ID: {doc_id}, Figure: {figure}") - for annotation in annotations: + for i, annotation in enumerate(annotations): start = annotation['start'] end = annotation['end'] # gilda_annotation = annotation['id'] - gilda_synonyms = annotation['synonyms'] + # gilda_synonyms = annotation['synonyms'] text = annotation['matched_text'] + curie = f"{annotation['db']}:{annotation['id']}" - # Uncomment this if above doesnt work. - # ref_annotations = self.annotations_df[ - # (self.annotations_df['don_article'] == int(doc_id)) & - # (self.annotations_df['figure'] == figure) - # ] + match_found = (doc_id, figure, text, curie, start, end) in ref_annotations + + if match_found: + metrics['all_matches']['tp'] += 1 + if i == 0: # Top match + metrics['top_match']['tp'] += 1 + else: + metrics['all_matches']['fp'] += 1 + false_positives_counter[text] += 1 + if i == 0: # Top match + metrics['top_match']['fp'] += 1 - # UNCOMMENT IF BELOW DOESN'T WORK - # match_found = any( - # (text, syn, start, end) in ref_annotations[ - # ['text', 'obj_synonyms', 'first left', - # 'last right']].values - # for syn in gilda_synonyms - # ) - match_found = any( - (doc_id, figure, text, syn, start, end) - in ref_annotations for syn in gilda_synonyms) + # match_found = any( + # (doc_id, figure, text, syn, start, end) + # in ref_annotations for syn in gilda_synonyms) # Debugging: Identify and print the exact match matching_reference = None if match_found: - for syn in gilda_synonyms: - if (doc_id, figure, text, syn, start, - end) in ref_annotations: + if (doc_id, figure, text, curie, start, end) in ref_annotations: matching_reference = ( - doc_id, figure, text, syn, start, end) + doc_id, figure, text, curie, start, end) break if (match_found == True and doc_id == '3868508' @@ -412,33 +422,74 @@ def evaluate_gilda_performance(self): print(f"Gilda Annotation: {annotation}") # print(f"Reference Annotations: {ref_annotations}") print(f"Match Found: {match_found}") - print(f"Synonyms: {gilda_synonyms}") - print(f"Match Found: {match_found}") if match_found: print(f"Matching Reference: {matching_reference}") - if match_found: - total_true_positives += 1 - else: - total_false_positives += 1 - false_positives_counter[text] += 1 + # if match_found: + # total_true_positives += 1 + # else: + # total_false_positives += 1 + # false_positives_counter[text] += 1 # total_false_negatives += 1 print(f"20 Most Common False Positives: " f"{false_positives_counter.most_common(20)}") - for doc_id, figure, text, syn, start, end in ref_annotations: + # Calculate false negatives + # for doc_id, figure, text, syn, start, end in ref_annotations: + # gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), + # []) + # match_found = any( + # ann['matched_text'] == text and + # f"{ann['db']}:{ann['id']}" == syn and + # ann['start'] == start and + # ann['end'] == end + # for ann in gilda_annotations + # ) + # if not match_found: + # metrics['all_matches']['fn'] += 1 + # metrics['top_match']['fn'] += 1 + + # Separate False Negatives Calculation using the DataFrame + for _, row in self.annotations_df.iterrows(): + doc_id = str(row['don_article']) + figure = row['figure'] + text = row['text'] + curie = row['obj'] + start = row['first left'] + end = row['last right'] + gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), []) match_found = any( ann['matched_text'] == text and - syn in ann['synonyms'] and + f"{ann['db']}:{ann['id']}" == curie and ann['start'] == start and ann['end'] == end for ann in gilda_annotations ) if not match_found: - total_false_negatives += 1 - + metrics['all_matches']['fn'] += 1 + metrics['top_match']['fn'] += 1 + + results = {} + for match_type, counts in metrics.items(): + precision = counts['tp'] / (counts['tp'] + counts['fp']) if (counts[ + 'tp'] + + counts[ + 'fp']) > 0 else 0 + recall = counts['tp'] / (counts['tp'] + counts['fn']) if (counts[ + 'tp'] + + counts[ + 'fn']) > 0 else 0 + f1 = 2 * (precision * recall) / (precision + recall) if ( + precision + recall) > 0 else 0 + results[match_type] = { + 'precision': precision, + 'recall': recall, + 'f1': f1 + } + + self.performance_metrics = results @@ -464,28 +515,46 @@ def evaluate_gilda_performance(self): # if not ref_match_found: # total_false_negatives += 1 - precision = total_true_positives / (total_true_positives - + total_false_positives) \ - if (total_true_positives + total_false_positives) > 0 else 0.0 - - recall = total_true_positives / (total_true_positives - + total_false_negatives) \ - if (total_true_positives + total_false_negatives) > 0 else 0.0 - - f1 = (2 * (precision * recall)) / (precision + recall) \ - if ((precision + recall) > 0) else 0 - - counts_table = pd.DataFrame([{ - 'True Positives': total_true_positives, - 'False Positives': total_false_positives, - 'False Negatives': total_false_negatives - }]) - - precision_recall = pd.DataFrame([{ - 'Precision': precision, - 'Recall': recall, - 'F1 Score': f1 - }]) + # precision = total_true_positives / (total_true_positives + # + total_false_positives) \ + # if (total_true_positives + total_false_positives) > 0 else 0.0 + # + # recall = total_true_positives / (total_true_positives + # + total_false_negatives) \ + # if (total_true_positives + total_false_negatives) > 0 else 0.0 + # + # f1 = (2 * (precision * recall)) / (precision + recall) \ + # if ((precision + recall) > 0) else 0 + + counts_table = pd.DataFrame([ + { + 'Match Type': 'All Matches', + 'True Positives': metrics['all_matches']['tp'], + 'False Positives': metrics['all_matches']['fp'], + 'False Negatives': metrics['all_matches']['fn'] + }, + { + 'Match Type': 'Top Match', + 'True Positives': metrics['top_match']['tp'], + 'False Positives': metrics['top_match']['fp'], + 'False Negatives': metrics['top_match']['fn'] + } + ]) + + precision_recall = pd.DataFrame([ + { + 'Match Type': 'All Matches', + 'Precision': results['all_matches']['precision'], + 'Recall': results['all_matches']['recall'], + 'F1 Score': results['all_matches']['f1'] + }, + { + 'Match Type': 'Top Match', + 'Precision': results['top_match']['precision'], + 'Recall': results['top_match']['recall'], + 'F1 Score': results['top_match']['f1'] + } + ]) self.counts_table = counts_table self.precision_recall = precision_recall From db72a7110b47910a1bf2e0c3b116805c3f7d7d61 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Thu, 18 Jul 2024 09:33:56 -0400 Subject: [PATCH 05/19] Committing changes to app.py with NER class. --- gilda/app/app.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gilda/app/app.py b/gilda/app/app.py index 801f96a..152c68e 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -8,6 +8,7 @@ from gilda import __version__ as version from gilda.grounder import GrounderInput, Grounder from gilda.app.proxies import grounder +from gilda.ner import annotate # NOTE: the Flask REST-X API has to be declared here, below the home endpoint # otherwise it reserves the / base path. @@ -311,7 +312,6 @@ def post(self): return jsonify([annotation.to_json() for annotation in results]) - def get_app(terms: Optional[GrounderInput] = None, *, ui: bool = True) -> Flask: app = Flask(__name__) app.config['RESTX_MASK_SWAGGER'] = False From 728ee34fcb833957271f7a75a6edac80fb83f4aa Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Thu, 18 Jul 2024 15:19:08 -0400 Subject: [PATCH 06/19] Modified to work with annotation objects --- benchmarks/bioid_ner_benchmark.py | 287 ++++++++++-------------------- 1 file changed, 93 insertions(+), 194 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 30f381c..8bdbf2e 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -15,7 +15,6 @@ import pystow import gilda from gilda import ground -# from benchmarks.bioid_evaluation import fplx_members from gilda.ner import annotate from gilda.grounder import logger @@ -52,14 +51,12 @@ def __init__(self): self.paper_level_grounding = defaultdict(set) self.processed_data = self.process_xml_files() # xml files processesed self.annotations_df = self._process_annotations_table() # csv annotations - # self.reference_map = self.create_reference_map() # Create reference map for efficient lookup self.stoplist = self._load_stoplist() # Load stoplist self.gilda_annotations_map = defaultdict(list) self.annotations_count = 0 # New field to store Gilda annotations self.counts_table = None self.precision_recall = None - self.performance_metrics = None def process_xml_files(self): """Extract relevant information from XML files.""" @@ -73,8 +70,8 @@ def process_xml_files(self): root = tree.getroot() for document in root.findall('.//document'): doc_id_full = document.find('.//id').text.strip() - don_article, figure = doc_id_full.split(' ', - 1) # Split the full ID to get don_article and figure + # Split the full ID to get don_article and figure + don_article, figure = doc_id_full.split(' ',1) don_article = don_article for passage in document.findall('.//passage'): offset = int(passage.find('.//offset').text) @@ -104,8 +101,6 @@ def process_xml_files(self): 'text': text, 'annotations': annotations, }) - # df = pd.DataFrame(data) - # print(f"{len(df)} rows in processed XML data.") print(f"Total annotations in XML files: {total_annotations}") self.annotations_count = total_annotations print("Finished extracting information from XML files.") @@ -130,7 +125,9 @@ def _load_equivalences(self) -> Dict[str, List[str]]: equivalences = json.load(f) except FileNotFoundError: print( - f"No Equivalences found at {os.path.join(HERE, 'data', 'equivalences.json')}. Proceeding without it.") + f"No Equivalences found at " + f"{os.path.join(HERE, 'data', 'equivalences.json')}. " + f"Proceeding without it.") equivalences = {} return equivalences @@ -185,7 +182,8 @@ def _get_equivalent_entities(self, curie: str) -> Set[str]: if (prefix, xref_prefix) not in BO_MISSING_XREFS: BO_MISSING_XREFS.add((prefix, xref_prefix)) tqdm.write( - f'Bioontology v{bio_ontology.version} is missing mappings from {prefix} to {xref_prefix}') + f'Bioontology v{bio_ontology.version} is missing mappings' + f' from {prefix} to {xref_prefix}') output.add(xref_curie) if prefix == 'NCBI gene': @@ -291,17 +289,13 @@ def _get_plaintext(self, don_article: str) -> str: def annotate_entities_with_gilda(self): """Performs NER on the XML files using gilda.annotate()""" - # df = self.processed_data tqdm.write("Annotating corpus with Gilda...") - # results = [] total_gilda_annotations = 0 - # for _, row in df.iterrows(): for _, item in self.processed_data.iterrows(): doc_id = item['doc_id'] figure = item['figure'] text = item['text'] - # annotations = item['annotations'] # Get the full text for the paper-level disambiguation full_text = self._get_plaintext(doc_id) @@ -309,45 +303,28 @@ def annotate_entities_with_gilda(self): gilda_annotations = annotate(text, context_text=full_text) # for testing all matches for each entity, return_first = False. - for matched_text, scored_matches, start, end in gilda_annotations: + for annotation in gilda_annotations: total_gilda_annotations += 1 - # all_synonyms = set() # new addition - for grounding_result in scored_matches: - # Checking against stoplist - if matched_text in self.stoplist: - continue - - db, entity_id = grounding_result.term.db, grounding_result.term.id - curie = f"{db}:{entity_id}" # unnecessary btw - # normalized_id = self._normalize_id(curie) - # synonyms = self.get_synonym_set([curie]) - # all_synonyms.update(synonyms) #new addition. - # entity_type = self._get_entity_type([curie]) - - self.gilda_annotations_map[(doc_id, figure)].append({ - 'matched_text': matched_text, - 'db': db, - 'id': entity_id, - 'start': start, - 'end': end, - # 'normalized_id': normalized_id, - # 'synonyms': all_synonyms #new addition, otherwise change back to just synonyms - # 'entity_type': entity_type - }) - - # results.append( - # (scored_match.term.db, scored_match.term.id, start, end)) - if doc_id == '3868508' and figure == 'Figure_1-A': - print(f"Scored NER Match: {grounding_result}") - print(f"Annotated Text Segment: {text[start:end]} at " - f"indices {start} to {end}") + + if annotation.text in self.stoplist: + continue + + self.gilda_annotations_map[(doc_id, figure)].append(annotation) + + if doc_id == '3868508' and figure == 'Figure_1-A': + print(f"Scored NER Match: {annotation}") + print(f"Annotated Text Segment: " + f"{text[annotation.start:annotation.end]} at " + f"indices {annotation.start} to {annotation.end}") + for i, scored_match in enumerate(annotation.matches): + print(f"Scored Match {i + 1}: {scored_match}") + print( + f"DB: {scored_match.term.db}, " + f"ID: {scored_match.term.id}") print( - f"Gilda Matched Text: {matched_text}, DB: {db}, " - f"ID: {entity_id}, Start: {start}, End: {end}") - print(f"Grounding Results: {curie}") - # print(f"synonyms: {synonyms}") - # print(f"entity type: {entity_type}") - print("\n") + f"Score: {scored_match.score}, " + f"Match: {scored_match.match}") + print("\n") tqdm.write("Finished annotating corpus with Gilda...") print(f"Total Gilda annotations: {total_gilda_annotations}") @@ -356,175 +333,97 @@ def evaluate_gilda_performance(self): """Calculates precision, recall, and F1""" print("Evaluating performance...") - # df = self.processed_data - - # total_true_positives = 0 - # total_false_positives = 0 - # total_false_negatives = 0 - false_positives_counter = Counter() - - # Create a set of reference annotations for quick lookup - ref_annotations = set() - - for _, row in self.annotations_df.iterrows(): - doc_id = str(row['don_article']) - figure = row['figure'] - text = row['text'] - for syn in row['obj_synonyms']: - ref_annotations.add((doc_id, figure, text, syn, - row['first left'], - row['last right'])) - - - print(f"Total reference annotations: {len(ref_annotations)}") metrics = { 'all_matches': {'tp': 0, 'fp': 0, 'fn': 0}, 'top_match': {'tp': 0, 'fp': 0, 'fn': 0} } - for (doc_id, figure), annotations in self.gilda_annotations_map.items(): - # print(f"Processing Document ID: {doc_id}, Figure: {figure}") - for i, annotation in enumerate(annotations): - start = annotation['start'] - end = annotation['end'] - # gilda_annotation = annotation['id'] - # gilda_synonyms = annotation['synonyms'] - text = annotation['matched_text'] - curie = f"{annotation['db']}:{annotation['id']}" - - match_found = (doc_id, figure, text, curie, start, end) in ref_annotations - - if match_found: - metrics['all_matches']['tp'] += 1 - if i == 0: # Top match - metrics['top_match']['tp'] += 1 - else: - metrics['all_matches']['fp'] += 1 - false_positives_counter[text] += 1 - if i == 0: # Top match - metrics['top_match']['fp'] += 1 + false_positives_counter = Counter() + ref_dict = defaultdict(list) + for _, row in self.annotations_df.iterrows(): + key = (str(row['don_article']), row['figure'], row['text'], + row['first left'], row['last right']) + ref_dict[key].append((set(row['obj']), row['obj_synonyms'])) - # match_found = any( - # (doc_id, figure, text, syn, start, end) - # in ref_annotations for syn in gilda_synonyms) + print(f"Total reference annotations: {len(ref_dict)}") - # Debugging: Identify and print the exact match - matching_reference = None - if match_found: - if (doc_id, figure, text, curie, start, end) in ref_annotations: - matching_reference = ( - doc_id, figure, text, curie, start, end) + for (doc_id, figure), annotations in self.gilda_annotations_map.items(): + for annotation in annotations: + key = (doc_id, figure, annotation.text, annotation.start, + annotation.end) + matching_refs = ref_dict.get(key, []) + + match_found = False + for i, scored_match in enumerate(annotation.matches): + curie = f"{scored_match.term.db}:{scored_match.term.id}" + + for original_curies, synonyms in matching_refs: + if curie in original_curies or curie in synonyms: + metrics['all_matches']['tp'] += 1 + if i == 0: # Top match + metrics['top_match']['tp'] += 1 + match_found = True break - if (match_found == True and doc_id == '3868508' - and figure == "Figure_1-A"): - print(f"Gilda Annotation: {annotation}") - # print(f"Reference Annotations: {ref_annotations}") - print(f"Match Found: {match_found}") if match_found: - print(f"Matching Reference: {matching_reference}") - - # if match_found: - # total_true_positives += 1 - # else: - # total_false_positives += 1 - # false_positives_counter[text] += 1 - # total_false_negatives += 1 + if doc_id == '3868508' and figure == "Figure_1-A": + print(f"Gilda Annotation: {annotation}") + # print(f"Reference Annotations: {ref_annotations}") + print(f"Match Found: {match_found}") + print(f"Matching Reference: {matching_refs}") + + break + + if match_found: + break + + if not match_found: + metrics['all_matches']['fp'] += 1 + false_positives_counter[annotation.text] += 1 + if annotation.matches: # Check if there are any matches + metrics['top_match']['fp'] += 1 + print(f"20 Most Common False Positives: " f"{false_positives_counter.most_common(20)}") - # Calculate false negatives - # for doc_id, figure, text, syn, start, end in ref_annotations: - # gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), - # []) - # match_found = any( - # ann['matched_text'] == text and - # f"{ann['db']}:{ann['id']}" == syn and - # ann['start'] == start and - # ann['end'] == end - # for ann in gilda_annotations - # ) - # if not match_found: - # metrics['all_matches']['fn'] += 1 - # metrics['top_match']['fn'] += 1 - - # Separate False Negatives Calculation using the DataFrame - for _, row in self.annotations_df.iterrows(): - doc_id = str(row['don_article']) - figure = row['figure'] - text = row['text'] - curie = row['obj'] - start = row['first left'] - end = row['last right'] - + # False negative calculation using ref dict + for key, refs in ref_dict.items(): + doc_id, figure = key[0], key[1] gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), []) - match_found = any( - ann['matched_text'] == text and - f"{ann['db']}:{ann['id']}" == curie and - ann['start'] == start and - ann['end'] == end - for ann in gilda_annotations - ) - if not match_found: - metrics['all_matches']['fn'] += 1 - metrics['top_match']['fn'] += 1 + for original_curies, synonyms in refs: + match_found = any( + ann.text == key[2] and + ann.start == key[3] and + ann.end == key[4] and + any(f"{match.term.db}:{match.term.id}" in original_curies or + f"{match.term.db}:{match.term.id}" in synonyms + for match in ann.matches) + for ann in gilda_annotations + ) + + if not match_found: + metrics['all_matches']['fn'] += 1 + metrics['top_match']['fn'] += 1 results = {} for match_type, counts in metrics.items(): - precision = counts['tp'] / (counts['tp'] + counts['fp']) if (counts[ - 'tp'] + - counts[ - 'fp']) > 0 else 0 - recall = counts['tp'] / (counts['tp'] + counts['fn']) if (counts[ - 'tp'] + - counts[ - 'fn']) > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if ( - precision + recall) > 0 else 0 + precision = counts['tp'] / (counts['tp'] + counts['fp']) \ + if ((counts['tp'] + counts['fp']) > 0) else 0 + + recall = counts['tp'] / (counts['tp'] + counts['fn']) \ + if (counts['tp'] + counts['fn']) > 0 else 0 + + f1 = 2 * (precision * recall) / (precision + recall) \ + if (precision + recall) > 0 else 0 + results[match_type] = { 'precision': precision, 'recall': recall, 'f1': f1 } - self.performance_metrics = results - - - - # UNCOMMENT IF ABOVE DOESNT WORK - # for _, row in self.annotations_df.iterrows(): - # doc_id, figure = row['don_article'], row['figure'] - # new_ref_annotation = ( - # row['text'], set(row['obj_synonyms']), row['first left'], - # row['last right']) - # - # # Check if this reference annotation is in Gilda's annotations - # new_gilda_annotations = self.gilda_annotations_map.get( - # (str(doc_id), figure), []) - # ref_match_found = False - # for ann in new_gilda_annotations: - # # (new_ref_annotation[1].intersection(ann['synonyms']) and - # if new_ref_annotation[0] == ann['matched_text'] and \ - # new_ref_annotation[2] == ann['start'] and \ - # new_ref_annotation[3] == ann['end']: - # ref_match_found = True - # break - # - # if not ref_match_found: - # total_false_negatives += 1 - - # precision = total_true_positives / (total_true_positives - # + total_false_positives) \ - # if (total_true_positives + total_false_positives) > 0 else 0.0 - # - # recall = total_true_positives / (total_true_positives - # + total_false_negatives) \ - # if (total_true_positives + total_false_negatives) > 0 else 0.0 - # - # f1 = (2 * (precision * recall)) / (precision + recall) \ - # if ((precision + recall) > 0) else 0 counts_table = pd.DataFrame([ { From 5709a752be660c64a233b4173fde35ae745c7094 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 19 Jul 2024 11:35:45 -0400 Subject: [PATCH 07/19] Changes to app.py --- gilda/app/app.py | 70 +++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 34 deletions(-) diff --git a/gilda/app/app.py b/gilda/app/app.py index 152c68e..c26f7f3 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -47,7 +47,7 @@ term_model = api.model( "Term", - {'norm_text' : fields.String( + {'norm_text': fields.String( description='The normalized text corresponding to the text entry, ' 'used for lookups.', example='egf receptor'), @@ -93,6 +93,7 @@ description='In some cases the term\'s db/id was mapped from another ' 'db/id pair given in the original source. If this is the ' 'case, this field provides the original source ID.') + } ) @@ -109,19 +110,18 @@ example=0.9845 ), 'match': fields.Nested(api.model('Match', {}), - description='Additional metadata about the nature of the match.' - ), + description='Additional metadata about the nature of the match.' + ), 'subsumed_terms': fields.List(fields.Nested(term_model), - description='In some cases multiple terms with the same db/id ' - 'matched the input string, potentially with different ' - 'scores, and only the first one is exposed in the ' - 'scored match\'s term attribute (see above). This field ' - 'provides additional terms with the same db/id that ' - 'matched the input for additional traceability.') + description='In some cases multiple terms with the same db/id ' + 'matched the input string, potentially with different ' + 'scores, and only the first one is exposed in the ' + 'scored match\'s term attribute (see above). This field ' + 'provides additional terms with the same db/id that ' + 'matched the input for additional traceability.') } ) - get_names_input_model = api.model( "GetNamesInput", {'db': fields.String( @@ -129,26 +129,26 @@ "e.g. HGNC.", required=True, example='HGNC'), - 'id': fields.String( - description="Identifier within the given database", - required=True, - example='3236' - ), - 'status': fields.String( - description="If provided, only entity texts of the given status are " - "returned (e.g., curated, name, synonym, former_name).", - required=False, - enum=['curated', 'name', 'synonym', 'former_name'], - example='synonym' - ), - 'source': fields.String( - description="If provided, only entity texts collected from the given " - "source are returned.This is useful if terms grounded to " - "IDs in a given database are collected from multiple " - "different sources.", - required=False, - example='uniprot' - ) + 'id': fields.String( + description="Identifier within the given database", + required=True, + example='3236' + ), + 'status': fields.String( + description="If provided, only entity texts of the given status are " + "returned (e.g., curated, name, synonym, former_name).", + required=False, + enum=['curated', 'name', 'synonym', 'former_name'], + example='synonym' + ), + 'source': fields.String( + description="If provided, only entity texts collected from the given " + "source are returned.This is useful if terms grounded to " + "IDs in a given database are collected from multiple " + "different sources.", + required=False, + example='uniprot' + ) } ) @@ -186,8 +186,8 @@ }) names_model = fields.List( - fields.String, - example=['EGF receptor', 'EGFR', 'ERBB1', 'Proto-oncogene c-ErbB-1']) + fields.String, + example=['EGF receptor', 'EGFR', 'ERBB1', 'Proto-oncogene c-ErbB-1']) models_model = fields.List( fields.String, @@ -213,7 +213,8 @@ def post(self): text = request.json.get('text') context = request.json.get('context') organisms = request.json.get('organisms') - scored_matches = grounder.ground(text, context=context, organisms=organisms) + scored_matches = grounder.ground(text, context=context, + organisms=organisms) res = [sm.to_json() for sm in scored_matches] return jsonify(res) @@ -239,7 +240,8 @@ def post(self): text = input.get('text') context = input.get('context') organisms = input.get('organisms') - scored_matches = grounder.ground(text, context=context, organisms=organisms) + scored_matches = grounder.ground(text, context=context, + organisms=organisms) all_matches.append([sm.to_json() for sm in scored_matches]) return jsonify(all_matches) From 2bc9cf7f34729d8ae17d30e66c2966cd3f0930ba Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Thu, 25 Jul 2024 14:02:15 -0400 Subject: [PATCH 08/19] Annotate not accepting spans <= 1, added progress bars --- benchmarks/bioid_ner_benchmark.py | 98 ++++++++++++++++--------------- gilda/ner.py | 5 +- 2 files changed, 55 insertions(+), 48 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 8bdbf2e..2af8865 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -35,7 +35,8 @@ RESULTS_DIR = os.path.join(HERE, 'results', "bioid_ner_performance", gilda.__version__) MODULE = pystow.module('gilda', 'biocreative') -URL = 'https://biocreative.bioinformatics.udel.edu/media/store/files/2017/BioIDtraining_2.tar.gz' +URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017' + '/BioIDtraining_2.tar.gz') STOPLIST_PATH = os.path.join(HERE, 'data', 'ner_stoplist.txt') @@ -49,8 +50,8 @@ def __init__(self): print("Instantiating benchmarker...") self.equivalences = self._load_equivalences() self.paper_level_grounding = defaultdict(set) - self.processed_data = self.process_xml_files() # xml files processesed - self.annotations_df = self._process_annotations_table() # csv annotations + self.processed_data = self.process_xml_files() # xml files processes + self.annotations_df = self._process_annotations_table() # csvannotations self.stoplist = self._load_stoplist() # Load stoplist self.gilda_annotations_map = defaultdict(list) self.annotations_count = 0 @@ -289,10 +290,12 @@ def _get_plaintext(self, don_article: str) -> str: def annotate_entities_with_gilda(self): """Performs NER on the XML files using gilda.annotate()""" - tqdm.write("Annotating corpus with Gilda...") + print("Annotating corpus with Gilda...") total_gilda_annotations = 0 - for _, item in self.processed_data.iterrows(): + for _, item in tqdm(self.processed_data.iterrows(), + total=self.processed_data.shape[0], + desc="Annotating with Gilda"): doc_id = item['doc_id'] figure = item['figure'] text = item['text'] @@ -311,23 +314,23 @@ def annotate_entities_with_gilda(self): self.gilda_annotations_map[(doc_id, figure)].append(annotation) - if doc_id == '3868508' and figure == 'Figure_1-A': - print(f"Scored NER Match: {annotation}") - print(f"Annotated Text Segment: " - f"{text[annotation.start:annotation.end]} at " - f"indices {annotation.start} to {annotation.end}") - for i, scored_match in enumerate(annotation.matches): - print(f"Scored Match {i + 1}: {scored_match}") - print( - f"DB: {scored_match.term.db}, " - f"ID: {scored_match.term.id}") - print( - f"Score: {scored_match.score}, " - f"Match: {scored_match.match}") - print("\n") + # if doc_id == '3868508' and figure == 'Figure_1-A': + # tqdm.write(f"Scored NER Match: {annotation}") + # tqdm.write(f"Annotated Text Segment: " + # f"{text[annotation.start:annotation.end]} at " + # f"indices {annotation.start} to {annotation.end}") + # for i, scored_match in enumerate(annotation.matches): + # tqdm.write(f"Scored Match {i + 1}: {scored_match}") + # tqdm.write( + # f"DB: {scored_match.term.db}, " + # f"ID: {scored_match.term.id}") + # tqdm.write( + # f"Score: {scored_match.score}, " + # f"Match: {scored_match.match}") + # tqdm.write("\n") tqdm.write("Finished annotating corpus with Gilda...") - print(f"Total Gilda annotations: {total_gilda_annotations}") + # tqdm.write(f"Total Gilda annotations: {total_gilda_annotations}") def evaluate_gilda_performance(self): """Calculates precision, recall, and F1""" @@ -346,9 +349,11 @@ def evaluate_gilda_performance(self): row['first left'], row['last right']) ref_dict[key].append((set(row['obj']), row['obj_synonyms'])) - print(f"Total reference annotations: {len(ref_dict)}") + # print(f"Total reference annotations: {len(ref_dict)}") - for (doc_id, figure), annotations in self.gilda_annotations_map.items(): + for (doc_id, figure), annotations in ( + tqdm(self.gilda_annotations_map.items(), + desc="Evaluating Annotations")): for annotation in annotations: key = (doc_id, figure, annotation.text, annotation.start, annotation.end) @@ -366,14 +371,13 @@ def evaluate_gilda_performance(self): match_found = True break - if match_found: - if doc_id == '3868508' and figure == "Figure_1-A": - print(f"Gilda Annotation: {annotation}") - # print(f"Reference Annotations: {ref_annotations}") - print(f"Match Found: {match_found}") - print(f"Matching Reference: {matching_refs}") + # if match_found: + # if doc_id == '3868508' and figure == "Figure_1-A": + # print(f"Gilda Annotation: {annotation}") + # print(f"Match Found: {match_found}") + # print(f"Matching Reference: {matching_refs}") - break + # break if match_found: break @@ -384,11 +388,12 @@ def evaluate_gilda_performance(self): if annotation.matches: # Check if there are any matches metrics['top_match']['fp'] += 1 - print(f"20 Most Common False Positives: " - f"{false_positives_counter.most_common(20)}") + # print(f"20 Most Common False Positives: " + # f"{false_positives_counter.most_common(20)}") # False negative calculation using ref dict - for key, refs in ref_dict.items(): + for key, refs in tqdm(ref_dict.items(), + desc="Calculating False Negatives"): doc_id, figure = key[0], key[1] gilda_annotations = self.gilda_annotations_map.get((doc_id, figure), []) @@ -424,7 +429,6 @@ def evaluate_gilda_performance(self): 'f1': f1 } - counts_table = pd.DataFrame([ { 'Match Type': 'All Matches', @@ -462,7 +466,8 @@ def evaluate_gilda_performance(self): false_positives_df = pd.DataFrame(false_positives_counter.items(), columns=['False Positive Text', 'Count']) - false_positives_df = false_positives_df.sort_values(by='Count', ascending=False) + false_positives_df = false_positives_df.sort_values(by='Count', + ascending=False) false_positives_df.to_csv( os.path.join(RESULTS_DIR, 'false_positives.csv'), index=False) @@ -489,25 +494,26 @@ def get_famplex_members(): fplx_members = get_famplex_members() -# def main(results: str): -def main(): - # results_path = os.path.expandvars(os.path.expanduser(results)) - # os.makedirs(results_path, exist_ok=True) +def main(results: str = RESULTS_DIR): + results_path = os.path.expandvars(os.path.expanduser(results)) + os.makedirs(results_path, exist_ok=True) benchmarker = BioIDNERBenchmarker() benchmarker.annotate_entities_with_gilda() benchmarker.evaluate_gilda_performance() counts, precision_recall = benchmarker.get_results_tables() - print("Counts table:") + + print(f"Counts Table:") print(counts.to_markdown(index=False)) - print("Precision and Recall table:") + print(f"Precision and Recall table: ") print(precision_recall.to_markdown(index=False)) - # time = datetime.now().strftime('%y%m%d-%H%M%S') - # result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}') - # counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False) - # precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"), - # index=False) - # print(f'Results saved to {results_path}') + + time = datetime.now().strftime('%y%m%d-%H%M%S') + result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}') + counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False) + precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"), + index=False) + print(f'Results saved to {results_path}') if __name__ == '__main__': diff --git a/gilda/ner.py b/gilda/ner.py index 570991e..5c6f58a 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -150,8 +150,9 @@ def annotate( spaces = ' ' * (c[0] - len(raw_span) - raw_word_coords[idx][0]) raw_span += spaces + rw - # if len(txt_span) <= 1: - # continue + + if len(raw_span) <= 1: + continue context = text if context_text is None else context_text matches = grounder.ground(raw_span, context=context, From 4c297281daa5d9df2821c5c3220039be15a87399 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Thu, 25 Jul 2024 16:14:32 -0400 Subject: [PATCH 09/19] Extended stopwords; moved stop words file to resources and logic into ner module --- benchmarks/bioid_ner_benchmark.py | 25 +---- gilda/ner.py | 18 +++- gilda/resources/ner_stoplist.txt | 172 ++++++++++++++++++++++++++++++ 3 files changed, 191 insertions(+), 24 deletions(-) create mode 100644 gilda/resources/ner_stoplist.txt diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 2af8865..a23c7c4 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -14,9 +14,8 @@ import click import pystow import gilda -from gilda import ground from gilda.ner import annotate -from gilda.grounder import logger +import logging import famplex from indra.databases.chebi_client import get_chebi_id_from_pubchem @@ -24,7 +23,8 @@ from indra.databases.uniprot_client import get_hgnc_id from indra.ontology.bio import bio_ontology -logger.setLevel('WARNING') +logging.getLogger('gilda.grounder').setLevel('WARNING') +logger = logging.getLogger('bioid_ner_benchmark') # Constants HERE = os.path.dirname(os.path.abspath(__file__)) @@ -38,8 +38,6 @@ URL = ('https://biocreative.bioinformatics.udel.edu/media/store/files/2017' '/BioIDtraining_2.tar.gz') -STOPLIST_PATH = os.path.join(HERE, 'data', 'ner_stoplist.txt') - tqdm.pandas() BO_MISSING_XREFS = set() @@ -52,7 +50,6 @@ def __init__(self): self.paper_level_grounding = defaultdict(set) self.processed_data = self.process_xml_files() # xml files processes self.annotations_df = self._process_annotations_table() # csvannotations - self.stoplist = self._load_stoplist() # Load stoplist self.gilda_annotations_map = defaultdict(list) self.annotations_count = 0 # New field to store Gilda annotations @@ -107,19 +104,6 @@ def process_xml_files(self): print("Finished extracting information from XML files.") return pd.DataFrame(data) - def _load_stoplist(self) -> Set[str]: - """Load NER stoplist from file.""" - stoplist_path = STOPLIST_PATH - try: - with open(stoplist_path, 'r') as file: - stoplist = {line.strip().lower() for line in file} - print(f"Loaded stoplist with {len(stoplist)} entries.") - return stoplist - except FileNotFoundError: - print( - f"No stoplist found at {stoplist_path}. Proceeding without it.") - return set() - def _load_equivalences(self) -> Dict[str, List[str]]: try: with open(os.path.join(HERE, 'data', 'equivalences.json')) as f: @@ -309,9 +293,6 @@ def annotate_entities_with_gilda(self): for annotation in gilda_annotations: total_gilda_annotations += 1 - if annotation.text in self.stoplist: - continue - self.gilda_annotations_map[(doc_id, figure)].append(annotation) # if doc_id == '3868508' and figure == 'Figure_1-A': diff --git a/gilda/ner.py b/gilda/ner.py index 5c6f58a..49dae9e 100644 --- a/gilda/ner.py +++ b/gilda/ner.py @@ -46,7 +46,8 @@ same name but extension ``.ann``. """ -from typing import List +from typing import List, Set +import os from nltk.corpus import stopwords from nltk.tokenize import PunktSentenceTokenizer, TreebankWordTokenizer @@ -61,7 +62,20 @@ "stop_words" ] +STOPLIST_PATH = os.path.join(os.path.dirname(__file__),'resources', + 'ner_stoplist.txt') + + +def _load_stoplist() -> Set[str]: + """Load NER stoplist from file.""" + stoplist_path = STOPLIST_PATH + with open(stoplist_path, 'r') as file: + stoplist = {line.strip() for line in file} + return stoplist + + stop_words = set(stopwords.words('english')) +stop_words.update(_load_stoplist()) def annotate( @@ -150,7 +164,7 @@ def annotate( spaces = ' ' * (c[0] - len(raw_span) - raw_word_coords[idx][0]) raw_span += spaces + rw - + # If span is a single character, we don't want to consider it if len(raw_span) <= 1: continue context = text if context_text is None else context_text diff --git a/gilda/resources/ner_stoplist.txt b/gilda/resources/ner_stoplist.txt new file mode 100644 index 0000000..55e141c --- /dev/null +++ b/gilda/resources/ner_stoplist.txt @@ -0,0 +1,172 @@ +-I +-II +-III +Bark +Rod +Scott +Task +XREF_BIBR +XREF_FIG +[ +] +acid +alpha +andD +ankle +ankles +antigen +bark +bean +beta +bi +bite +blot +cell +cells +crash +cryptic +damage +danger +docking +duet +duration +face +fact +fast +fate +feet +finger +fingers +fist +foot +gain +goat +hand +hands +head +hip +hips +impact +injury +ir +knee +knees +lead +leg +legs +light +link +links +mark +matrix +neck +net +partial +post +prey +probe +processes +result +rod +role +sensor +shoulder +shoulders +spatial +task +time +toe +toes +top +tube +water +wt +figure +fig +control +bars +bar +red +per +antibody +antibodies +right +left +SEM +treatment +Cells +proteins +protein +SD +Student +group +µm +ANOVA +vs +nM +immunoblotting +animals +KO +Fig +experiment +fluorescence +starvation +intensity +white +genes +mM +condition +Bars +transfection +area +type +image +one +plasmid +µM +neurons +microscopy +Right +binding +hr +SDS-PAGE +arrowheads +individual +Bar +phosphorylation +nm +genotype +Left +mitochondrial +Ctrl +14 +DNA +tissue +RNA +clones +Control +plasmids +Cell +localization +gene +media +cultures +set +protein levels +A-C +size +membrane +biological replicates +inhibitor +strain +patients +growth +Table +NS +et +form +Methods +age +culture +basal +KD \ No newline at end of file From 2cebb39b04a527c32473810af4ff746c97914f00 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 26 Jul 2024 12:30:23 -0400 Subject: [PATCH 10/19] Fixed formatting issues --- gilda/app/app.py | 122 +++++++++++++++++++++++------------------------ 1 file changed, 60 insertions(+), 62 deletions(-) diff --git a/gilda/app/app.py b/gilda/app/app.py index c26f7f3..b1156c2 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -51,56 +51,54 @@ description='The normalized text corresponding to the text entry, ' 'used for lookups.', example='egf receptor'), - 'text' : fields.String( - description='The text entry that was matched.', - example='EGF receptor' + 'text': fields.String( + description='The text entry that was matched.', + example='EGF receptor' ), - 'db' : fields.String( - description='The database / namespace corresponding to the ' - 'grounded term.', - example='HGNC' + 'db': fields.String( + description='The database / namespace corresponding to the ' + 'grounded term.', + example='HGNC' ), 'id': fields.String( - description='The identifier of the grounded term within the ' - 'database / namespace.', - example='3236' + description='The identifier of the grounded term within the ' + 'database / namespace.', + example='3236' ), 'entry_name': fields.String( - description='The standardized name corresponding to the grounded ' - 'term.', - example='EGFR' + description='The standardized name corresponding to the grounded ' + 'term.', + example='EGFR' ), 'status': fields.String( - description='The relationship of the text entry to the grounded ' - 'term, e.g., synonym.', - example='curated' + description='The relationship of the text entry to the grounded ' + 'term, e.g., synonym.', + example='curated' ), 'source': fields.String( - description='The source from which the term was obtained.', - example='famplex' + description='The source from which the term was obtained.', + example='famplex' ), 'organism': fields.String( - description='If the term is a gene/protein, this field provides ' - 'the taxonomy identifier of the species to which ' - 'it belongs.', - example='9606' + description='If the term is a gene/protein, this field provides ' + 'the taxonomy identifier of the species to which ' + 'it belongs.', + example='9606' ), 'source_db': fields.String( - description='In some cases the term\'s db/id was mapped from another ' - 'db/id pair given in the original source. If this is the ' - 'case, this field provides the original source db.'), + description='In some cases the term\'s db/id was mapped from another ' + 'db/id pair given in the original source. If this is the ' + 'case, this field provides the original source db.'), 'source_id': fields.String( - description='In some cases the term\'s db/id was mapped from another ' - 'db/id pair given in the original source. If this is the ' - 'case, this field provides the original source ID.') - + description='In some cases the term\'s db/id was mapped from another ' + 'db/id pair given in the original source. If this is the ' + 'case, this field provides the original source ID.') } ) scored_match_model = api.model( "ScoredMatch", - {'term': fields.Nested(term_model, - description='The term that was matched'), + {'term': fields.Nested(term_model, description='The term that was matched'), 'url': fields.String( description='Identifiers.org URL for the matched term.', example='https://identifiers.org/hgnc:3236' @@ -110,15 +108,15 @@ example=0.9845 ), 'match': fields.Nested(api.model('Match', {}), - description='Additional metadata about the nature of the match.' - ), - 'subsumed_terms': fields.List(fields.Nested(term_model), - description='In some cases multiple terms with the same db/id ' - 'matched the input string, potentially with different ' - 'scores, and only the first one is exposed in the ' - 'scored match\'s term attribute (see above). This field ' - 'provides additional terms with the same db/id that ' - 'matched the input for additional traceability.') + description='Additional metadata about the nature of the match.' + ), + 'subsumed_terms': fields.List(fields.Nested(term_model), + description='In some cases multiple terms with the same db/id ' + 'matched the input string, potentially with different ' + 'scores, and only the first one is exposed in the ' + 'scored match\'s term attribute (see above). This field ' + 'provides additional terms with the same db/id that ' + 'matched the input for additional traceability.') } ) @@ -129,26 +127,26 @@ "e.g. HGNC.", required=True, example='HGNC'), - 'id': fields.String( - description="Identifier within the given database", - required=True, - example='3236' - ), - 'status': fields.String( - description="If provided, only entity texts of the given status are " - "returned (e.g., curated, name, synonym, former_name).", - required=False, - enum=['curated', 'name', 'synonym', 'former_name'], - example='synonym' - ), - 'source': fields.String( - description="If provided, only entity texts collected from the given " - "source are returned.This is useful if terms grounded to " - "IDs in a given database are collected from multiple " - "different sources.", - required=False, - example='uniprot' - ) + 'id': fields.String( + description="Identifier within the given database", + required=True, + example='3236' + ), + 'status': fields.String( + description="If provided, only entity texts of the given status are " + "returned (e.g., curated, name, synonym, former_name).", + required=False, + enum=['curated', 'name', 'synonym', 'former_name'], + example='synonym' + ), + 'source': fields.String( + description="If provided, only entity texts collected from the given " + "source are returned.This is useful if terms grounded to " + "IDs in a given database are collected from multiple " + "different sources.", + required=False, + example='uniprot' + ) } ) @@ -162,8 +160,8 @@ ner_input_model = api.model('NERInput', { 'text': fields.String(required=True, description='Text on which to perform' ' NER', - example='The EGF receptor binds EGF which is an interaction' - 'important in cancer.'), + example='The EGF receptor binds EGF which is an ' + 'interaction important in cancer.'), 'organisms': fields.List(fields.String, example=['9606'], description='An optional list of taxonomy ' 'species IDs defining a priority list' From 99a327ed2889f79ced961e925a09ee2970395def Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 26 Jul 2024 12:32:03 -0400 Subject: [PATCH 11/19] NER script derives BioIDBenchmarker class --- benchmarks/bioid_ner_benchmark.py | 162 ++---------------------------- 1 file changed, 7 insertions(+), 155 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index a23c7c4..e25c6fb 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -14,6 +14,8 @@ import click import pystow import gilda +from benchmarks.bioid_evaluation import BioIDBenchmarker +from benchmarks.bioid_evaluation import fplx_members from gilda.ner import annotate import logging @@ -43,8 +45,9 @@ BO_MISSING_XREFS = set() -class BioIDNERBenchmarker: +class BioIDNERBenchmarker(BioIDBenchmarker): def __init__(self): + super().__init__() print("Instantiating benchmarker...") self.equivalences = self._load_equivalences() self.paper_level_grounding = defaultdict(set) @@ -105,125 +108,13 @@ def process_xml_files(self): return pd.DataFrame(data) def _load_equivalences(self) -> Dict[str, List[str]]: - try: - with open(os.path.join(HERE, 'data', 'equivalences.json')) as f: + with open(os.path.join(HERE, 'data', 'equivalences.json')) as f: equivalences = json.load(f) - except FileNotFoundError: - print( - f"No Equivalences found at " - f"{os.path.join(HERE, 'data', 'equivalences.json')}. " - f"Proceeding without it.") - equivalences = {} return equivalences - @classmethod - def _normalize_ids(cls, curies: str) -> List[Tuple[str, str]]: - return [cls._normalize_id(y) for y in curies.split('|')] - - @staticmethod - def _normalize_id(curie): - """Convert ID into standardized format, f'{namespace}:{id}'.""" - if curie.startswith('CVCL'): - return curie.replace('_', ':') - split_id = curie.split(':', maxsplit=1) - if split_id[0] == 'Uberon': - return split_id[1] - if split_id[0] == 'Uniprot': - return f'UP:{split_id[1]}' - if split_id[0] in ['GO', 'CHEBI']: - return f'{split_id[0]}:{split_id[0]}:{split_id[1]}' - return curie - - def get_synonym_set(self, curies: Iterable[str]) -> Set[str]: - """Return set containing all elements in input list along with synonyms - """ - output = set() - for curie in curies: - output.update(self._get_equivalent_entities(curie)) - # We accept all FamPlex terms that cover some or all of the specific - # entries in the annotations - covered_fplx = {fplx_entry for fplx_entry, members - in fplx_members.items() if (members <= output)} - output |= {'FPLX:%s' % fplx_entry for fplx_entry in covered_fplx} - return output - - def _get_equivalent_entities(self, curie: str) -> Set[str]: - """Return set of equivalent entity groundings - - Uses set of equivalences in self.equiv_map as well as those - available in indra's hgnc, uniprot, and chebi clients. - """ - output = {curie} - prefix, identifier = curie.split(':', maxsplit=1) - for xref_prefix, xref_id in bio_ontology.get_mappings(prefix, - identifier): - output.add(f'{xref_prefix}:{xref_id}') - - # TODO these should all be in bioontology, eventually - for xref_curie in self.equivalences.get(curie, []): - if xref_curie in output: - continue - xref_prefix, xref_id = xref_curie.split(':', maxsplit=1) - if (prefix, xref_prefix) not in BO_MISSING_XREFS: - BO_MISSING_XREFS.add((prefix, xref_prefix)) - tqdm.write( - f'Bioontology v{bio_ontology.version} is missing mappings' - f' from {prefix} to {xref_prefix}') - output.add(xref_curie) - - if prefix == 'NCBI gene': - hgnc_id = get_hgnc_from_entrez(identifier) - if hgnc_id is not None: - output.add(f'HGNC:{hgnc_id}') - if prefix == 'UP': - hgnc_id = get_hgnc_id(identifier) - if hgnc_id is not None: - output.add(f'HGNC:{hgnc_id}') - if prefix == 'PubChem': - chebi_id = get_chebi_id_from_pubchem(identifier) - if chebi_id is not None: - output.add(f'CHEBI:{chebi_id}') - return output - - def _get_entity_type_helper(self, row) -> str: - if self._get_entity_type(row.obj) != 'Gene': - return self._get_entity_type(row.obj) - elif any(y.startswith('HGNC') for y in row.obj_synonyms): - return 'Human Gene' - else: - return 'Nonhuman Gene' - - @staticmethod - def _get_entity_type(groundings: Collection[str]) -> str: - """Get entity type based on entity groundings of text in corpus.""" - if any( - grounding.startswith('NCBI gene') or grounding.startswith('UP') - for grounding in groundings - ): - return 'Gene' - elif any(grounding.startswith('Rfam') for grounding in groundings): - return 'miRNA' - elif any( - grounding.startswith('CHEBI') or grounding.startswith('PubChem') - for grounding in groundings): - return 'Small Molecule' - elif any(grounding.startswith('GO') for grounding in groundings): - return 'Cellular Component' - elif any( - grounding.startswith('CVCL') or grounding.startswith('CL') - for grounding in groundings - ): - return 'Cell types/Cell lines' - elif any(grounding.startswith('UBERON') for grounding in groundings): - return 'Tissue/Organ' - elif any( - grounding.startswith('NCBI taxon') for grounding in groundings): - return 'Taxon' - else: - return 'unknown' - def _process_annotations_table(self): - """Extract relevant information from annotations table.""" + """Extract relevant information from annotations table. Modified for + NER. Overrides the super method.""" print("Extracting information from annotations table...") df = MODULE.ensure_tar_df( url=URL, @@ -250,28 +141,6 @@ def _process_annotations_table(self): self.paper_level_grounding[don_article, text].update(synonyms) return processed_data - @lru_cache(maxsize=None) - def _get_plaintext(self, don_article: str) -> str: - """Get plaintext content from XML file in BioID corpus - - Parameters - ---------- - don_article : - Identifier for paper used within corpus. - - Returns - ------- - : - Plaintext of specified article - """ - directory = MODULE.ensure_untar(url=URL, directory='BioIDtraining_2') - path = directory.joinpath('BioIDtraining_2', 'fulltext_bioc', - f'{don_article}.xml') - tree = etree.parse(path.as_posix()) - paragraphs = tree.xpath('//text') - paragraphs = [' '.join(text.itertext()) for text in paragraphs] - return '/n'.join(paragraphs) + '/n' - def annotate_entities_with_gilda(self): """Performs NER on the XML files using gilda.annotate()""" print("Annotating corpus with Gilda...") @@ -458,23 +327,6 @@ def get_results_tables(self): return self.counts_table, self.precision_recall -def get_famplex_members(): - from indra.databases import hgnc_client - fplx_entities = famplex.load_entities() - fplx_children = defaultdict(set) - for fplx_entity in fplx_entities: - members = famplex.individual_members('FPLX', fplx_entity) - for db_ns, db_id in members: - if db_ns == 'HGNC': - db_id = hgnc_client.get_current_hgnc_id(db_id) - if db_id: - fplx_children[fplx_entity].add('%s:%s' % (db_ns, db_id)) - return dict(fplx_children) - - -fplx_members = get_famplex_members() - - def main(results: str = RESULTS_DIR): results_path = os.path.expandvars(os.path.expanduser(results)) os.makedirs(results_path, exist_ok=True) From 0dff04bc3e84382fba799c06f7e6cab4c9a48437 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 13:17:48 -0400 Subject: [PATCH 12/19] Fix indentation --- gilda/app/app.py | 52 ++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/gilda/app/app.py b/gilda/app/app.py index b1156c2..bd5bd91 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -52,47 +52,47 @@ 'used for lookups.', example='egf receptor'), 'text': fields.String( - description='The text entry that was matched.', - example='EGF receptor' + description='The text entry that was matched.', + example='EGF receptor' ), 'db': fields.String( - description='The database / namespace corresponding to the ' - 'grounded term.', - example='HGNC' + description='The database / namespace corresponding to the ' + 'grounded term.', + example='HGNC' ), 'id': fields.String( - description='The identifier of the grounded term within the ' - 'database / namespace.', - example='3236' + description='The identifier of the grounded term within the ' + 'database / namespace.', + example='3236' ), 'entry_name': fields.String( - description='The standardized name corresponding to the grounded ' - 'term.', - example='EGFR' + description='The standardized name corresponding to the grounded ' + 'term.', + example='EGFR' ), 'status': fields.String( - description='The relationship of the text entry to the grounded ' - 'term, e.g., synonym.', - example='curated' + description='The relationship of the text entry to the grounded ' + 'term, e.g., synonym.', + example='curated' ), 'source': fields.String( - description='The source from which the term was obtained.', - example='famplex' + description='The source from which the term was obtained.', + example='famplex' ), 'organism': fields.String( - description='If the term is a gene/protein, this field provides ' - 'the taxonomy identifier of the species to which ' - 'it belongs.', - example='9606' + description='If the term is a gene/protein, this field provides ' + 'the taxonomy identifier of the species to which ' + 'it belongs.', + example='9606' ), 'source_db': fields.String( - description='In some cases the term\'s db/id was mapped from another ' - 'db/id pair given in the original source. If this is the ' - 'case, this field provides the original source db.'), + description='In some cases the term\'s db/id was mapped from another ' + 'db/id pair given in the original source. If this is the ' + 'case, this field provides the original source db.'), 'source_id': fields.String( - description='In some cases the term\'s db/id was mapped from another ' - 'db/id pair given in the original source. If this is the ' - 'case, this field provides the original source ID.') + description='In some cases the term\'s db/id was mapped from another ' + 'db/id pair given in the original source. If this is the ' + 'case, this field provides the original source ID.') } ) From 9a725f15ccf5ae632e8816452b0429b53c9b1771 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 13:22:06 -0400 Subject: [PATCH 13/19] Change more indentation --- gilda/app/app.py | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/gilda/app/app.py b/gilda/app/app.py index bd5bd91..a0caf74 100644 --- a/gilda/app/app.py +++ b/gilda/app/app.py @@ -110,7 +110,7 @@ 'match': fields.Nested(api.model('Match', {}), description='Additional metadata about the nature of the match.' ), - 'subsumed_terms': fields.List(fields.Nested(term_model), + 'subsumed_terms': fields.List(fields.Nested(term_model), description='In some cases multiple terms with the same db/id ' 'matched the input string, potentially with different ' 'scores, and only the first one is exposed in the ' @@ -123,29 +123,29 @@ get_names_input_model = api.model( "GetNamesInput", {'db': fields.String( - description="Capitalized name of the database for the grounding, " - "e.g. HGNC.", - required=True, - example='HGNC'), - 'id': fields.String( - description="Identifier within the given database", - required=True, - example='3236' - ), - 'status': fields.String( - description="If provided, only entity texts of the given status are " - "returned (e.g., curated, name, synonym, former_name).", - required=False, - enum=['curated', 'name', 'synonym', 'former_name'], - example='synonym' - ), - 'source': fields.String( - description="If provided, only entity texts collected from the given " - "source are returned.This is useful if terms grounded to " - "IDs in a given database are collected from multiple " - "different sources.", - required=False, - example='uniprot' + description="Capitalized name of the database for the grounding, " + "e.g. HGNC.", + required=True, + example='HGNC'), + 'id': fields.String( + description="Identifier within the given database", + required=True, + example='3236' + ), + 'status': fields.String( + description="If provided, only entity texts of the given status are " + "returned (e.g., curated, name, synonym, former_name).", + required=False, + enum=['curated', 'name', 'synonym', 'former_name'], + example='synonym' + ), + 'source': fields.String( + description="If provided, only entity texts collected from the given " + "source are returned.This is useful if terms grounded to " + "IDs in a given database are collected from multiple " + "different sources.", + required=False, + example='uniprot' ) } ) From c789f49e122c4fe01c2d5b122dd13f8e41f71d3c Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 14:30:38 -0400 Subject: [PATCH 14/19] Sort NERN stoplist alphabetically --- gilda/resources/ner_stoplist.txt | 175 +++++++++++++++---------------- 1 file changed, 87 insertions(+), 88 deletions(-) diff --git a/gilda/resources/ner_stoplist.txt b/gilda/resources/ner_stoplist.txt index 55e141c..3d5d26a 100644 --- a/gilda/resources/ner_stoplist.txt +++ b/gilda/resources/ner_stoplist.txt @@ -1,172 +1,171 @@ -I -II -III +14 +A-C +ANOVA +Bar Bark +Bars +Cell +Cells +Control +Ctrl +DNA +Fig +KDKO +Left +Methods +NS +RNA +Right Rod +SD +SDS-PAGE +SEM Scott +Student +Table Task XREF_BIBR XREF_FIG [ ] acid +age alpha andD +animals ankle ankles +antibodies +antibody antigen +area +arrowheads +bar bark +bars +basal bean beta bi +binding +biological replicates bite blot cell cells +clones +condition +control crash cryptic +culture +cultures damage danger docking duet duration +et +experiment face fact fast fate feet +fig +figure finger fingers fist +fluorescence foot +form gain +gene +genes +genotype goat +group +growth hand hands head hip hips +hr +image +immunoblotting impact +individual +inhibitor injury +intensity ir knee knees lead +left leg legs light link links +localization +mM mark matrix +media +membrane +microscopy +mitochondrial +nM neck net +neurons +nm +one partial +patients +per +phosphorylation +plasmid +plasmids post prey probe processes +protein +protein levels +proteins +red result +right rod role sensor +set shoulder shoulders +size spatial +starvation +strain task time +tissue toe toes top -tube -water -wt -figure -fig -control -bars -bar -red -per -antibody -antibodies -right -left -SEM +transfection treatment -Cells -proteins -protein -SD -Student -group -µm -ANOVA +tube +type vs -nM -immunoblotting -animals -KO -Fig -experiment -fluorescence -starvation -intensity +water white -genes -mM -condition -Bars -transfection -area -type -image -one -plasmid +wt µM -neurons -microscopy -Right -binding -hr -SDS-PAGE -arrowheads -individual -Bar -phosphorylation -nm -genotype -Left -mitochondrial -Ctrl -14 -DNA -tissue -RNA -clones -Control -plasmids -Cell -localization -gene -media -cultures -set -protein levels -A-C -size -membrane -biological replicates -inhibitor -strain -patients -growth -Table -NS -et -form -Methods -age -culture -basal -KD \ No newline at end of file +µm From e49763856c103fd1d42ed96bb8c0add864fc0ded Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 14:32:48 -0400 Subject: [PATCH 15/19] Adjust test for updated stopwords --- gilda/tests/test_ner.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py index 6c54758..c2d9b43 100644 --- a/gilda/tests/test_ner.py +++ b/gilda/tests/test_ner.py @@ -40,20 +40,14 @@ def test_get_brat(): assert isinstance(brat_str, str) match_str = dedent(""" - T1\tEntity 4 11\tprotein - #1\tAnnotatorNotes T1\tCHEBI:36080 - T2\tEntity 12 16\tBRAF - #2\tAnnotatorNotes T2\thgnc:1097 - T3\tEntity 22 28\tkinase - #3\tAnnotatorNotes T3\tmesh:D010770 - T4\tEntity 30 34\tBRAF + T1\tEntity 12 16\tBRAF + #1\tAnnotatorNotes T1\thgnc:1097 + T2\tEntity 22 28\tkinase + #2\tAnnotatorNotes T2\tmesh:D010770 + T3\tEntity 30 34\tBRAF + #3\tAnnotatorNotes T3\thgnc:1097 + T4\tEntity 46 50\tBRAF #4\tAnnotatorNotes T4\thgnc:1097 - T5\tEntity 40 44\tgene - #5\tAnnotatorNotes T5\tmesh:D005796 - T6\tEntity 46 50\tBRAF - #6\tAnnotatorNotes T6\thgnc:1097 - T7\tEntity 56 63\tprotein - #7\tAnnotatorNotes T7\tCHEBI:36080 """).lstrip() assert brat_str == match_str From 9fe059d12d76f8024cab3a8c74d68fa26676e26b Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 14:35:37 -0400 Subject: [PATCH 16/19] Reorganize imports --- benchmarks/bioid_ner_benchmark.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index e25c6fb..35e1d6b 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -1,23 +1,21 @@ import os import json import pathlib +import logging +from datetime import datetime from collections import defaultdict, Counter - -from functools import lru_cache -import pandas as pd import xml.etree.ElementTree as ET - -from lxml import etree -from tqdm import tqdm -from datetime import datetime from typing import List, Tuple, Set, Dict, Optional, Iterable, Collection -import click + import pystow +import pandas as pd +from tqdm import tqdm + import gilda -from benchmarks.bioid_evaluation import BioIDBenchmarker -from benchmarks.bioid_evaluation import fplx_members from gilda.ner import annotate -import logging + +#from benchmarks.bioid_evaluation import fplx_members +from benchmarks.bioid_evaluation import BioIDBenchmarker import famplex from indra.databases.chebi_client import get_chebi_id_from_pubchem From c05e87282b64805783a67d23b43ea1c022fae4d1 Mon Sep 17 00:00:00 2001 From: Ben Gyori Date: Fri, 26 Jul 2024 14:47:14 -0400 Subject: [PATCH 17/19] Update annotation test for new stopwords --- gilda/tests/test_ner.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/gilda/tests/test_ner.py b/gilda/tests/test_ner.py index c2d9b43..67e0eb1 100644 --- a/gilda/tests/test_ner.py +++ b/gilda/tests/test_ner.py @@ -12,22 +12,19 @@ def test_annotate(): assert isinstance(annotations, list) # Check that we get 7 annotations - assert len(annotations) == 7 + assert len(annotations) == 4 # Check that the annotations are for the expected words assert tuple(a.text for a in annotations) == ( - 'protein', 'BRAF', 'kinase', 'BRAF', 'gene', 'BRAF', 'protein') + 'BRAF', 'kinase', 'BRAF', 'BRAF') # Check that the spans are correct - expected_spans = ((4, 11), (12, 16), (22, 28), (30, 34), (40, 44), - (46, 50), (56, 63)) + expected_spans = ((12, 16), (22, 28), (30, 34), (46, 50)) actual_spans = tuple((a.start, a.end) for a in annotations) assert actual_spans == expected_spans # Check that the curies are correct - expected_curies = ("CHEBI:36080", "hgnc:1097", "mesh:D010770", - "hgnc:1097", "mesh:D005796", "hgnc:1097", - "CHEBI:36080") + expected_curies = ("hgnc:1097", "mesh:D010770", "hgnc:1097", "hgnc:1097") actual_curies = tuple(a.matches[0].term.get_curie() for a in annotations) assert actual_curies == expected_curies From 67e13ddd1690a117b1875f46c0de76e4e6dabbc2 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 26 Jul 2024 13:41:08 -0400 Subject: [PATCH 18/19] Removed debugging print statements and unused imports --- benchmarks/bioid_ner_benchmark.py | 40 ++----------------------------- 1 file changed, 2 insertions(+), 38 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 35e1d6b..5c0851b 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -5,7 +5,7 @@ from datetime import datetime from collections import defaultdict, Counter import xml.etree.ElementTree as ET -from typing import List, Tuple, Set, Dict, Optional, Iterable, Collection +from typing import List, Dict import pystow import pandas as pd @@ -17,12 +17,6 @@ #from benchmarks.bioid_evaluation import fplx_members from benchmarks.bioid_evaluation import BioIDBenchmarker -import famplex -from indra.databases.chebi_client import get_chebi_id_from_pubchem -from indra.databases.hgnc_client import get_hgnc_from_entrez -from indra.databases.uniprot_client import get_hgnc_id -from indra.ontology.bio import bio_ontology - logging.getLogger('gilda.grounder').setLevel('WARNING') logger = logging.getLogger('bioid_ner_benchmark') @@ -53,7 +47,6 @@ def __init__(self): self.annotations_df = self._process_annotations_table() # csvannotations self.gilda_annotations_map = defaultdict(list) self.annotations_count = 0 - # New field to store Gilda annotations self.counts_table = None self.precision_recall = None @@ -155,30 +148,14 @@ def annotate_entities_with_gilda(self): full_text = self._get_plaintext(doc_id) gilda_annotations = annotate(text, context_text=full_text) - # for testing all matches for each entity, return_first = False. for annotation in gilda_annotations: total_gilda_annotations += 1 self.gilda_annotations_map[(doc_id, figure)].append(annotation) - # if doc_id == '3868508' and figure == 'Figure_1-A': - # tqdm.write(f"Scored NER Match: {annotation}") - # tqdm.write(f"Annotated Text Segment: " - # f"{text[annotation.start:annotation.end]} at " - # f"indices {annotation.start} to {annotation.end}") - # for i, scored_match in enumerate(annotation.matches): - # tqdm.write(f"Scored Match {i + 1}: {scored_match}") - # tqdm.write( - # f"DB: {scored_match.term.db}, " - # f"ID: {scored_match.term.id}") - # tqdm.write( - # f"Score: {scored_match.score}, " - # f"Match: {scored_match.match}") - # tqdm.write("\n") - tqdm.write("Finished annotating corpus with Gilda...") - # tqdm.write(f"Total Gilda annotations: {total_gilda_annotations}") + tqdm.write(f"Total Gilda annotations: {total_gilda_annotations}") def evaluate_gilda_performance(self): """Calculates precision, recall, and F1""" @@ -197,8 +174,6 @@ def evaluate_gilda_performance(self): row['first left'], row['last right']) ref_dict[key].append((set(row['obj']), row['obj_synonyms'])) - # print(f"Total reference annotations: {len(ref_dict)}") - for (doc_id, figure), annotations in ( tqdm(self.gilda_annotations_map.items(), desc="Evaluating Annotations")): @@ -219,14 +194,6 @@ def evaluate_gilda_performance(self): match_found = True break - # if match_found: - # if doc_id == '3868508' and figure == "Figure_1-A": - # print(f"Gilda Annotation: {annotation}") - # print(f"Match Found: {match_found}") - # print(f"Matching Reference: {matching_refs}") - - # break - if match_found: break @@ -236,9 +203,6 @@ def evaluate_gilda_performance(self): if annotation.matches: # Check if there are any matches metrics['top_match']['fp'] += 1 - # print(f"20 Most Common False Positives: " - # f"{false_positives_counter.most_common(20)}") - # False negative calculation using ref dict for key, refs in tqdm(ref_dict.items(), desc="Calculating False Negatives"): From b915c1c7396476a69152a951e62e9f9c38641cb2 Mon Sep 17 00:00:00 2001 From: galileosteinberg Date: Fri, 26 Jul 2024 14:36:24 -0400 Subject: [PATCH 19/19] Added readme output. --- benchmarks/bioid_ner_benchmark.py | 67 +++++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 7 deletions(-) diff --git a/benchmarks/bioid_ner_benchmark.py b/benchmarks/bioid_ner_benchmark.py index 5c0851b..7f24e9c 100644 --- a/benchmarks/bioid_ner_benchmark.py +++ b/benchmarks/bioid_ner_benchmark.py @@ -5,6 +5,7 @@ from datetime import datetime from collections import defaultdict, Counter import xml.etree.ElementTree as ET +from textwrap import dedent from typing import List, Dict import pystow @@ -49,6 +50,7 @@ def __init__(self): self.annotations_count = 0 self.counts_table = None self.precision_recall = None + self.false_positives_counter = Counter() def process_xml_files(self): """Extract relevant information from XML files.""" @@ -166,7 +168,7 @@ def evaluate_gilda_performance(self): 'top_match': {'tp': 0, 'fp': 0, 'fn': 0} } - false_positives_counter = Counter() + ref_dict = defaultdict(list) for _, row in self.annotations_df.iterrows(): @@ -199,7 +201,7 @@ def evaluate_gilda_performance(self): if not match_found: metrics['all_matches']['fp'] += 1 - false_positives_counter[annotation.text] += 1 + self.false_positives_counter[annotation.text] += 1 if annotation.matches: # Check if there are any matches metrics['top_match']['fp'] += 1 @@ -275,7 +277,7 @@ def evaluate_gilda_performance(self): self.precision_recall = precision_recall os.makedirs(RESULTS_DIR, exist_ok=True) - false_positives_df = pd.DataFrame(false_positives_counter.items(), + false_positives_df = pd.DataFrame(self.false_positives_counter.items(), columns=['False Positive Text', 'Count']) false_positives_df = false_positives_df.sort_values(by='Count', @@ -285,8 +287,10 @@ def evaluate_gilda_performance(self): print("Finished evaluating performance...") - def get_results_tables(self): - return self.counts_table, self.precision_recall + def get_tables(self): + return (self.counts_table, + self.precision_recall, + self.false_positives_counter) def main(results: str = RESULTS_DIR): @@ -296,7 +300,7 @@ def main(results: str = RESULTS_DIR): benchmarker = BioIDNERBenchmarker() benchmarker.annotate_entities_with_gilda() benchmarker.evaluate_gilda_performance() - counts, precision_recall = benchmarker.get_results_tables() + counts, precision_recall, false_positives_counter = benchmarker.get_tables() print(f"Counts Table:") print(counts.to_markdown(index=False)) @@ -304,7 +308,56 @@ def main(results: str = RESULTS_DIR): print(precision_recall.to_markdown(index=False)) time = datetime.now().strftime('%y%m%d-%H%M%S') - result_stub = pathlib.Path(results_path).joinpath(f'benchmark_{time}') + + outname = f'benchmark_{time}' + result_stub = pathlib.Path(results_path).joinpath(outname) + + caption0 = dedent(f"""\ + # Gilda NER Benchmarking + + Gilda: v{gilda.__version__} + Date: {time} + """) + + caption1 = dedent("""\ + ## Table 1 + + The counts of true positives, false positives, and false negatives + for Gilda annotations in the corpus where only Gilda's "Top Match" + grounding (top score grounding) returns the correct match and where + any Gilda grounding returns a correct match. + """) + table1 = counts.to_markdown(index=False) + + caption2 = dedent("""\ + ## Table 2 + + Precision, recall, and F1 Score values for Gilda performance where + Gilda's "Top Match" grounding (top score grounding) returns the + correct match and where any Gilda grounding returns a correct match. + """) + table2 = precision_recall.to_markdown(index=False) + + caption3 = dedent("""\ + ## 50 Most Common False Positive Words + + A list of 50 most common false positive annotations created by Gilda. + """) + top_50_false_positives = false_positives_counter.most_common(50) + false_positives_list = '\n'.join( + [f'- {word}: {count}' for word, count in top_50_false_positives]) + + output = '\n\n'.join([ + caption0, + caption1, table1, + caption2, table2, + caption3, false_positives_list + ]) + + md_path = result_stub.with_suffix(".md") + with open(md_path, 'w') as f: + f.write(output) + counts.to_csv(result_stub.with_suffix(".counts.csv"), index=False) precision_recall.to_csv(result_stub.with_suffix(".precision_recall.csv"), index=False)