Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Begin cleanup #13

Merged
merged 23 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install tox
pip install tox hatch
- name: Test linting
run:
hatch run lint:style
- name: Test with mypy
run:
tox -e mypy
Expand Down
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ name = "semra"
dynamic = ["version"]
description = 'Semantic Mapping Reasoning Assembler'
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.8"
license = "MIT"
keywords = []
authors = [
Expand All @@ -16,6 +16,7 @@ authors = [
classifiers = [
"Development Status :: 4 - Beta",
"Programming Language :: Python",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand All @@ -33,6 +34,7 @@ dependencies = [
"bioontologies",
"pyobo",
"typing_extensions",
"rdflib", # remove after https://github.com/biopragmatics/bioregistry/pull/1030 is released
]

[project.optional-dependencies]
Expand Down Expand Up @@ -84,6 +86,7 @@ dependencies = [
"black[jupyter]>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
"pydantic",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive --ignore-missing-imports {args:src/semra tests}"
Expand All @@ -105,6 +108,9 @@ all = [
target-version = ["py39"]
line-length = 120

[tool.mypy]
plugins = ["pydantic.mypy"]

[tool.ruff]
target-version = "py39"
line-length = 120
Expand Down Expand Up @@ -146,6 +152,10 @@ ignore = [
"EM102", "EM101",
# Ignore pickle security warnings
"S301",
# Ignore upgrading type annotations
"UP006", "UP007", "UP035",
# Ignore shadowing python builtins (because we use 'license')
"A001", "A002", "A003",
]
unfixable = [
# Don't touch unused imports
Expand Down
87 changes: 51 additions & 36 deletions src/semra/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import itertools as itt
import logging
import typing as t
from collections import Counter, defaultdict
from collections.abc import Iterable
from typing import cast
Expand All @@ -22,15 +23,22 @@
KNOWLEDGE_MAPPING,
NARROW_MATCH,
)
from semra.struct import Evidence, Mapping, ReasonedEvidence, Reference, Triple, triple_key
from semra.struct import (
Evidence,
Mapping,
ReasonedEvidence,
Reference,
Triple,
triple_key,
)

logger = logging.getLogger(__name__)

PREDICATE_KEY = "predicate"
EVIDENCE_KEY = "evidence"

#: An index allows for the aggregation of evidences for each core triple
Index = dict[Triple, list[Evidence]]
Index = t.Dict[Triple, t.List[Evidence]]


def _tqdm(mappings: Iterable[Mapping], desc: str | None = None, *, progress: bool = True):
Expand All @@ -43,7 +51,7 @@ def _tqdm(mappings: Iterable[Mapping], desc: str | None = None, *, progress: boo
)


def count_source_target(mappings: Iterable[Mapping]) -> Counter[tuple[str, str]]:
def count_source_target(mappings: Iterable[Mapping]) -> Counter[t.Tuple[str, str]]:
"""Count source prefix-target prefix pairs."""
return Counter((s.prefix, o.prefix) for s, _, o in get_index(mappings))

Expand All @@ -65,18 +73,18 @@ def print_source_target_counts(mappings: Iterable[Mapping], minimum: int = 0) ->

def get_index(mappings: Iterable[Mapping], *, progress: bool = True) -> Index:
"""Aggregate and deduplicate evidences for each core triple."""
dd: defaultdict[Triple, list[Evidence]] = defaultdict(list)
dd: t.DefaultDict[Triple, t.List[Evidence]] = defaultdict(list)
for mapping in _tqdm(mappings, desc="Indexing mappings", progress=progress):
dd[mapping.triple].extend(mapping.evidence)
return {triple: deduplicate_evidence(evidence) for triple, evidence in dd.items()}


def assemble_evidences(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def assemble_evidences(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
index = get_index(mappings, progress=progress)
return unindex(index, progress=progress)


def infer_reversible(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def infer_reversible(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
rv = []
for mapping in _tqdm(mappings, desc="Infer reverse", progress=progress):
rv.append(mapping)
Expand Down Expand Up @@ -111,7 +119,7 @@ def flip(mapping: Mapping) -> Mapping | None:
)


def to_graph(mappings: list[Mapping]) -> nx.DiGraph:
def to_graph(mappings: t.List[Mapping]) -> nx.DiGraph:
"""Convert mappings into a directed graph data model."""
graph = nx.DiGraph()
for mapping in mappings:
Expand All @@ -123,7 +131,7 @@ def to_graph(mappings: list[Mapping]) -> nx.DiGraph:
return graph


def from_graph(graph: nx.DiGraph) -> list[Mapping]:
def from_graph(graph: nx.DiGraph) -> t.List[Mapping]:
"""Extract mappings from a directed graph data model."""
return [_from_edge(graph, s, o) for s, o in graph.edges()]

Expand All @@ -133,7 +141,7 @@ def _from_edge(graph: nx.DiGraph, s: Reference, o: Reference) -> Mapping:
return Mapping(s=s, p=data[PREDICATE_KEY], o=o, evidence=data[EVIDENCE_KEY])


def _condense_predicates(predicates: list[Reference]) -> Reference | None:
def _condense_predicates(predicates: t.List[Reference]) -> Reference | None:
predicate_set = set(predicates)
if predicate_set == {EXACT_MATCH}:
return EXACT_MATCH
Expand All @@ -145,8 +153,8 @@ def _condense_predicates(predicates: list[Reference]) -> Reference | None:


def infer_chains(
mappings: list[Mapping], *, backwards: bool = True, progress: bool = True, cutoff: int = 5
) -> list[Mapping]:
mappings: t.List[Mapping], *, backwards: bool = True, progress: bool = True, cutoff: int = 5
) -> t.List[Mapping]:
"""Apply graph-based reasoning over mapping chains to infer new mappings.

:param mappings: A list of input mappings
Expand Down Expand Up @@ -198,7 +206,7 @@ def tabulate_index(index: Index) -> str:
"""Tabulate"""
from tabulate import tabulate

rows: list[tuple[str, str, str, str]] = []
rows: t.List[t.Tuple[str, str, str, str]] = []

def key(pair):
return triple_key(pair[0])
Expand All @@ -218,16 +226,16 @@ def infer_mutual_dbxref_mutations(
mappings: Iterable[Mapping],
prefixes: set[str],
confidence: float | None = None,
) -> list[Mapping]:
) -> t.List[Mapping]:
pairs = {(s, t) for s, t in itt.product(prefixes, repeat=2) if s != t}
return infer_dbxref_mutations(mappings, pairs=pairs, confidence=confidence)


def infer_dbxref_mutations(
mappings: Iterable[Mapping],
pairs: dict[tuple[str, str], float] | Iterable[tuple[str, str]],
pairs: t.Dict[t.Tuple[str, str], float] | Iterable[t.Tuple[str, str]],
confidence: float | None = None,
) -> list[Mapping]:
) -> t.List[Mapping]:
"""Upgrade database cross-references into exact matches for the given pairs.

:param mappings: A list of mappings
Expand All @@ -249,12 +257,12 @@ def infer_dbxref_mutations(

def infer_mutations(
mappings: Iterable[Mapping],
pairs: dict[tuple[str, str], float],
pairs: t.Dict[t.Tuple[str, str], float],
old: Reference,
new: Reference,
*,
progress: bool = False,
) -> list[Mapping]:
) -> t.List[Mapping]:
"""Infer mappings with alternate predicates for the given prefix pairs.

:param mappings: Mappings to infer from
Expand Down Expand Up @@ -286,7 +294,7 @@ def infer_mutations(
return rv


def keep_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> list[Mapping]:
def keep_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings whose subject or object are not in the given list of prefixes."""
prefixes = set(prefixes)
return [
Expand Down Expand Up @@ -314,7 +322,7 @@ def keep_object_prefixes(mappings: Iterable[Mapping], prefixes: str | Iterable[s
]


def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> list[Mapping]:
def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings whose subject or object are in the given list of prefixes."""
prefixes = set(prefixes)
return [
Expand All @@ -324,7 +332,7 @@ def filter_prefixes(mappings: Iterable[Mapping], prefixes: Iterable[str], *, pro
]


def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out mappings within the same resource."""
return [
mapping
Expand All @@ -333,7 +341,9 @@ def filter_self_matches(mappings: Iterable[Mapping], *, progress: bool = True) -
]


def filter_mappings(mappings: list[Mapping], skip_mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_mappings(
mappings: t.List[Mapping], skip_mappings: t.List[Mapping], *, progress: bool = True
) -> t.List[Mapping]:
"""Filter out mappings in the second set from the first set."""
skip_triples = {skip_mapping.triple for skip_mapping in skip_mappings}
return [
Expand All @@ -343,18 +353,18 @@ def filter_mappings(mappings: list[Mapping], skip_mappings: list[Mapping], *, pr
]


M2MIndex = defaultdict[tuple[str, str], defaultdict[str, defaultdict[str, list[Mapping]]]]
M2MIndex = t.DefaultDict[t.Tuple[str, str], t.DefaultDict[str, t.DefaultDict[str, t.List[Mapping]]]]


def get_many_to_many(mappings: list[Mapping]) -> list[Mapping]:
def get_many_to_many(mappings: t.List[Mapping]) -> t.List[Mapping]:
"""Get many-to-many mappings, disregarding predicate type."""
forward: M2MIndex = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
backward: M2MIndex = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
for mapping in mappings:
forward[mapping.s.prefix, mapping.o.prefix][mapping.s.identifier][mapping.o.identifier].append(mapping)
backward[mapping.s.prefix, mapping.o.prefix][mapping.o.identifier][mapping.s.identifier].append(mapping)

index: defaultdict[Triple, list[Evidence]] = defaultdict(list)
index: t.DefaultDict[Triple, t.List[Evidence]] = defaultdict(list)
for preindex in [forward, backward]:
for d1 in preindex.values():
for d2 in d1.values():
Expand All @@ -366,15 +376,20 @@ def get_many_to_many(mappings: list[Mapping]) -> list[Mapping]:
return rv


def filter_many_to_many(mappings: list[Mapping], *, progress: bool = True) -> list[Mapping]:
def filter_many_to_many(mappings: t.List[Mapping], *, progress: bool = True) -> t.List[Mapping]:
"""Filter out many to many mappings."""
skip_mappings = get_many_to_many(mappings)
return filter_mappings(mappings, skip_mappings, progress=progress)


def project(
mappings: list[Mapping], source_prefix: str, target_prefix: str, *, return_sus: bool = False, progress: bool = False
) -> list[Mapping] | tuple[list[Mapping], list[Mapping]]:
mappings: t.List[Mapping],
source_prefix: str,
target_prefix: str,
*,
return_sus: bool = False,
progress: bool = False,
) -> t.List[Mapping] | t.Tuple[t.List[Mapping], t.List[Mapping]]:
"""Ensure that each identifier only appears as the subject of one mapping."""
mappings = keep_subject_prefixes(mappings, source_prefix, progress=progress)
mappings = keep_object_prefixes(mappings, target_prefix, progress=progress)
Expand All @@ -386,13 +401,13 @@ def project(
return mappings


def project_dict(mappings: list[Mapping], source_prefix: str, target_prefix: str) -> dict[str, str]:
def project_dict(mappings: t.List[Mapping], source_prefix: str, target_prefix: str) -> t.Dict[str, str]:
"""Get a dictionary from source identifiers to target identifiers."""
mappings = cast(list[Mapping], project(mappings, source_prefix, target_prefix))
mappings = cast(t.List[Mapping], project(mappings, source_prefix, target_prefix))
return {mapping.s.identifier: mapping.o.identifier for mapping in mappings}


def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
def prioritize(mappings: t.List[Mapping], priority: t.List[str]) -> t.List[Mapping]:
"""Get a priority star graph.

:param mappings:
Expand All @@ -403,7 +418,7 @@ def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
exact_mappings = len(mappings)

graph = to_graph(mappings).to_undirected()
rv: list[Mapping] = []
rv: t.List[Mapping] = []
for component in tqdm(nx.connected_components(graph), unit="component", unit_scale=True):
o = _get_priority(component, priority)
if o is None:
Expand All @@ -427,7 +442,7 @@ def prioritize(mappings: list[Mapping], priority: list[str]) -> list[Mapping]:
return rv


def _get_priority(component: list[Reference], priority: list[str]) -> Reference | None:
def _get_priority(component: t.List[Reference], priority: t.List[str]) -> t.Optional[Reference]:
prefix_to_references = defaultdict(list)
for c in component:
prefix_to_references[c.prefix].append(c)
Expand All @@ -444,7 +459,7 @@ def _get_priority(component: list[Reference], priority: list[str]) -> Reference
return None


def unindex(index: Index, *, progress: bool = True) -> list[Mapping]:
def unindex(index: Index, *, progress: bool = True) -> t.List[Mapping]:
"""Convert a mapping index into a list of mapping objects."""
return [
Mapping.from_triple(triple, evidence=evidence)
Expand All @@ -454,13 +469,13 @@ def unindex(index: Index, *, progress: bool = True) -> list[Mapping]:
]


def deduplicate_evidence(evidence: list[Evidence]) -> list[Evidence]:
def deduplicate_evidence(evidence: t.List[Evidence]) -> t.List[Evidence]:
"""Deduplicate a list of evidences based on their "key" function."""
d = {e.key(): e for e in evidence}
return list(d.values())


def validate_mappings(mappings: list[Mapping], *, progress: bool = True) -> None:
def validate_mappings(mappings: t.List[Mapping], *, progress: bool = True) -> None:
"""Validate mappings against the Bioregistry and raise an error on the first invalid."""
import bioregistry

Expand Down Expand Up @@ -489,7 +504,7 @@ def validate_mappings(mappings: list[Mapping], *, progress: bool = True) -> None
raise ValueError(f"banana in mapping object: {mapping}")


def summarize_prefixes(mappings: list[Mapping]) -> pd.DataFrame:
def summarize_prefixes(mappings: t.List[Mapping]) -> pd.DataFrame:
"""Get a dataframe summarizing the prefixes appearing in the mappings."""
import bioregistry

Expand Down
Loading
Loading