Source code for chemicalx.models.deepddi

"""An implementation of the DeepDDI model."""

import torch
from torch import nn

from chemicalx.data import DrugPairBatch
from chemicalx.models import Model

__all__ = [
    "DeepDDI",
]


[docs]class DeepDDI(Model): """An implementation of the DeepDDI model from [ryu2018]_. .. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/2 .. [ryu2018] Ryu, J. Y., *et al.* (2018). `Deep learning improves prediction of drug–drug and drug–food interactions <https://doi.org/10.1073/pnas.1803294115>`_. *Proceedings of the National Academy of Sciences*, 115(18), E4304–E4311. """ def __init__( self, *, drug_channels: int, hidden_channels: int = 2048, hidden_layers_num: int = 9, out_channels: int = 1, ): """Instantiate the DeepDDI model. :param drug_channels: The number of drug features. :param hidden_channels: The number of hidden layer neurons. :param hidden_layers_num: The number of hidden layers. :param out_channels: The number of output channels. """ super().__init__() assert hidden_layers_num > 1 layers = [ nn.Linear(drug_channels * 2, hidden_channels), nn.ReLU(), nn.BatchNorm1d(num_features=hidden_channels, affine=True, momentum=None), nn.ReLU(), ] for _ in range(hidden_layers_num - 1): layers.extend( [ nn.Linear(hidden_channels, hidden_channels), nn.ReLU(), nn.BatchNorm1d(num_features=hidden_channels, affine=True, momentum=None), nn.ReLU(), ] ) layers.extend([nn.Linear(hidden_channels, out_channels), nn.Sigmoid()]) self.final = nn.Sequential(*layers)
[docs] def unpack(self, batch: DrugPairBatch): """Return the context features, left drug features and right drug features.""" return ( batch.drug_features_left, batch.drug_features_right, )
def _combine_sides(self, left: torch.FloatTensor, right: torch.FloatTensor) -> torch.FloatTensor: return torch.cat([left, right], dim=1)
[docs] def forward( self, drug_features_left: torch.FloatTensor, drug_features_right: torch.FloatTensor, ) -> torch.FloatTensor: """Run a forward pass of the DeepDDI model. :param drug_features_left: A matrix of head drug features. :param drug_features_right: A matrix of tail drug features. :returns: A column vector of predicted interaction scores. """ hidden = self._combine_sides(drug_features_left, drug_features_right) return self.final(hidden)