From 50c362a8eb54e9ebff36ca0a83e056f533ef7668 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Thu, 17 Oct 2024 10:37:14 +0200 Subject: [PATCH 01/18] Update requirements_chemprop to 2.0.5 --- requirements_chemprop.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_chemprop.txt b/requirements_chemprop.txt index 57b62453..ed31e961 100644 --- a/requirements_chemprop.txt +++ b/requirements_chemprop.txt @@ -1,2 +1,2 @@ -chemprop >=2.0.3, <=2.0.4 +chemprop >=2.0.3, <=2.0.5 lightning From 73e08f9339c6981c1686e780df9488b2faffd4d9 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Mon, 30 Jun 2025 14:58:43 +0200 Subject: [PATCH 02/18] require chemprop==2.2.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 2eb94220..f0ce021f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ dependencies = [ [project.optional-dependencies] chemprop = [ - "chemprop>=2.0.3,<=2.0.4", + "chemprop>=2.2.0,<=2.0.0", "lightning>=2.5.1", ] notebooks = [ From 77834be99058130fd155effb023780a99275aad2 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 30 Sep 2025 15:49:56 +0200 Subject: [PATCH 03/18] bump chemprop version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c7dd03d2..6681e446 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ [project.optional-dependencies] chemprop = [ - "chemprop>=2.2.0,<=2.0.0", + "chemprop>=2.2.0,<=2.3.0", "lightning>=2.5.1", ] notebooks = [ From 0ca4d7229df1818cb6d21c7ee5759e4c71dc44f2 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 30 Sep 2025 16:02:44 +0200 Subject: [PATCH 04/18] reroute loss --- .../estimators/chemprop/loss_wrapper.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 11ecde9c..5bd8ab70 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -3,15 +3,14 @@ from typing import Any import torch -from chemprop.nn.loss import BCELoss as _BCELoss -from chemprop.nn.loss import BinaryDirichletLoss as _BinaryDirichletLoss -from chemprop.nn.loss import CrossEntropyLoss as _CrossEntropyLoss -from chemprop.nn.loss import EvidentialLoss as _EvidentialLoss -from chemprop.nn.loss import LossFunction as _LossFunction -from chemprop.nn.loss import MSELoss as _MSELoss -from chemprop.nn.loss import MulticlassDirichletLoss as _MulticlassDirichletLoss -from chemprop.nn.loss import MVELoss as _MVELoss -from chemprop.nn.loss import SIDLoss as _SIDLoss +from chemprop.nn.metrics import BCELoss as _BCELoss +from chemprop.nn.metrics import DirichletLoss as _BinaryDirichletLoss +from chemprop.nn.metrics import CrossEntropyLoss as _CrossEntropyLoss +from chemprop.nn.metrics import EvidentialLoss as _EvidentialLoss +from chemprop.nn.metrics import LossFunction as _LossFunction +from chemprop.nn.metrics import MSELoss as _MSELoss +from chemprop.nn.metrics import MVELoss as _MVELoss +from chemprop.nn.metrics import SIDLoss as _SIDLoss from numpy.typing import ArrayLike @@ -92,10 +91,6 @@ class MSELoss(LossFunctionParamMixin, _MSELoss): """Mean squared error loss function.""" -class MulticlassDirichletLoss(LossFunctionParamMixin, _MulticlassDirichletLoss): - """Multiclass Dirichlet loss function.""" - - class MVELoss(LossFunctionParamMixin, _MVELoss): """Mean value entropy loss function.""" From ec482386c279af4ee57630163a5d08b84e8e69a5 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:04:20 +0200 Subject: [PATCH 05/18] rename loss --- molpipeline/estimators/chemprop/component_wrapper.py | 5 ++--- molpipeline/estimators/chemprop/loss_wrapper.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index 71f8ae46..2a6b55e6 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -39,11 +39,10 @@ from molpipeline.estimators.chemprop.loss_wrapper import ( BCELoss, - BinaryDirichletLoss, + DirichletLoss, CrossEntropyLoss, EvidentialLoss, MSELoss, - MulticlassDirichletLoss, MVELoss, SIDLoss, ) @@ -315,7 +314,7 @@ class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore """A wrapper for the BinaryDirichletFFN class.""" n_targets: int = 2 - _T_default_criterion = BinaryDirichletLoss + _T_default_criterion = DirichletLoss _T_default_metric = BinaryAUROCMetric diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 5bd8ab70..b30d94b8 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -4,7 +4,7 @@ import torch from chemprop.nn.metrics import BCELoss as _BCELoss -from chemprop.nn.metrics import DirichletLoss as _BinaryDirichletLoss +from chemprop.nn.metrics import DirichletLoss as _DirichletLoss from chemprop.nn.metrics import CrossEntropyLoss as _CrossEntropyLoss from chemprop.nn.metrics import EvidentialLoss as _EvidentialLoss from chemprop.nn.metrics import LossFunction as _LossFunction @@ -75,8 +75,8 @@ class BCELoss(LossFunctionParamMixin, _BCELoss): """Binary cross-entropy loss function.""" -class BinaryDirichletLoss(LossFunctionParamMixin, _BinaryDirichletLoss): - """Binary Dirichlet loss function.""" +class DirichletLoss(LossFunctionParamMixin, _DirichletLoss): + """Dirichlet loss function.""" class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss): From dcafc54b3ad3511d72cb171acae40762f099012c Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:06:49 +0200 Subject: [PATCH 06/18] update lock file --- uv.lock | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index 85a06774..cd930d99 100644 --- a/uv.lock +++ b/uv.lock @@ -506,22 +506,24 @@ wheels = [ [[package]] name = "chemprop" -version = "2.0.4" +version = "2.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "astartes", extra = ["molecules"] }, { name = "configargparse" }, + { name = "descriptastorus" }, { name = "lightning" }, { name = "numpy" }, { name = "pandas" }, { name = "rdkit" }, + { name = "rich" }, { name = "scikit-learn" }, { name = "scipy" }, { name = "torch" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/58/19/b34a4a70515ae73edb5349c0f2eacc352feb3f9ecf406379daab8c20396f/chemprop-2.0.4.tar.gz", hash = "sha256:701bf0e8202c9d4c1e4b15f0c50535d276872f2c5959a9abb51d6c3907809dc1", size = 30420796 } +sdist = { url = "https://files.pythonhosted.org/packages/81/61/257e7d444a5c11f4651764260444f894383e8c85087961cdf9321932f699/chemprop-2.2.1.tar.gz", hash = "sha256:d33a680917b67c779939b5a1e346709ecc3c4ece6003de3fb98472a23c1c7021", size = 139025 } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/b5/49781be83dffcd2210f24e93c184c20ce6829aa408ff4e4330ada419519c/chemprop-2.0.4-py3-none-any.whl", hash = "sha256:a1c9bcf3ff1a84dae6693751f72be8bf0898e344dcec074e397e5a6aee9cf37f", size = 92323 }, + { url = "https://files.pythonhosted.org/packages/0e/62/8c3f40a0dcdad1b40aad421db67dabc9b00aab8e6d2a9c85d4c115933199/chemprop-2.2.1-py3-none-any.whl", hash = "sha256:90ce9bd38aef1d0d6015b342cd2b034677147d60850b3f67bfb28818237cba01", size = 144148 }, ] [[package]] @@ -729,6 +731,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bb/78/983efd23200921d9edb6bd40512e1aa04af553d7d5a171e50f9b2b45d109/coverage-7.10.4-py3-none-any.whl", hash = "sha256:065d75447228d05121e5c938ca8f0e91eed60a1eb2d1258d42d5084fecfc3302", size = 208365 }, ] +[[package]] +name = "crepes" +version = "0.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/02/cd1460a4d9432fdf32a9f781613277d579b99f49b5756c9c630adf0f95c5/crepes-0.8.0.tar.gz", hash = "sha256:ce11cc2befe824db146b06e0de0b1d491cf74fe86832815e9a63bc2bec324410", size = 44996 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/cc/66e152ce1b40227ed011f2fb1b10640cdf30a41df4913721ab9cc86418dd/crepes-0.8.0-py3-none-any.whl", hash = "sha256:879b8eac4fe80de628a6f418ac710e323e537233e477ca0e745faaa205f05872", size = 38700 }, +] + [[package]] name = "cycler" version = "0.12.1" @@ -777,6 +793,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 }, ] +[[package]] +name = "descriptastorus" +version = "2.8.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pandas-flavor" }, + { name = "rdkit" }, + { name = "scipy" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/3c/67bfc03db3f6527a54a0f9b696809bca55eb6de27928829218ab7e6c40e7/descriptastorus-2.8.0-py3-none-any.whl", hash = "sha256:cc0c8201d6d9d8534dc8a45ce629aeaf4cf1e62ba848b9b1e392b2600ca834dc", size = 2127218 }, +] + [[package]] name = "dill" version = "0.4.0" @@ -1760,6 +1790,7 @@ wheels = [ name = "molpipeline" source = { editable = "." } dependencies = [ + { name = "crepes" }, { name = "joblib" }, { name = "loguru" }, { name = "matplotlib" }, @@ -1801,7 +1832,8 @@ dev = [ [package.metadata] requires-dist = [ - { name = "chemprop", marker = "extra == 'chemprop'", specifier = ">=2.0.3,<=2.0.4" }, + { name = "chemprop", marker = "extra == 'chemprop'", specifier = ">=2.2.0,<=2.3.0" }, + { name = "crepes", specifier = ">=0.8.0" }, { name = "joblib", specifier = ">=1.3.0" }, { name = "jupyterlab", marker = "extra == 'notebooks'", specifier = ">=4.4.0" }, { name = "lightning", marker = "extra == 'chemprop'", specifier = ">=2.5.1" }, @@ -2347,6 +2379,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/f9/07086f5b0f2a19872554abeea7658200824f5835c58a106fa8f2ae96a46c/pandas-2.3.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:5db9637dbc24b631ff3707269ae4559bce4b7fd75c1c4d7e13f40edc42df4444", size = 13189044 }, ] +[[package]] +name = "pandas-flavor" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pandas" }, + { name = "xarray" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f4/f3/be418c6244854bf66e3ec08c996a4b2b666536f2d09c1b9f2de8dd0eff73/pandas_flavor-0.7.0.tar.gz", hash = "sha256:617bf9f96902017afc9bd284f611592bce91806d3c7ae34ad64f6edab3edaf7e", size = 11057 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/e6/71ed4d95676098159b533c4a4c424cf453fec9614edaff1a0633fe228eef/pandas_flavor-0.7.0-py3-none-any.whl", hash = "sha256:7ee81e834b111e424679776f49c51abcffec88203b3ff0df2c9cb75550e06b1a", size = 8375 }, +] + [[package]] name = "pandocfilters" version = "1.5.1" @@ -3810,6 +3855,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e1/07/c6fe3ad3e685340704d314d765b7912993bcb8dc198f0e7a89382d37974b/win32_setctime-1.2.0-py3-none-any.whl", hash = "sha256:95d644c4e708aba81dc3704a116d8cbc974d70b3bdb8be1d150e36be6e9d1390", size = 4083 }, ] +[[package]] +name = "xarray" +version = "2025.9.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "packaging" }, + { name = "pandas" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/5d/e139112a463336c636d4455494f3227b7f47a2e06ca7571e6b88158ffc06/xarray-2025.9.1.tar.gz", hash = "sha256:f34a27a52c13d1f3cceb7b27276aeec47021558363617dd7ef4f4c8b379011c0", size = 3057322 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0e/a7/6eeb32e705d510a672f74135f538ad27f87f3d600845bfd3834ea3a77c7e/xarray-2025.9.1-py3-none-any.whl", hash = "sha256:3e9708db0d7915c784ed6c227d81b398dca4957afe68d119481f8a448fc88c44", size = 1364411 }, +] + [[package]] name = "yarl" version = "1.20.1" From ef02b244e50f060c65a5bb8b7aa20ee6b8ad1113 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:15:20 +0200 Subject: [PATCH 07/18] Fix Loss --- .../estimators/chemprop/component_wrapper.py | 42 ++++++++++++++----- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index 2a6b55e6..d95b2756 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -1,4 +1,4 @@ -"""Wrapper classes for the chemprop components to make them compatible with scikit-learn.""" +"""Wrapper classes for the chemprop components for compatibility with scikit-learn.""" import abc from collections.abc import Iterable @@ -31,7 +31,7 @@ from chemprop.nn.predictors import MveFFN as _MveFFN from chemprop.nn.predictors import RegressionFFN as _RegressionFFN from chemprop.nn.predictors import SpectralFFN as _SpectralFFN -from chemprop.nn.predictors import _FFNPredictorBase as _Predictor +from chemprop.nn.predictors import _FFNPredictorBase as _Predictor # noqa: PLC2701 from chemprop.nn.transforms import UnscaleTransform from chemprop.nn.utils import Activation, get_activation_function from sklearn.base import BaseEstimator @@ -39,8 +39,8 @@ from molpipeline.estimators.chemprop.loss_wrapper import ( BCELoss, - DirichletLoss, CrossEntropyLoss, + DirichletLoss, EvidentialLoss, MSELoss, MVELoss, @@ -85,7 +85,9 @@ def __init__( undirected : bool, optional (default=False) Whether to use undirected edges. d_vd : int or None, optional (default=None) - Dimension of additional vertex descriptors that will be concatenated to the hidden features before readout + Dimension of additional vertex descriptors that will be concatenated to the + hidden features before readout + """ super().__init__( d_v, @@ -113,9 +115,14 @@ def reinitialize_network(self) -> Self: ------- Self The reinitialized network. + """ self.W_i, self.W_h, self.W_o, self.W_d = self.setup( - self.d_v, self.d_e, self.d_h, self.d_vd, self.bias + self.d_v, + self.d_e, + self.d_h, + self.d_vd, + self.bias, ) self.dropout = nn.Dropout(self.dropout_rate) if isinstance(self.activation, str): @@ -136,6 +143,7 @@ def set_params(self, **params: Any) -> Self: ------- Self The model with the new parameters. + """ super().set_params(**params) self.reinitialize_network() @@ -189,6 +197,7 @@ def __init__( # pylint: disable=too-many-positional-arguments Transformations to apply to the output. None defaults to UnscaleTransform. kwargs : Any Additional keyword arguments. + """ if task_weights is None: task_weights = torch.ones(n_tasks) @@ -225,6 +234,7 @@ def input_dim(self, value: int) -> None: ---------- value : int The dimension of input. + """ self._input_dim = value @@ -241,6 +251,7 @@ def n_tasks(self, value: int) -> None: ---------- value : int The number of tasks. + """ self._n_tasks = value @@ -251,6 +262,7 @@ def reinitialize_fnn(self) -> Self: ------- Self The reinitialized feedforward network. + """ self.ffn = MLP.build( input_dim=self.input_dim, @@ -274,6 +286,7 @@ def set_params(self, **params: Any) -> Self: ------- Self The model with the new parameters. + """ super().set_params(**params) self.reinitialize_fnn() @@ -325,7 +338,7 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN _T_default_criterion = CrossEntropyLoss _T_default_metric = CrossEntropyMetric - def __init__( # pylint: disable=too-many-positional-arguments + def __init__( # pylint: disable=too-many-positional-arguments #noqa: PLR0917 self, n_classes: int, n_tasks: int = 1, @@ -365,6 +378,7 @@ def __init__( # pylint: disable=too-many-positional-arguments Threshold for binary classification. output_transform : UnscaleTransform or None, optional (default=None) Transformations to apply to the output. None defaults to UnscaleTransform. + """ super().__init__( n_tasks, @@ -387,7 +401,7 @@ class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type """A wrapper for the MulticlassDirichletFFN class.""" n_targets: int = 1 - _T_default_criterion = MulticlassDirichletLoss + _T_default_criterion = DirichletLoss _T_default_metric = CrossEntropyMetric @@ -402,8 +416,9 @@ class SpectralFFN(PredictorWrapper, _SpectralFFN): # type: ignore class MPNN(_MPNN, BaseEstimator): """A wrapper for the MPNN class. - The MPNN is the main model class in chemprop. It consists of a message passing network, an aggregation function, - and a feedforward network for prediction. + The MPNN is the main model class in chemprop. It consists of a message passing + network, an aggregation function, and a feedforward network for prediction. + """ bn: nn.BatchNorm1d | nn.Identity @@ -442,6 +457,7 @@ def __init__( The maximum learning rate. final_lr : float, optional (default=1e-4) The final learning rate. + """ super().__init__( message_passing=message_passing, @@ -464,6 +480,7 @@ def reinitialize_network(self) -> Self: ------- Self The reinitialized network. + """ if self.batch_norm: self.bn = nn.BatchNorm1d(self.message_passing.output_dim) @@ -472,7 +489,7 @@ def reinitialize_network(self) -> Self: if self.metric_list is None: # pylint: disable=protected-access - self.metrics = [self.predictor._T_default_metric, self.criterion] + self.metrics = [self.predictor._T_default_metric, self.criterion] # noqa: SLF001 else: self.metrics = [*list(self.metric_list), self.criterion] @@ -490,6 +507,7 @@ def set_params(self, **params: Any) -> Self: ------- Self The model with the new parameters. + """ super().set_params(**params) self.reinitialize_network() @@ -498,7 +516,7 @@ def set_params(self, **params: Any) -> Self: # pylint: disable=too-many-ancestors class MeanAggregation(_MeanAggregation, BaseEstimator): - """Aggregate the graph-level representation by averaging the node representations.""" + """Aggregate the graph-level representation by averaging node representations.""" def __init__(self, dim: int = 0): """Initialize the MeanAggregation class. @@ -507,6 +525,7 @@ def __init__(self, dim: int = 0): ---------- dim : int, optional (default=0) The dimension to aggregate over. See torch_scater.scatter for more details. + """ super().__init__(dim) @@ -522,5 +541,6 @@ def __init__(self, dim: int = 0): ---------- dim : int, optional (default=0) The dimension to aggregate over. See torch_scater.scatter for more details. + """ super().__init__(dim) From bf95cefca636c7120f2a637fd507c5df12897d54 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:40:25 +0200 Subject: [PATCH 08/18] Update loss to metrics --- .../estimators/chemprop/loss_wrapper.py | 33 ++++++++++++------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index b30d94b8..006bf4db 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -1,16 +1,17 @@ """Wrapper for Chemprop loss functions.""" -from typing import Any +from typing import Any, override import torch +from chemprop.nn.metrics import MSE as _MSE +from chemprop.nn.metrics import SID as _SID from chemprop.nn.metrics import BCELoss as _BCELoss -from chemprop.nn.metrics import DirichletLoss as _DirichletLoss +from chemprop.nn.metrics import ChempropMetric as _ChempropMetric from chemprop.nn.metrics import CrossEntropyLoss as _CrossEntropyLoss +from chemprop.nn.metrics import DirichletLoss as _DirichletLoss from chemprop.nn.metrics import EvidentialLoss as _EvidentialLoss -from chemprop.nn.metrics import LossFunction as _LossFunction -from chemprop.nn.metrics import MSELoss as _MSELoss +from chemprop.nn.metrics import MulticlassMCCMetric as _MulticlassMCCMetric from chemprop.nn.metrics import MVELoss as _MVELoss -from chemprop.nn.metrics import SIDLoss as _SIDLoss from numpy.typing import ArrayLike @@ -19,7 +20,7 @@ class LossFunctionParamMixin: _original_task_weights: ArrayLike - def __init__(self: _LossFunction, task_weights: ArrayLike) -> None: + def __init__(self: _ChempropMetric, task_weights: ArrayLike) -> None: """Initialize the loss function. Parameters @@ -32,7 +33,8 @@ def __init__(self: _LossFunction, task_weights: ArrayLike) -> None: self._original_task_weights = task_weights # pylint: disable=unused-argument - def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: + @override + def get_params(self: _ChempropMetric, deep: bool = True) -> dict[str, Any]: """Get the parameters of the loss function. Parameters @@ -44,10 +46,11 @@ def get_params(self: _LossFunction, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] The parameters of the loss function. + """ return {"task_weights": self._original_task_weights} - def set_params(self: _LossFunction, **params: Any) -> _LossFunction: + def set_params(self: _ChempropMetric, **params: Any) -> _ChempropMetric: """Set the parameters of the loss function. Parameters @@ -59,13 +62,15 @@ def set_params(self: _LossFunction, **params: Any) -> _LossFunction: ------- Self The loss function with the new parameters. + """ task_weights = params.pop("task_weights", None) if task_weights is not None: self._original_task_weights = task_weights state_dict = self.state_dict() state_dict["task_weights"] = torch.as_tensor( - task_weights, dtype=torch.float + task_weights, + dtype=torch.float, ).view(1, -1) self.load_state_dict(state_dict) return self @@ -79,6 +84,10 @@ class DirichletLoss(LossFunctionParamMixin, _DirichletLoss): """Dirichlet loss function.""" +class MulticlassMCCMetric(LossFunctionParamMixin, _MulticlassMCCMetric): + """Multiclass Matthews correlation coefficient metric.""" + + class CrossEntropyLoss(LossFunctionParamMixin, _CrossEntropyLoss): """Cross-entropy loss function.""" @@ -87,7 +96,7 @@ class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss): """Evidential loss function.""" -class MSELoss(LossFunctionParamMixin, _MSELoss): +class MSELoss(LossFunctionParamMixin, _MSE): """Mean squared error loss function.""" @@ -95,5 +104,5 @@ class MVELoss(LossFunctionParamMixin, _MVELoss): """Mean value entropy loss function.""" -class SIDLoss(LossFunctionParamMixin, _SIDLoss): - """SID loss function.""" +class SID(LossFunctionParamMixin, _SID): + """SID score function.""" From ed43a548204bafc7f53d984713d85f6937d4fd99 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:43:05 +0200 Subject: [PATCH 09/18] make override safe for py3.11 --- molpipeline/estimators/chemprop/loss_wrapper.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 006bf4db..00194681 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -1,6 +1,11 @@ """Wrapper for Chemprop loss functions.""" -from typing import Any, override +from typing import Any + +try: # required for Python < 3.12 + from typing import override +except ImportError: + from typing_extensions import override import torch from chemprop.nn.metrics import MSE as _MSE From 586a447c04dd14e3fdf7940d78badf942517b2bf Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:46:35 +0200 Subject: [PATCH 10/18] Rename MSELoss to MSE --- molpipeline/estimators/chemprop/loss_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 00194681..2130cedc 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -101,7 +101,7 @@ class EvidentialLoss(LossFunctionParamMixin, _EvidentialLoss): """Evidential loss function.""" -class MSELoss(LossFunctionParamMixin, _MSE): +class MSE(LossFunctionParamMixin, _MSE): """Mean squared error loss function.""" From d92ffd422d191e0f4d175a3d71ffdf7467bef190 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:49:38 +0200 Subject: [PATCH 11/18] Add BinaryAUROC --- molpipeline/estimators/chemprop/loss_wrapper.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 2130cedc..dd185f7f 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -11,6 +11,7 @@ from chemprop.nn.metrics import MSE as _MSE from chemprop.nn.metrics import SID as _SID from chemprop.nn.metrics import BCELoss as _BCELoss +from chemprop.nn.metrics import BinaryAUROC as _BinaryAUROC from chemprop.nn.metrics import ChempropMetric as _ChempropMetric from chemprop.nn.metrics import CrossEntropyLoss as _CrossEntropyLoss from chemprop.nn.metrics import DirichletLoss as _DirichletLoss @@ -111,3 +112,7 @@ class MVELoss(LossFunctionParamMixin, _MVELoss): class SID(LossFunctionParamMixin, _SID): """SID score function.""" + + +class BinaryAUROC(LossFunctionParamMixin, _BinaryAUROC): + """Binary area under the receiver operating characteristic curve metric.""" From 4f7fc16df3e9770f39ae969c8cbda6a9f5e03336 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 09:52:31 +0200 Subject: [PATCH 12/18] Update component wrappers --- .../estimators/chemprop/component_wrapper.py | 41 ++++++++----------- 1 file changed, 18 insertions(+), 23 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index d95b2756..f0e0ec96 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -11,16 +11,9 @@ from chemprop.nn.agg import MeanAggregation as _MeanAggregation from chemprop.nn.agg import SumAggregation as _SumAggregation from chemprop.nn.ffn import MLP -from chemprop.nn.loss import LossFunction from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing from chemprop.nn.message_passing import MessagePassing -from chemprop.nn.metrics import ( - BinaryAUROCMetric, - CrossEntropyMetric, - Metric, - MSEMetric, - SIDMetric, -) +from chemprop.nn.metrics import ChempropMetric from chemprop.nn.predictors import BinaryClassificationFFN as _BinaryClassificationFFN from chemprop.nn.predictors import BinaryDirichletFFN as _BinaryDirichletFFN from chemprop.nn.predictors import EvidentialFFN as _EvidentialFFN @@ -38,13 +31,15 @@ from torch import Tensor, nn from molpipeline.estimators.chemprop.loss_wrapper import ( + MSE, + SID, BCELoss, + BinaryAUROC, CrossEntropyLoss, DirichletLoss, EvidentialLoss, - MSELoss, + MulticlassMCCMetric, MVELoss, - SIDLoss, ) @@ -154,8 +149,8 @@ def set_params(self, **params: Any) -> Self: class PredictorWrapper(_Predictor, BaseEstimator, abc.ABC): # type: ignore """Abstract wrapper for the Predictor class.""" - _T_default_criterion: LossFunction - _T_default_metric: Metric + _T_default_criterion: ChempropMetric + _T_default_metric: ChempropMetric def __init__( # pylint: disable=too-many-positional-arguments self, @@ -165,7 +160,7 @@ def __init__( # pylint: disable=too-many-positional-arguments n_layers: int = 1, dropout: float = 0, activation: str = "relu", - criterion: LossFunction | None = None, + criterion: ChempropMetric | None = None, task_weights: Tensor | None = None, threshold: float | None = None, output_transform: UnscaleTransform | None = None, @@ -297,8 +292,8 @@ class RegressionFFN(PredictorWrapper, _RegressionFFN): # type: ignore """A wrapper for the RegressionFFN class.""" n_targets: int = 1 - _T_default_criterion = MSELoss - _T_default_metric = MSEMetric + _T_default_criterion = MSE + _T_default_metric = MSE class MveFFN(PredictorWrapper, _MveFFN): # type: ignore @@ -320,7 +315,7 @@ class BinaryClassificationFFN(PredictorWrapper, _BinaryClassificationFFN): # ty n_targets: int = 1 _T_default_criterion = BCELoss - _T_default_metric = BinaryAUROCMetric + _T_default_metric = BinaryAUROC class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore @@ -328,7 +323,7 @@ class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore n_targets: int = 2 _T_default_criterion = DirichletLoss - _T_default_metric = BinaryAUROCMetric + _T_default_metric = BinaryAUROC class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN): # type: ignore @@ -336,7 +331,7 @@ class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN n_targets: int = 1 _T_default_criterion = CrossEntropyLoss - _T_default_metric = CrossEntropyMetric + _T_default_metric = MulticlassMCCMetric def __init__( # pylint: disable=too-many-positional-arguments #noqa: PLR0917 self, @@ -347,7 +342,7 @@ def __init__( # pylint: disable=too-many-positional-arguments #noqa: PLR0917 n_layers: int = 1, dropout: float = 0.0, activation: str = "relu", - criterion: LossFunction | None = None, + criterion: ChempropMetric | None = None, task_weights: Tensor | None = None, threshold: float | None = None, output_transform: UnscaleTransform | None = None, @@ -402,15 +397,15 @@ class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type n_targets: int = 1 _T_default_criterion = DirichletLoss - _T_default_metric = CrossEntropyMetric + _T_default_metric = MulticlassMCCMetric class SpectralFFN(PredictorWrapper, _SpectralFFN): # type: ignore """A wrapper for the SpectralFFN class.""" n_targets: int = 1 - _T_default_criterion = SIDLoss - _T_default_metric = SIDMetric + _T_default_criterion = SID + _T_default_metric = SID class MPNN(_MPNN, BaseEstimator): @@ -429,7 +424,7 @@ def __init__( agg: Aggregation, predictor: PredictorWrapper, batch_norm: bool = True, - metric_list: Iterable[Metric] | None = None, + metric_list: Iterable[ChempropMetric] | None = None, warmup_epochs: int = 2, init_lr: float = 1e-4, max_lr: float = 1e-3, From ea7167dd77886fc3c558f11dc51acf0108da89d0 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 10:47:18 +0200 Subject: [PATCH 13/18] Update imports --- .../test_chemprop/test_component_wrapper.py | 12 +++--- test_extras/test_chemprop/test_models.py | 41 ++++++++++++------- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/test_extras/test_chemprop/test_component_wrapper.py b/test_extras/test_chemprop/test_component_wrapper.py index 53a901aa..0487533a 100644 --- a/test_extras/test_chemprop/test_component_wrapper.py +++ b/test_extras/test_chemprop/test_component_wrapper.py @@ -2,7 +2,7 @@ import unittest -from chemprop.nn.loss import LossFunction +from chemprop.nn.metrics import ChempropMetric from sklearn.base import clone from torch import nn @@ -94,10 +94,10 @@ def test_get_set_params(self) -> None: mpnn2.set_params(**params1) for param_name, param in mpnn1.get_params(deep=True).items(): param2 = mpnn2.get_params(deep=True)[param_name] - # Classes are cloned, so they are not equal, but they should be the same class - # Since (here) objects are identical if their parameters are identical, and since all - # their parameters are listed flat in the params dicts, all objects are identical if - # param dicts are identical. + # Classes are cloned, so they are not equal, but they should be the same + # class. Since (here) objects are identical if their parameters are + # identical, and since all their parameters are listed flat in the params + # dicts, all objects are identical if param dicts are identical. if hasattr(param, "get_params"): self.assertEqual(param.__class__, param2.__class__) else: @@ -115,7 +115,7 @@ def test_clone(self) -> None: clone_param = mpnn_clone.get_params(deep=True)[param_name] if hasattr(param, "get_params"): self.assertEqual(param.__class__, clone_param.__class__) - elif isinstance(param, LossFunction): + elif isinstance(param, ChempropMetric): self.assertEqual( param.state_dict()["task_weights"], clone_param.state_dict()["task_weights"], diff --git a/test_extras/test_chemprop/test_models.py b/test_extras/test_chemprop/test_models.py index 2e7dd001..0c22be9f 100644 --- a/test_extras/test_chemprop/test_models.py +++ b/test_extras/test_chemprop/test_models.py @@ -5,7 +5,7 @@ from collections.abc import Iterable import torch -from chemprop.nn.loss import MSELoss +from chemprop.nn.metrics import MSE from sklearn.base import clone from molpipeline.estimators.chemprop.component_wrapper import ( @@ -69,7 +69,9 @@ def test_get_params(self) -> None: raise ValueError(f"{param_name} should be a type.") else: self.assertEqual( - orig_params[param_name], param, f"Test failed for {param_name}" + orig_params[param_name], + param, + f"Test failed for {param_name}", ) new_params = { @@ -82,7 +84,7 @@ def test_get_params(self) -> None: chemprop_model.set_params(**new_params) model_params = chemprop_model.get_params(deep=True) for param_name, param in new_params.items(): - if param_name in {"model__agg"}: + if param_name == "model__agg": self.assertIsInstance(model_params[param_name], type(param)) continue self.assertEqual(param, model_params[param_name]) @@ -98,8 +100,8 @@ def test_classifier_methods(self) -> None: """Test the classifier methods.""" chemprop_model = get_chemprop_model_binary_classification_mpnn() # pylint: disable=protected-access - self.assertTrue(chemprop_model._is_binary_classifier()) - self.assertFalse(chemprop_model._is_multiclass_classifier()) + self.assertTrue(chemprop_model._is_binary_classifier()) # noqa: SLF001 + self.assertFalse(chemprop_model._is_multiclass_classifier()) # noqa: SLF001 # pylint: enable=protected-access self.assertTrue(hasattr(chemprop_model, "predict_proba")) @@ -128,7 +130,8 @@ def test_json_serialization(self) -> None: param_dict = chemprop_model_copy.get_params(deep=True) self.assertSetEqual( - set(param_dict.keys()), set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys()) + set(param_dict.keys()), + set(DEFAULT_BINARY_CLASSIFICATION_PARAMS.keys()), ) for param_name, param in DEFAULT_BINARY_CLASSIFICATION_PARAMS.items(): if param_name in NO_IDENTITY_CHECK: @@ -144,7 +147,9 @@ def test_json_serialization(self) -> None: self.assertTrue(torch.allclose(param, param_dict[param_name])) else: self.assertEqual( - param_dict[param_name], param, f"Test failed for {param_name}" + param_dict[param_name], + param, + f"Test failed for {param_name}", ) @@ -176,7 +181,9 @@ def test_get_params(self) -> None: raise ValueError(f"{param_name} should be a type.") else: self.assertEqual( - param_dict[param_name], param, f"Test failed for {param_name}" + param_dict[param_name], + param, + f"Test failed for {param_name}", ) def test_set_params(self) -> None: @@ -204,7 +211,7 @@ def test_get_params(self) -> None: param_dict = chemprop_model.get_params(deep=True) expected_params = dict(DEFAULT_REGRESSION_PARAMS) expected_params["model__predictor"] = RegressionFFN - expected_params["model__predictor__criterion"] = MSELoss + expected_params["model__predictor__criterion"] = MSE self.assertSetEqual(set(param_dict.keys()), set(expected_params.keys())) for param_name, param in expected_params.items(): if param_name in NO_IDENTITY_CHECK: @@ -218,7 +225,9 @@ def test_get_params(self) -> None: raise ValueError(f"{param_name} should be a type.") else: self.assertEqual( - param_dict[param_name], param, f"Test failed for {param_name}" + param_dict[param_name], + param, + f"Test failed for {param_name}", ) @@ -236,7 +245,8 @@ def test_get_params(self) -> None: """ n_classes = 3 chemprop_model = ChempropMulticlassClassifier( - lightning_trainer__accelerator="cpu", n_classes=n_classes + lightning_trainer__accelerator="cpu", + n_classes=n_classes, ) param_dict = chemprop_model.get_params(deep=True) expected_params = dict(DEFAULT_MULTICLASS_CLASSIFICATION_PARAMS) # Shallow copy @@ -257,13 +267,16 @@ def test_get_params(self) -> None: self.assertTrue(torch.allclose(param_dict[param_name], param)) else: self.assertEqual( - param_dict[param_name], param, f"Test failed for {param_name}" + param_dict[param_name], + param, + f"Test failed for {param_name}", ) def test_set_params(self) -> None: """Test the set_params methods.""" chemprop_model = ChempropMulticlassClassifier( - lightning_trainer__accelerator="cpu", n_classes=3 + lightning_trainer__accelerator="cpu", + n_classes=3, ) chemprop_model.set_params(**DEFAULT_SET_PARAMS) params = { @@ -278,7 +291,7 @@ def test_set_params(self) -> None: self.assertEqual(current_params[param], value) def test_error_for_multiclass_predictor(self) -> None: - """Test the error for using a multiclass predictor for a binary classification model.""" + """Test error raised by using a multiclass predictor for bin. classification.""" bond_encoder = BondMessagePassing() agg = SumAggregation() with self.assertRaises(ValueError): From 253d3a1ff836c3a3aa73a91894e7e204c5306c22 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 10:47:55 +0200 Subject: [PATCH 14/18] Add default value for task weight --- molpipeline/estimators/chemprop/loss_wrapper.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index dd185f7f..605cfd31 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -26,12 +26,12 @@ class LossFunctionParamMixin: _original_task_weights: ArrayLike - def __init__(self: _ChempropMetric, task_weights: ArrayLike) -> None: + def __init__(self: _ChempropMetric, task_weights: ArrayLike | float = 1.0) -> None: """Initialize the loss function. Parameters ---------- - task_weights : ArrayLike + task_weights : ArrayLike | float, optional The weights for each task. """ From 103ab32238c1e23ce7f86713e3aa2489172cee30 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 10:51:11 +0200 Subject: [PATCH 15/18] Wrap metrics in ModuleList --- molpipeline/estimators/chemprop/component_wrapper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/molpipeline/estimators/chemprop/component_wrapper.py b/molpipeline/estimators/chemprop/component_wrapper.py index f0e0ec96..e095176c 100644 --- a/molpipeline/estimators/chemprop/component_wrapper.py +++ b/molpipeline/estimators/chemprop/component_wrapper.py @@ -484,9 +484,11 @@ def reinitialize_network(self) -> Self: if self.metric_list is None: # pylint: disable=protected-access - self.metrics = [self.predictor._T_default_metric, self.criterion] # noqa: SLF001 + self.metrics = nn.ModuleList( + [self.predictor._T_default_metric(), self.criterion], # noqa: SLF001 + ) else: - self.metrics = [*list(self.metric_list), self.criterion] + self.metrics = nn.ModuleList([*list(self.metric_list), self.criterion]) return self From 61b61e45e32a3185224d711b0c0a1bed13a27fd1 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 10:52:35 +0200 Subject: [PATCH 16/18] Rename metric --- .../chemprop_test_utils/compare_models.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/test_extras/test_chemprop/chemprop_test_utils/compare_models.py b/test_extras/test_chemprop/chemprop_test_utils/compare_models.py index b9e00b4a..253c0c1d 100644 --- a/test_extras/test_chemprop/chemprop_test_utils/compare_models.py +++ b/test_extras/test_chemprop/chemprop_test_utils/compare_models.py @@ -4,7 +4,7 @@ from unittest import TestCase import torch -from chemprop.nn.loss import LossFunction +from chemprop.nn.metrics import ChempropMetric from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.profilers.base import PassThroughProfiler from sklearn.base import BaseEstimator @@ -12,7 +12,9 @@ def compare_params( - test_case: TestCase, model_a: BaseEstimator, model_b: BaseEstimator + test_case: TestCase, + model_a: BaseEstimator, + model_b: BaseEstimator, ) -> None: """Compare the parameters of two models. @@ -24,6 +26,7 @@ def compare_params( The first model. model_b : BaseEstimator The second model. + """ model_a_params = model_a.get_params(deep=True) model_b_params = model_b.get_params(deep=True) @@ -34,7 +37,7 @@ def compare_params( if hasattr(param_a, "get_params"): test_case.assertTrue(hasattr(param_b, "get_params")) test_case.assertNotEqual(id(param_a), id(param_b)) - elif isinstance(param_a, LossFunction): + elif isinstance(param_a, ChempropMetric): test_case.assertEqual( param_a.state_dict()["task_weights"], param_b.state_dict()["task_weights"], @@ -44,7 +47,8 @@ def compare_params( test_case.assertEqual(type(param_a), type(param_b)) elif isinstance(param_a, torch.Tensor): test_case.assertTrue( - torch.equal(param_a, param_b), f"Test failed for {param_name}" + torch.equal(param_a, param_b), + f"Test failed for {param_name}", ) elif param_name == "lightning_trainer__callbacks": test_case.assertIsInstance(param_b, Sequence) From f18b6716df4855fba6aea8b688f94275e02ce876 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 10:58:18 +0200 Subject: [PATCH 17/18] Remove override --- molpipeline/estimators/chemprop/loss_wrapper.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 605cfd31..2d05b899 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -2,11 +2,6 @@ from typing import Any -try: # required for Python < 3.12 - from typing import override -except ImportError: - from typing_extensions import override - import torch from chemprop.nn.metrics import MSE as _MSE from chemprop.nn.metrics import SID as _SID @@ -39,8 +34,7 @@ def __init__(self: _ChempropMetric, task_weights: ArrayLike | float = 1.0) -> No self._original_task_weights = task_weights # pylint: disable=unused-argument - @override - def get_params(self: _ChempropMetric, deep: bool = True) -> dict[str, Any]: + def get_params(self: _ChempropMetric, deep: bool = True) -> dict[str, Any]: # noqa: ARG002 """Get the parameters of the loss function. Parameters From addc7429616eb22752a3715c8fe438216a95ccb7 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 1 Oct 2025 11:04:20 +0200 Subject: [PATCH 18/18] Disable too many ancestors for BinaryAUROC --- molpipeline/estimators/chemprop/loss_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/estimators/chemprop/loss_wrapper.py b/molpipeline/estimators/chemprop/loss_wrapper.py index 2d05b899..b1df29f9 100644 --- a/molpipeline/estimators/chemprop/loss_wrapper.py +++ b/molpipeline/estimators/chemprop/loss_wrapper.py @@ -108,5 +108,5 @@ class SID(LossFunctionParamMixin, _SID): """SID score function.""" -class BinaryAUROC(LossFunctionParamMixin, _BinaryAUROC): +class BinaryAUROC(LossFunctionParamMixin, _BinaryAUROC): # pylint: disable=too-many-ancestors """Binary area under the receiver operating characteristic curve metric."""