MHCADDI¶
- class MHCADDI(*, atom_feature_channels=16, atom_type_channels=16, bond_type_channels=16, node_channels=16, edge_channels=16, hidden_channels=16, readout_channels=16, output_channels=1, dropout=0.5)[source]¶
Bases:
chemicalx.models.base.Model
An implementation of the MHCADDI model from [deac2019].
See also
This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/13
- deac2019
Deac, A., et al. (2019). Drug-Drug Adverse Effect Prediction with Graph Co-Attention. arXiv, 1905.00534.
Methods Summary
atom_comp
(atom_features, atom_index)Compute atom projection, a linear transformation of a learned atom embedding and the atom features.
forward
(drug_molecules_left, …)Forward pass with the data.
Calculate all pairwise edges between the atoms in a set of drug pairs.
unpack
(batch)Adjust drug pair batch to model design.
Methods Documentation
- atom_comp(atom_features, atom_index)[source]¶
Compute atom projection, a linear transformation of a learned atom embedding and the atom features.
- Parameters
atom_features (
Tensor
) – Atom input featuresatom_index (
Tensor
) – Index of atom type
- Returns
Node index.
- forward(drug_molecules_left, drug_molecules_right)[source]¶
Forward pass with the data.
- Return type
FloatTensor
- generate_outer_segmentation(graph_sizes_left, graph_sizes_right)[source]¶
Calculate all pairwise edges between the atoms in a set of drug pairs.
Example: Given two sets of drug sizes:
graph_sizes_left = torch.tensor([1, 2]) graph_sizes_right = torch.tensor([3, 4])
Here the drug pairs have sizes (1,3) and (2,4)
This results in:
outer_segmentation_index = tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) outer_index = tensor([0, 1, 2, 3, 4, 5, 6, 3, 4, 5, 6])
- Parameters
graph_sizes_left (
LongTensor
) – List of graph sizes in the left drug batch.graph_sizes_right (
LongTensor
) – List of graph sizes in the right drug batch.
- Returns
Edge indices.
- unpack(batch)[source]¶
Adjust drug pair batch to model design.
- Parameters
batch (
DrugPairBatch
) – Molecular data in a drug pair batch.- Returns
Tuple of data.