Skip to content

Commit

Permalink
Fix scoring bug, and tests (#65)
Browse files Browse the repository at this point in the history
* test: make data loader test run and add split ids to index

* test: make plot use old split file

* test: update index length

* drop atom3d

* doc: column name

* refactor: drop unused files

* test: drop empty file

* test: add dataclass tests from pinder

* test: add transform tests

* log test error

* fix plot

* add torch dependency

* add torch to data requirements

* lint

* FIX: bug in scoring

* FIX: Bug in scoring

* fix: don't allow rigid docking on holo (just to be safe)

* type

* chore: fix hyperlink

* fix: fix failing tests

* fix: linting

* fix: mount data discrepancies

* fix: load split in index to avoid global issue

* lint and type

---------

Co-authored-by: OleinikovasV <[email protected]>
Co-authored-by: yusuf1759 <[email protected]>
  • Loading branch information
3 people authored Sep 28, 2024
1 parent b325c48 commit 1a8230e
Show file tree
Hide file tree
Showing 21 changed files with 98 additions and 683 deletions.
6 changes: 3 additions & 3 deletions docs/tutorial/api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
"of a column name with some given value.\n",
"Only those rows, that fulfill all conditions, are returned.\n",
"See the description of\n",
"[`pandas.read_parquet()`]https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html\n",
"[`pandas.read_parquet()`](https://pandas.pydata.org/docs/reference/api/pandas.read_parquet.html)\n",
"for more information on the filter syntax."
]
},
Expand Down Expand Up @@ -285,7 +285,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -402,7 +402,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ dev = [
]
loader = [
"torch",
"atom3d",
]
plots = [
"matplotlib",
Expand Down
1 change: 1 addition & 0 deletions requirements_data.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
tabulate
pdb-validation @ git+https://git.scicore.unibas.ch/schwede/ligand-validation.git
mmpdb @ git+https://github.com/rdkit/mmpdb.git
torch
9 changes: 6 additions & 3 deletions src/plinder/core/loader/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from plinder.core.scores import query_index
from plinder.core.scores.query import FILTERS
from plinder.core.structure.structure import Structure
from plinder.core.utils.log import setup_logger

LOG = setup_logger(__name__)


class PlinderDataset(Dataset): # type: ignore
Expand Down Expand Up @@ -44,9 +47,9 @@ def __init__(
] = structure_featurizer,
**kwargs: Any,
):
self._system_ids = list(
set(query_index(splits=[split], filters=filters)["system_id"])
)
index = query_index(splits=[split], filters=filters)
LOG.info(f"Loading {index.system_id.nunique()} systems")
self._system_ids = list(set(index["system_id"]))
self._num_examples = len(self._system_ids)

self._featurizer = featurizer
Expand Down
4 changes: 2 additions & 2 deletions src/plinder/core/scores/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from duckdb import sql

from plinder.core.scores.query import FILTERS, make_query
from plinder.core.split.utils import get_split
from plinder.core.utils import cpl
from plinder.core.utils.config import get_config
from plinder.core.utils.log import setup_logger
Expand Down Expand Up @@ -51,7 +50,8 @@ def query_index(
df = sql(query).to_df()
if splits is None:
splits = ["train", "val"]
split_df = get_split(cfg=cfg)
split = cpl.get_plinder_path(rel=f"{cfg.data.splits}/{cfg.data.split_file}")
split_df = pd.read_parquet(split)
split_dict = dict(zip(split_df["system_id"], split_df["split"]))
df["split"] = df["system_id"].map(lambda x: split_dict.get(x, "unassigned"))
if "*" not in splits:
Expand Down
2 changes: 1 addition & 1 deletion src/plinder/core/split/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def from_files(
mms_count=mms_count,
)
if plindex_file is None:
plindex = get_plindex()
plindex = get_plindex().drop(columns=["split"])
else:
plindex = pd.read_parquet(plindex_file)
plotter.plindex = plotter.merge_splits_and_plindex(plindex)
Expand Down
2 changes: 1 addition & 1 deletion src/plinder/data/column_descriptions/ligands.tsv
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ ligand_bird_id str Ligand BIRD id
ligand_centroid list[float] Ligand center of geometry
ligand_smiles str Ligand SMILES based on OpenStructure dictionary lookup, or resolved SMILES if not in dictionary
ligand_resolved_smiles str SMILES of only resolved ligand atoms
ligand_rdkit_canonical_smiles str | None RDKit canonical SMILES
ligand_rdkit_canonical_smiles str | None RDKit canonical SMILES (Recommended)
ligand_molecular_weight float | None Molecular weight
ligand_crippen_clogp float | None Ligand Crippen MlogP, see https://www.rdkit.org/docs/source/rdkit.Chem.Crippen.html
ligand_num_rot_bonds int | None Number of rotatable bonds
Expand Down
101 changes: 0 additions & 101 deletions src/plinder/data/utils/cluster.py

This file was deleted.

Loading

0 comments on commit 1a8230e

Please sign in to comment.