Source code for chemicalx.data.drugfeatureset

"""A module for the drug feature set."""

from collections import UserDict
from typing import Dict, Iterable, Mapping, Union

import torch
from torchdrug.data import Molecule

from chemicalx.compat import Graph, PackedGraph

__all__ = [
    "DrugFeatureSet",
]


[docs]class DrugFeatureSet(UserDict, Mapping[str, Mapping[str, Union[torch.FloatTensor, Molecule]]]): """Drug feature set for compounds."""
[docs] @classmethod def from_dict(cls, data: Dict[str, Dict]) -> "DrugFeatureSet": """Generate a drug feature set from a data dictionary.""" return cls( { key: { "features": torch.FloatTensor(features["features"]).view(1, -1), "molecule": Molecule.from_smiles(features["smiles"]), } for key, features in data.items() } )
[docs] def get_feature_matrix(self, drugs: Iterable[str]) -> torch.FloatTensor: """Get the drug feature matrix for a list of drugs. :param drugs: A list of drug identifiers. :returns: A matrix of drug features. """ return torch.cat([self.data[drug]["features"] for drug in drugs])
[docs] def get_molecules(self, drugs: Iterable[str]) -> PackedGraph: """Get the molecular structures. :param drugs: A list of drug identifiers. :returns: The molecules batched together for message passing. """ return Graph.pack([self.data[drug]["molecule"] for drug in drugs])