Source code for chemicalx.models.gcnbmp

"""An implementation of the GCNBMP model."""

from typing import List, Optional, Tuple, Union

import torch
import torchdrug
from more_itertools import chunked, pairwise
from torch import nn
from torch.fft import fft, ifft
from torch.nn import functional as F  # noqa:N812
from torch_scatter import scatter_add
from torchdrug import core, layers
from torchdrug.data import PackedGraph

from chemicalx.constants import TORCHDRUG_NODE_FEATURES
from chemicalx.data import DrugPairBatch
from chemicalx.models import Model

__all__ = [
    "GCNBMP",
]


def circular_correlation(left: torch.FloatTensor, right: torch.FloatTensor) -> torch.FloatTensor:
    """Compute the circular correlation of two vectors ``left`` and ``right`` via their Fast Fourier Transforms.

    :param left: the left vector
    :param right: the right vector
    :returns: Joint representation by circular correlation.
    """
    left_x_cfft = torch.conj(fft(left))
    right_x_fft = fft(right)
    circ_corr = ifft(torch.mul(left_x_cfft, right_x_fft))
    return circ_corr.real


class Highway(nn.Module):
    """The Highway update layer from [srivastava2015]_.

    .. [srivastava2015] Srivastava, R. K., *et al.* (2015).
       `Highway Networks <http://arxiv.org/abs/1505.00387>`_.
       *arXiv*, 1505.00387.
    """

    def __init__(self, input_size: int, prev_input_size: int):
        """Instantiate the Highway update layer.

        :param input_size: Current representation size.
        :param prev_input_size: Size of the representation obtained by the previous convolutional layer.
        """
        super().__init__()
        total_size = input_size + prev_input_size
        self.proj = nn.Linear(total_size, input_size)
        self.transform = nn.Linear(total_size, input_size)
        self.transform.bias.data.fill_(-2.0)

    def forward(self, current: torch.Tensor, previous: torch.Tensor) -> torch.Tensor:
        """Compute the gated update.

        :param current: Current layer node representations.
        :param previous: Previous layer node representations.
        :returns: The highway-updated inputs.
        """
        concat_inputs = torch.cat((current, previous), 1)
        proj_result = F.relu(self.proj(concat_inputs))
        proj_gate = F.sigmoid(self.transform(concat_inputs))
        gated = (proj_gate * proj_result) + ((1 - proj_gate) * current)
        return gated


class AttentionPooling(nn.Module):
    """The attention pooling layer from [chen2020]_."""

    def __init__(self, molecule_channels: int, hidden_channels: int):
        """Instantiate the attention pooling layer.

        :param molecule_channels: Input node features.
        :param hidden_channels: Final node representation.
        """
        super(AttentionPooling, self).__init__()
        total_features_channels = molecule_channels + hidden_channels
        # weights here must be shared across all nodes according to the paper
        self.lin = nn.Linear(total_features_channels, hidden_channels)
        self.last_rep = nn.Linear(hidden_channels, hidden_channels)

    def forward(self, input_rep: torch.Tensor, final_rep: torch.Tensor, graph_index: torch.Tensor) -> torch.Tensor:
        """
        Compute an attention-based readout using the input and output layers of the RGCN encoder for one molecule.

        :param input_rep: Input nodes representations.
        :param final_rep: Final nodes representations.
        :param graph_index: Node to graph readout index.
        :returns: Graph-level representation.
        """
        att = torch.sigmoid(self.lin(torch.cat((input_rep, final_rep), dim=1)))
        g = att.mul(self.last_rep(final_rep))
        g = scatter_add(g, graph_index, dim=0)
        return g


