Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 53 additions & 37 deletions molpipeline/estimators/chemprop/component_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Wrapper classes for the chemprop components to make them compatible with scikit-learn."""
"""Wrapper classes for the chemprop components for compatibility with scikit-learn."""

import abc
from collections.abc import Iterable
Expand All @@ -11,16 +11,9 @@
from chemprop.nn.agg import MeanAggregation as _MeanAggregation
from chemprop.nn.agg import SumAggregation as _SumAggregation
from chemprop.nn.ffn import MLP
from chemprop.nn.loss import LossFunction
from chemprop.nn.message_passing import BondMessagePassing as _BondMessagePassing
from chemprop.nn.message_passing import MessagePassing
from chemprop.nn.metrics import (
BinaryAUROCMetric,
CrossEntropyMetric,
Metric,
MSEMetric,
SIDMetric,
)
from chemprop.nn.metrics import ChempropMetric
from chemprop.nn.predictors import BinaryClassificationFFN as _BinaryClassificationFFN
from chemprop.nn.predictors import BinaryDirichletFFN as _BinaryDirichletFFN
from chemprop.nn.predictors import EvidentialFFN as _EvidentialFFN
Expand All @@ -31,21 +24,22 @@
from chemprop.nn.predictors import MveFFN as _MveFFN
from chemprop.nn.predictors import RegressionFFN as _RegressionFFN
from chemprop.nn.predictors import SpectralFFN as _SpectralFFN
from chemprop.nn.predictors import _FFNPredictorBase as _Predictor
from chemprop.nn.predictors import _FFNPredictorBase as _Predictor # noqa: PLC2701
from chemprop.nn.transforms import UnscaleTransform
from chemprop.nn.utils import Activation, get_activation_function
from sklearn.base import BaseEstimator
from torch import Tensor, nn

from molpipeline.estimators.chemprop.loss_wrapper import (
MSE,
SID,
BCELoss,
BinaryDirichletLoss,
BinaryAUROC,
CrossEntropyLoss,
DirichletLoss,
EvidentialLoss,
MSELoss,
MulticlassDirichletLoss,
MulticlassMCCMetric,
MVELoss,
SIDLoss,
)


Expand Down Expand Up @@ -86,7 +80,9 @@ def __init__(
undirected : bool, optional (default=False)
Whether to use undirected edges.
d_vd : int or None, optional (default=None)
Dimension of additional vertex descriptors that will be concatenated to the hidden features before readout
Dimension of additional vertex descriptors that will be concatenated to the
hidden features before readout

"""
super().__init__(
d_v,
Expand Down Expand Up @@ -114,9 +110,14 @@ def reinitialize_network(self) -> Self:
-------
Self
The reinitialized network.

"""
self.W_i, self.W_h, self.W_o, self.W_d = self.setup(
self.d_v, self.d_e, self.d_h, self.d_vd, self.bias
self.d_v,
self.d_e,
self.d_h,
self.d_vd,
self.bias,
)
self.dropout = nn.Dropout(self.dropout_rate)
if isinstance(self.activation, str):
Expand All @@ -137,6 +138,7 @@ def set_params(self, **params: Any) -> Self:
-------
Self
The model with the new parameters.

"""
super().set_params(**params)
self.reinitialize_network()
Expand All @@ -147,8 +149,8 @@ def set_params(self, **params: Any) -> Self:
class PredictorWrapper(_Predictor, BaseEstimator, abc.ABC): # type: ignore
"""Abstract wrapper for the Predictor class."""

_T_default_criterion: LossFunction
_T_default_metric: Metric
_T_default_criterion: ChempropMetric
_T_default_metric: ChempropMetric

def __init__( # pylint: disable=too-many-positional-arguments
self,
Expand All @@ -158,7 +160,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
n_layers: int = 1,
dropout: float = 0,
activation: str = "relu",
criterion: LossFunction | None = None,
criterion: ChempropMetric | None = None,
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
Transformations to apply to the output. None defaults to UnscaleTransform.
kwargs : Any
Additional keyword arguments.

"""
if task_weights is None:
task_weights = torch.ones(n_tasks)
Expand Down Expand Up @@ -226,6 +229,7 @@ def input_dim(self, value: int) -> None:
----------
value : int
The dimension of input.

"""
self._input_dim = value

Expand All @@ -242,6 +246,7 @@ def n_tasks(self, value: int) -> None:
----------
value : int
The number of tasks.

"""
self._n_tasks = value

Expand All @@ -252,6 +257,7 @@ def reinitialize_fnn(self) -> Self:
-------
Self
The reinitialized feedforward network.

"""
self.ffn = MLP.build(
input_dim=self.input_dim,
Expand All @@ -275,6 +281,7 @@ def set_params(self, **params: Any) -> Self:
-------
Self
The model with the new parameters.

"""
super().set_params(**params)
self.reinitialize_fnn()
Expand All @@ -285,8 +292,8 @@ class RegressionFFN(PredictorWrapper, _RegressionFFN): # type: ignore
"""A wrapper for the RegressionFFN class."""

n_targets: int = 1
_T_default_criterion = MSELoss
_T_default_metric = MSEMetric
_T_default_criterion = MSE
_T_default_metric = MSE


class MveFFN(PredictorWrapper, _MveFFN): # type: ignore
Expand All @@ -308,25 +315,25 @@ class BinaryClassificationFFN(PredictorWrapper, _BinaryClassificationFFN): # ty

n_targets: int = 1
_T_default_criterion = BCELoss
_T_default_metric = BinaryAUROCMetric
_T_default_metric = BinaryAUROC


class BinaryDirichletFFN(PredictorWrapper, _BinaryDirichletFFN): # type: ignore
"""A wrapper for the BinaryDirichletFFN class."""

n_targets: int = 2
_T_default_criterion = BinaryDirichletLoss
_T_default_metric = BinaryAUROCMetric
_T_default_criterion = DirichletLoss
_T_default_metric = BinaryAUROC


class MulticlassClassificationFFN(PredictorWrapper, _MulticlassClassificationFFN): # type: ignore
"""A wrapper for the MulticlassClassificationFFN class."""

n_targets: int = 1
_T_default_criterion = CrossEntropyLoss
_T_default_metric = CrossEntropyMetric
_T_default_metric = MulticlassMCCMetric

def __init__( # pylint: disable=too-many-positional-arguments
def __init__( # pylint: disable=too-many-positional-arguments #noqa: PLR0917
self,
n_classes: int,
n_tasks: int = 1,
Expand All @@ -335,7 +342,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
n_layers: int = 1,
dropout: float = 0.0,
activation: str = "relu",
criterion: LossFunction | None = None,
criterion: ChempropMetric | None = None,
task_weights: Tensor | None = None,
threshold: float | None = None,
output_transform: UnscaleTransform | None = None,
Expand Down Expand Up @@ -366,6 +373,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
Threshold for binary classification.
output_transform : UnscaleTransform or None, optional (default=None)
Transformations to apply to the output. None defaults to UnscaleTransform.

"""
super().__init__(
n_tasks,
Expand All @@ -388,23 +396,24 @@ class MulticlassDirichletFFN(PredictorWrapper, _MulticlassDirichletFFN): # type
"""A wrapper for the MulticlassDirichletFFN class."""

n_targets: int = 1
_T_default_criterion = MulticlassDirichletLoss
_T_default_metric = CrossEntropyMetric
_T_default_criterion = DirichletLoss
_T_default_metric = MulticlassMCCMetric


class SpectralFFN(PredictorWrapper, _SpectralFFN): # type: ignore
"""A wrapper for the SpectralFFN class."""

n_targets: int = 1
_T_default_criterion = SIDLoss
_T_default_metric = SIDMetric
_T_default_criterion = SID
_T_default_metric = SID


class MPNN(_MPNN, BaseEstimator):
"""A wrapper for the MPNN class.

The MPNN is the main model class in chemprop. It consists of a message passing network, an aggregation function,
and a feedforward network for prediction.
The MPNN is the main model class in chemprop. It consists of a message passing
network, an aggregation function, and a feedforward network for prediction.

"""

bn: nn.BatchNorm1d | nn.Identity
Expand All @@ -415,7 +424,7 @@ def __init__(
agg: Aggregation,
predictor: PredictorWrapper,
batch_norm: bool = True,
metric_list: Iterable[Metric] | None = None,
metric_list: Iterable[ChempropMetric] | None = None,
warmup_epochs: int = 2,
init_lr: float = 1e-4,
max_lr: float = 1e-3,
Expand Down Expand Up @@ -443,6 +452,7 @@ def __init__(
The maximum learning rate.
final_lr : float, optional (default=1e-4)
The final learning rate.

"""
super().__init__(
message_passing=message_passing,
Expand All @@ -465,6 +475,7 @@ def reinitialize_network(self) -> Self:
-------
Self
The reinitialized network.

"""
if self.batch_norm:
self.bn = nn.BatchNorm1d(self.message_passing.output_dim)
Expand All @@ -473,9 +484,11 @@ def reinitialize_network(self) -> Self:

if self.metric_list is None:
# pylint: disable=protected-access
self.metrics = [self.predictor._T_default_metric, self.criterion]
self.metrics = nn.ModuleList(
[self.predictor._T_default_metric(), self.criterion], # noqa: SLF001
)
else:
self.metrics = [*list(self.metric_list), self.criterion]
self.metrics = nn.ModuleList([*list(self.metric_list), self.criterion])

return self

Expand All @@ -491,6 +504,7 @@ def set_params(self, **params: Any) -> Self:
-------
Self
The model with the new parameters.

"""
super().set_params(**params)
self.reinitialize_network()
Expand All @@ -499,7 +513,7 @@ def set_params(self, **params: Any) -> Self:

# pylint: disable=too-many-ancestors
class MeanAggregation(_MeanAggregation, BaseEstimator):
"""Aggregate the graph-level representation by averaging the node representations."""
"""Aggregate the graph-level representation by averaging node representations."""

def __init__(self, dim: int = 0):
"""Initialize the MeanAggregation class.
Expand All @@ -508,6 +522,7 @@ def __init__(self, dim: int = 0):
----------
dim : int, optional (default=0)
The dimension to aggregate over. See torch_scater.scatter for more details.

"""
super().__init__(dim)

Expand All @@ -523,5 +538,6 @@ def __init__(self, dim: int = 0):
----------
dim : int, optional (default=0)
The dimension to aggregate over. See torch_scater.scatter for more details.

"""
super().__init__(dim)
Loading