CASTER¶
- class CASTER(*, drug_channels, encoder_hidden_channels=32, encoder_output_channels=32, decoder_hidden_channels=32, hidden_channels=32, out_hidden_channels=32, out_channels=1, lambda3=1e-05, magnifying_factor=100)[source]¶
Bases:
chemicalx.models.base.Model
An implementation of the CASTER model from [huang2020].
See also
This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/17
- huang2020
Huang, K., et al. (2020). CASTER: Predicting drug interactions with chemical substructure representation. AAAI 2020 - 34th AAAI Conference on Artificial Intelligence, 702–709.
Methods Summary
Perform a forward pass of the dictionary encoder submodule.
forward
(drug_pair_features)Run a forward pass of the CASTER model.
unpack
(batch)Return the “functional representation” of drug pairs, as defined in the original implementation.
Methods Documentation
- dictionary_encoder(drug_pair_features_latent, dictionary_features_latent)[source]¶
Perform a forward pass of the dictionary encoder submodule.
- Parameters
drug_pair_features_latent (
FloatTensor
) – encoder output for the input drug_pair_features (batch_size x encoder_output_channels)dictionary_features_latent (
FloatTensor
) – projection of the drug_pair_features using the dictionary basis (encoder_output_channels x drug_channels)
- Return type
FloatTensor
- Returns
sparse code X_o: (batch_size x drug_channels)
- forward(drug_pair_features)[source]¶
Run a forward pass of the CASTER model.
- Parameters
drug_pair_features (
FloatTensor
) – functional representation of each drug pair (see unpack method)- Return type
- Returns
(Tuple[torch.FloatTensor): a tuple of tensors including: prediction_scores: predicted target scores for each drug pair reconstructed: input drug pair vectors reconstructed by the encoder-decoder chain dictionary_encoded: drug pair features encoded by the dictionary encoder submodule dictionary_features_latent: projection of the encoded drug pair features using the dictionary basis drug_pair_features_latent: encoder output for the input drug_pair_features drug_pair_features: a copy of the input unpacked drug_pair_features (needed for loss calculation)
- unpack(batch)[source]¶
Return the “functional representation” of drug pairs, as defined in the original implementation.
- Parameters
batch (
DrugPairBatch
) – batch of drug pairs- Return type
Tuple
[FloatTensor
]- Returns
each pair is represented as a single vector with x^i = 1 if either x_1^i >= 1 or x_2^i >= 1