class GCNBMPEncoder(nn.Module, core.Configurable):
    """The drug encoding backbone from [chen2020]_."""

    def __init__(
        self,
        input_dim: int,
        hidden_dims: Union[int, List[int]],
        num_relations: int,
        edge_input_dim: Optional[int] = None,
        batch_norm: Optional[bool] = False,
        activation: Optional[str] = "sigmoid",
    ):
        """Instantiate the GCN-BMP encoder.

        :param input_dim: Input dimensions.
        :param hidden_dims: Hidden dimensions.
        :param num_relations: Number of relations.
        :param edge_input_dim: Dimension of edge features.
        :param batch_norm: Apply batch normalization on nodes or not.
        :param activation: Activation function.
        """
        super().__init__()
        if isinstance(hidden_dims, int):
            hidden_dims = [hidden_dims]
        self.input_dim = input_dim
        self.dims = [input_dim, *hidden_dims]

        self.layers = nn.ModuleList()
        for left_dim, right_dim in pairwise(self.dims):
            self.layers.extend(
                (
                    layers.RelationalGraphConv(
                        left_dim, right_dim, num_relations, edge_input_dim, batch_norm, activation
                    ),
                    Highway(right_dim, left_dim),
                )
            )

    def forward(self, graph: torchdrug.data.graph.PackedGraph, input_node_features: torch.Tensor) -> dict:
        """Compute the node representations and the graph representation(s).

        :param graph: Batch of molecular graphs.
        :param input_node_features: Input node representations
        :returns: Node representation matrix.
        """
        hiddens = []
        layer_input = input_node_features
        prev_gcn = input_node_features

        for conv, highway in chunked(self.layers, 2):
            hidden = conv(graph, layer_input)
            hiddens.append(hidden)
            hidden = highway(hidden, prev_gcn)
            hiddens.append(hidden)
            layer_input = hidden
            prev_gcn = hiddens[-2]
        node_feature = hiddens[-1]
        return {"node_feature": node_feature}


[docs]class GCNBMP(Model): """An implementation of the GCN-BMP model from [chen2020]_. .. seealso:: This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/21 .. [chen2020] Chen, X., *et al.* (2020). `GCN-BMP: Investigating graph representation learning for DDI prediction task <https://doi.org/10.1016/j.ymeth.2020.05.014>`_. *Methods*, 179, 47–54. """ def __init__( self, *, molecule_channels: int = TORCHDRUG_NODE_FEATURES, num_relations: int = 4, # TODO: This default value should be set by a dataset-specific constant hidden_channels: int = 16, hidden_conv_layers: int = 1, out_channels: int = 1, ): """Instantiate the GCN-BMP model. :param molecule_channels: The number of node-level features. :param num_relations: Number of edge types. :param hidden_channels: The number of hidden layer neurons in the input layer. :param hidden_conv_layers: The number of hidden layers in the encoder. :param out_channels: The number of output channels. """ super().__init__() self.graph_convolutions = GCNBMPEncoder( molecule_channels, [hidden_channels for _ in range(hidden_conv_layers)], num_relations ) self.attention_readout = AttentionPooling(molecule_channels, hidden_channels) self.final = nn.Sequential( nn.Linear(hidden_channels, out_channels), nn.Sigmoid(), )
[docs] def unpack(self, batch: DrugPairBatch) -> Tuple[PackedGraph, PackedGraph]: """Return the left and right drugs PackedGraphs.""" return ( batch.drug_molecules_left, batch.drug_molecules_right, )
def _forward_molecules(self, molecules: PackedGraph) -> torch.FloatTensor: """ Run a forward pass of the encoder layers. :param molecules: The batched molecular graphs. :returns: A matrix of molecular features """ features = self.graph_convolutions(molecules, molecules.data_dict["node_feature"])["node_feature"] features = self.attention_readout(molecules.data_dict["node_feature"], features, molecules.node2graph) return features def _combine_sides(self, left: torch.FloatTensor, right: torch.FloatTensor) -> torch.FloatTensor: return circular_correlation(left, right)
[docs] def forward(self, molecules_left: PackedGraph, molecules_right: PackedGraph) -> torch.FloatTensor: """ Run a forward pass of the GCN-BMP model. :param molecules_left: The graph of left drug and node features. :param molecules_right: The graph of right drug and node features. :returns: A column vector of predicted synergy scores. """ features_left = self._forward_molecules(molecules_left) features_right = self._forward_molecules(molecules_right) hidden = self._combine_sides(features_left, features_right) return self.final(hidden)