Source code for chemicalx.data.contextfeatureset

"""A module for the context feature set class."""

from collections import UserDict
from typing import Iterable, Mapping, Sequence

import torch

__all__ = [
    "ContextFeatureSet",
]


[docs]class ContextFeatureSet(UserDict, Mapping[str, torch.FloatTensor]): """Context feature set for biological/chemical context feature vectors."""
[docs] @classmethod def from_dict(cls, data: Mapping[str, Sequence[float]]) -> "ContextFeatureSet": """Generate a context feature set from a data dictionary.""" return cls({key: torch.FloatTensor(values).view(1, -1) for key, values in data.items()})
[docs] def get_feature_matrix(self, contexts: Iterable[str]) -> torch.FloatTensor: """Get the feature matrix for a list of contexts. :param contexts: A list of context identifiers. :returns: A matrix of context features. """ return torch.cat([self.data[context] for context in contexts])