Source code for chemicalx.data.drugpairbatch

"""A module for the drug pair batch class."""

from dataclasses import dataclass
from typing import Optional, TypeVar

import pandas as pd
import torch
from torch.types import Device

from chemicalx.compat import PackedGraph

__all__ = [
    "DrugPairBatch",
]

X = TypeVar("X", torch.FloatTensor, PackedGraph)


[docs]@dataclass class DrugPairBatch: """A data class to store a labeled drug pair batch.""" #: A dataframe with drug pair, context and label columns. identifiers: Optional[pd.DataFrame] #: A matrix of molecular features for the left-hand drugs. drug_features_left: Optional[torch.FloatTensor] #: Packed molecules for the left-hand drugs. drug_molecules_left: Optional[PackedGraph] #: A matrix of molecular features for the right-hand drugs. drug_features_right: Optional[torch.FloatTensor] #: Packed molecules for the right-hand drugs. drug_molecules_right: Optional[PackedGraph] #: A matrix of biological/chemical context features. context_features: Optional[torch.FloatTensor] #: A vector of drug pair labels. labels: Optional[torch.FloatTensor]
[docs] def to(self, device: Device) -> "DrugPairBatch": """Move this batch to the given device (out of place).""" return DrugPairBatch( identifiers=self.identifiers, drug_features_left=_move_tensor(self.drug_features_left, device), drug_molecules_left=_move_tensor(self.drug_molecules_left, device), drug_features_right=_move_tensor(self.drug_features_right, device), drug_molecules_right=_move_tensor(self.drug_molecules_right, device), context_features=_move_tensor(self.context_features, device), labels=_move_tensor(self.labels, device), )
def _move_tensor(x: Optional[X], device: Device) -> Optional[X]: if x is None: return None return x.to(device)