MHCADDI

class MHCADDI(*, atom_feature_channels=16, atom_type_channels=16, bond_type_channels=16, node_channels=16, edge_channels=16, hidden_channels=16, readout_channels=16, output_channels=1, dropout=0.5)[source]

Bases: chemicalx.models.base.Model

An implementation of the MHCADDI model from [deac2019].

See also

This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/13

deac2019

Deac, A., et al. (2019). Drug-Drug Adverse Effect Prediction with Graph Co-Attention. arXiv, 1905.00534.

Methods Summary

atom_comp(atom_features, atom_index)

Compute atom projection, a linear transformation of a learned atom embedding and the atom features.

forward(drug_molecules_left, …)

Forward pass with the data.

generate_outer_segmentation(…)

Calculate all pairwise edges between the atoms in a set of drug pairs.

unpack(batch)

Adjust drug pair batch to model design.

Methods Documentation

atom_comp(atom_features, atom_index)[source]

Compute atom projection, a linear transformation of a learned atom embedding and the atom features.

Parameters
  • atom_features (Tensor) – Atom input features

  • atom_index (Tensor) – Index of atom type

Returns

Node index.

forward(drug_molecules_left, drug_molecules_right)[source]

Forward pass with the data.

Return type

FloatTensor

generate_outer_segmentation(graph_sizes_left, graph_sizes_right)[source]

Calculate all pairwise edges between the atoms in a set of drug pairs.

Example: Given two sets of drug sizes:

graph_sizes_left = torch.tensor([1, 2]) graph_sizes_right = torch.tensor([3, 4])

Here the drug pairs have sizes (1,3) and (2,4)

This results in:

outer_segmentation_index = tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) outer_index = tensor([0, 1, 2, 3, 4, 5, 6, 3, 4, 5, 6])

Parameters
  • graph_sizes_left (LongTensor) – List of graph sizes in the left drug batch.

  • graph_sizes_right (LongTensor) – List of graph sizes in the right drug batch.

Returns

Edge indices.

unpack(batch)[source]

Adjust drug pair batch to model design.

Parameters

batch (DrugPairBatch) – Molecular data in a drug pair batch.

Returns

Tuple of data.