CASTER

class CASTER(*, drug_channels, encoder_hidden_channels=32, encoder_output_channels=32, decoder_hidden_channels=32, hidden_channels=32, out_hidden_channels=32, out_channels=1, lambda3=1e-05, magnifying_factor=100)[source]

Bases: chemicalx.models.base.Model

An implementation of the CASTER model from [huang2020].

See also

This model was suggested in https://github.com/AstraZeneca/chemicalx/issues/17

huang2020

Huang, K., et al. (2020). CASTER: Predicting drug interactions with chemical substructure representation. AAAI 2020 - 34th AAAI Conference on Artificial Intelligence, 702–709.

Methods Summary

dictionary_encoder(…)

Perform a forward pass of the dictionary encoder submodule.

forward(drug_pair_features)

Run a forward pass of the CASTER model.

unpack(batch)

Return the “functional representation” of drug pairs, as defined in the original implementation.

Methods Documentation

dictionary_encoder(drug_pair_features_latent, dictionary_features_latent)[source]

Perform a forward pass of the dictionary encoder submodule.

Parameters
  • drug_pair_features_latent (FloatTensor) – encoder output for the input drug_pair_features (batch_size x encoder_output_channels)

  • dictionary_features_latent (FloatTensor) – projection of the drug_pair_features using the dictionary basis (encoder_output_channels x drug_channels)

Return type

FloatTensor

Returns

sparse code X_o: (batch_size x drug_channels)

forward(drug_pair_features)[source]

Run a forward pass of the CASTER model.

Parameters

drug_pair_features (FloatTensor) – functional representation of each drug pair (see unpack method)

Return type

Tuple[FloatTensor, ...]

Returns

(Tuple[torch.FloatTensor): a tuple of tensors including: prediction_scores: predicted target scores for each drug pair reconstructed: input drug pair vectors reconstructed by the encoder-decoder chain dictionary_encoded: drug pair features encoded by the dictionary encoder submodule dictionary_features_latent: projection of the encoded drug pair features using the dictionary basis drug_pair_features_latent: encoder output for the input drug_pair_features drug_pair_features: a copy of the input unpacked drug_pair_features (needed for loss calculation)

unpack(batch)[source]

Return the “functional representation” of drug pairs, as defined in the original implementation.

Parameters

batch (DrugPairBatch) – batch of drug pairs

Return type

Tuple[FloatTensor]

Returns

each pair is represented as a single vector with x^i = 1 if either x_1^i >= 1 or x_2^i >= 1