Skip to content

Commit

Permalink
Update typing
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 22, 2024
1 parent 0df22de commit 179118c
Showing 1 changed file with 44 additions and 36 deletions.
80 changes: 44 additions & 36 deletions src/semra/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools as itt
import logging
import typing as t
from collections import Counter, defaultdict
from collections.abc import Iterable
from typing import cast
Expand All @@ -22,15 +23,22 @@
KNOWLEDGE_MAPPING,
NARROW_MATCH,
)
from semra.struct import Evidence, Mapping, ReasonedEvidence, Reference, Triple, triple_key
from semra.struct import (
Evidence,
Mapping,
ReasonedEvidence,
Reference,
Triple,
triple_key,
)

logger = logging.getLogger(__name__)

PREDICATE_KEY = "predicate"
EVIDENCE_KEY = "evidence"

#: An index allows for the aggregation of evidences for each core triple
Index = dict[Triple, list[Evidence]]
Index = t.Dict[Triple, t.List[Evidence]]


def _tqdm(mappings: Iterable[Mapping], desc: str | None = None, *, progress: bool = True):
Expand All @@ -43,7 +51,7 @@ def _tqdm(mappings: Iterable[Mapping], desc: str | None = None, *, progress: boo
)


def count_source_target(mappings: Iterable[Mapping]) -> Counter[tuple[str, str]]:
def count_source_target(mappings: Iterable[Mapping]) -> Counter[t.Tuple[str, str]]:
"""Count source prefix-target prefix pairs."""
return Counter((s.prefix, o.prefix) for s, _, o in get_index(mappings))

Expand All @@ -65,18 +73,18 @@ def print_source_target_counts(mappings: Iterable[Mapping], minimum: int = 0) ->

def get_index(mappings: Iterable[Mapping], *, progress: bool = True) -> Index:
"""Aggregate and deduplicate evidences for each core triple."""
dd: defaultdict[Triple, list[Evidence]] = defaultdict(list)
dd: t.DefaultDict[Triple, t.List[Evidence]] = defaultdict(list)
for mapping in _tqdm(mappings, desc="Indexing mappings", progress=progress):
dd[mapping.triple].extend(mapping.evidence)
return {triple: deduplicate_evidence(evidence) for triple, evidence in dd.items()}


