Source code for chemicalx.models.deepsynergy

r"""An implementation of the DeepSynergy model."""

import torch
from torch import nn

from chemicalx.data import DrugPairBatch
from chemicalx.models import Model

__all__ = [
    "DeepSynergy",
]


[docs]class DeepSynergy(Model): r"""An implementation of the DeepSynergy model from [preuer2018]_. .. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/16 .. [preuer2018] Preuer, K., *et al.* (2018). `DeepSynergy: predicting anti-cancer drug synergy with Deep Learning <https://doi.org/10.1093/bioinformatics/btx806>`_. *Bioinformatics*, 34(9), 1538–1546. """ def __init__( self, *, context_channels: int, drug_channels: int, input_hidden_channels: int = 32, middle_hidden_channels: int = 32, final_hidden_channels: int = 32, out_channels: int = 1, dropout_rate: float = 0.5, ): """Instantiate the DeepSynergy model. :param context_channels: The number of context features. :param drug_channels: The number of drug features. :param input_hidden_channels: The number of hidden layer neurons in the input layer. :param middle_hidden_channels: The number of hidden layer neurons in the middle layer. :param final_hidden_channels: The number of hidden layer neurons in the final layer. :param out_channels: The number of output channels. :param dropout_rate: The rate of dropout before the scoring head is used. """ super().__init__() self.final = nn.Sequential( nn.Linear(drug_channels + drug_channels + context_channels, input_hidden_channels), nn.ReLU(), nn.Linear(input_hidden_channels, middle_hidden_channels), nn.ReLU(), nn.Linear(middle_hidden_channels, final_hidden_channels), nn.ReLU(), nn.Dropout(dropout_rate), nn.Linear(final_hidden_channels, out_channels), nn.Sigmoid(), )
[docs] def unpack(self, batch: DrugPairBatch): """Return the context features, left drug features, and right drug features.""" return ( batch.context_features, batch.drug_features_left, batch.drug_features_right, )
[docs] def forward( self, context_features: torch.FloatTensor, drug_features_left: torch.FloatTensor, drug_features_right: torch.FloatTensor, ) -> torch.FloatTensor: """Run a forward pass of the DeepSynergy model. :param context_features: A matrix of biological context features. :param drug_features_left: A matrix of head drug features. :param drug_features_right: A matrix of tail drug features. :returns: A column vector of predicted synergy scores. """ hidden = torch.cat([context_features, drug_features_left, drug_features_right], dim=1) return self.final(hidden)