def assemble_evidences(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def assemble_evidences(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
index = get_index(mappings, progress=progress)
return unindex(index, progress=progress)


def infer_reversible(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def infer_reversible(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
rv = []
for mapping in _tqdm(mappings, desc="Infer reverse", progress=progress):
rv.append(mapping)
Expand Down Expand Up @@ -111,7 +119,7 @@ def flip(mapping: Mapping) -> Mapping | None:
)


def to_graph(mappings: list[Mapping]) -> nx.DiGraph:
def to_graph(mappings: t.List[Mapping]) -> nx.DiGraph:
"""Convert mappings into a directed graph data model."""
graph = nx.DiGraph()
for mapping in mappings:
Expand All @@ -123,7 +131,7 @@ def to_graph(mappings: list[Mapping]) -> nx.DiGraph:
return graph


def from_graph(graph: nx.DiGraph) -> list[Mapping]:
def from_graph(graph: nx.DiGraph) -> t.List[Mapping]:
"""Extract mappings from a directed graph data model."""
return [_from_edge(graph, s, o) for s, o in graph.edges()]

Expand All @@ -133,7 +141,7 @@ def _from_edge(graph: nx.DiGraph, s: Reference, o: Reference) -> Mapping:
return Mapping(s=s, p=data[PREDICATE_KEY], o=o, evidence=data[EVIDENCE_KEY])


def _condense_predicates(predicates: list[Reference]) -> Reference | None:
def _condense_predicates(predicates: t.List[Reference]) -> Reference | None:
predicate_set = set(predicates)
if predicate_set == {EXACT_MATCH}:
return EXACT_MATCH
Expand All @@ -145,8 +153,8 @@ def _condense_predicates(predicates: list[Reference]) -> Reference | None:


def infer_chains(
mappings: list[Mapping], *, backwards: bool = True, progress: bool = True, cutoff: int = 5
) -> list[Mapping]:
mappings: t.List[Mapping], *, backwards: bool = True, progress: bool = True, cutoff: int = 5
) -> t.List[Mapping]:
"""Apply graph-based reasoning over mapping chains to infer new mappings.
:param mappings: A list of input mappings
Expand Down Expand Up @@ -198,7 +206,7 @@ def tabulate_index(index: Index) -> str:
"""Tabulate"""
from tabulate import tabulate

rows: list[tuple[str, str, str, str]] = []
rows: t.List[t.Tuple[str, str, str, str]] = []

def key(pair):
return triple_key(pair[0])
Expand All @@ -218,16 +226,16 @@ def infer_mutual_dbxref_mutations(
mappings: Iterable[Mapping],
prefixes: set[str],
confidence: float | None = None,
) -> list[Mapping]:
) -> t.List[Mapping]:
pairs = {(s, t) for s, t in itt.product(prefixes, repeat=2) if s != t}
return infer_dbxref_mutations(mappings, pairs=pairs, confidence=confidence)


def infer_dbxref_mutations(
mappings: Iterable[Mapping],
pairs: dict[tuple[str, str], float] | Iterable[tuple[str, str]],
pairs: t.Dict[t.Tuple[str, str], float] | Iterable[t.Tuple[str, str]],
confidence: float | None = None,
) -> list[Mapping]:
) -> t.List[Mapping]:
"""Upgrade database cross-references into exact matches for the given pairs.
:param mappings: A list of mappings
Expand All @@ -249,12 +257,12 @@ def infer_dbxref_mutations(

def infer_mutations(
mappings: Iterable[Mapping],
pairs: dict[tuple[str, str], float],
pairs: t.Dict[t.Tuple[str, str], float],
old: Reference,
new: Reference,
*,
progress: bool = False,
) -> list[Mapping]:
) -> t.List[Mapping]:
"""Infer mappings with alternate predicates for the given prefix pairs.
:param mappings: Mappings to infer from
Expand Down Expand Up @@ -286,7 +294,7 @@ def infer_mutations(
return rv


def keep_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> list[Mapping]:
def keep_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings whose subject or object are not in the given list of prefixes."""
prefixes = set(prefixes)
return [
Expand Down Expand Up @@ -314,7 +322,7 @@ def keep_object_prefixes(mappings: Iterable[Mapping], prefixes: str | Iterable[s
]


def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> list[Mapping]:
def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings whose subject or object are in the given list of prefixes."""
prefixes = set(prefixes)
return [
Expand All @@ -324,7 +332,7 @@ def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, pro
]


def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings within the same resource."""
return [
mapping
Expand All @@ -333,7 +341,7 @@ def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -
]


def filter_mappings(mappings: list[Mapping], skip_mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_mappings(mappings: t.List[Mapping], skip_mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings in the second set from the first set."""
skip_triples = {skip_mapping.triple for skip_mapping in skip_mappings}
return [
Expand All @@ -343,18 +351,18 @@ def filter_mappings(mappings: list[Mapping], skip_mappings: list[Mapping], *, pr
]


M2MIndex = defaultdict[tuple[str, str], defaultdict[str, defaultdict[str, list[Mapping]]]]
M2MIndex = t.DefaultDict[t.Tuple[str, str], t.DefaultDict[str, t.DefaultDict[str, t.List[Mapping]]]]


def get_many_to_many(mappings: list[Mapping]) -> list[Mapping]:
def get_many_to_many(mappings: t.List[Mapping]) -> t.List[Mapping]:
"""Get many-to-many mappings, disregarding predicate type."""
forward: M2MIndex = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
backward: M2MIndex = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for mapping in mappings:
forward[mapping.s.prefix, mapping.o.prefix][mapping.s.identifier][mapping.o.identifier].append(mapping)
backward[mapping.s.prefix, mapping.o.prefix][mapping.o.identifier][mapping.s.identifier].append(mapping)

index: defaultdict[Triple, list[Evidence]] = defaultdict(list)
index: t.DefaultDict[Triple, t.List[Evidence]] = defaultdict(list)
for preindex in [forward, backward]:
for d1 in preindex.values():
for d2 in d1.values():
Expand All @@ -366,15 +374,15 @@ def get_many_to_many(mappings: list[Mapping]) -> list[Mapping]:
return rv


def filter_many_to_many(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_many_to_many(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out many to many mappings."""
skip_mappings = get_many_to_many(mappings)
return filter_mappings(mappings, skip_mappings, progress=progress)


def project(
mappings: list[Mapping], source_prefix: str, target_prefix: str, *, return_sus: bool = False, progress: bool = False
) -> list[Mapping] | tuple[list[Mapping], list[Mapping]]:
mappings: t.List[Mapping], source_prefix: str, target_prefix: str, *, return_sus: bool = False, progress: bool = False
) -> t.List[Mapping] | t.Tuple[t.List[Mapping], t.List[Mapping]]:
"""Ensure that each identifier only appears as the subject of one mapping."""
mappings = keep_subject_prefixes(mappings, source_prefix, progress=progress)
mappings = keep_object_prefixes(mappings, target_prefix, progress=progress)
Expand All @@ -386,13 +394,13 @@ def project(
return mappings


def project_dict(mappings: list[Mapping], source_prefix: str, target_prefix: str) -> dict[str, str]:
def project_dict(mappings: t.List[Mapping], source_prefix: str, target_prefix: str) -> t.Dict[str, str]:
"""Get a dictionary from source identifiers to target identifiers."""
mappings = cast(list[Mapping], project(mappings, source_prefix, target_prefix))
mappings = cast(t.List[Mapping], project(mappings, source_prefix, target_prefix))
return {mapping.s.identifier: mapping.o.identifier for mapping in mappings}


def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
def prioritize(mappings: t.List[Mapping], priority: t.List[str]) -> t.List[Mapping]:
"""Get a priority star graph.
:param mappings:
Expand All @@ -403,7 +411,7 @@ def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
exact_mappings = len(mappings)

graph = to_graph(mappings).to_undirected()
rv: list[Mapping] = []
rv: t.List[Mapping] = []
for component in tqdm(nx.connected_components(graph), unit="component", unit_scale=True):
o = _get_priority(component, priority)
if o is None:
Expand All @@ -427,7 +435,7 @@ def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
return rv


def _get_priority(component: list[Reference], priority: list[str]) -> Reference | None:
def _get_priority(component: t.List[Reference], priority: t.List[str]) -> t.Optional[Reference]:
prefix_to_references = defaultdict(list)
for c in component:
prefix_to_references[c.prefix].append(c)
Expand All @@ -444,7 +452,7 @@ def _get_priority(component: list[Reference], priority: list[str]) -> Reference
return None


def unindex(index: Index, *, progress: bool = True) -> list[Mapping]:
def unindex(index: Index, *, progress: bool = True) -> t.List[Mapping]:
"""Convert a mapping index into a list of mapping objects."""
return [
Mapping.from_triple(triple, evidence=evidence)
Expand All @@ -454,13 +462,13 @@ def unindex(index: Index, *, progress: bool = True) -> list[Mapping]:
]


def deduplicate_evidence(evidence: list[Evidence]) -> list[Evidence]:
def deduplicate_evidence(evidence: t.List[Evidence]) -> t.List[Evidence]:
"""Deduplicate a list of evidences based on their "key" function."""
d = {e.key(): e for e in evidence}
return list(d.values())


def validate_mappings(mappings: list[Mapping], *, progress: bool = True) -> None:
def validate_mappings(mappings: t.List[Mapping], *, progress: bool = True) -> None:
"""Validate mappings against the Bioregistry and raise an error on the first invalid."""
import bioregistry

Expand Down Expand Up @@ -489,7 +497,7 @@ def validate_mappings(mappings: list[Mapping], *, progress: bool = True) -> None
raise ValueError(f"banana in mapping object: {mapping}")


def summarize_prefixes(mappings: list[Mapping]) -> pd.DataFrame:
def summarize_prefixes(mappings: t.List[Mapping]) -> pd.DataFrame:
"""Get a dataframe summarizing the prefixes appearing in the mappings."""
import bioregistry

Expand Down

0 comments on commit 179118c

Please sign in to comment.