Source code for chemicalx.models.base

"""Base classes for models and utilities."""

from abc import ABC, abstractmethod

from torch import nn

from chemicalx.data import DrugPairBatch

__all__ = [
    "UnimplementedModel",
    "Model",
]


[docs]class UnimplementedModel: """The base class for unimplemented ChemicalX models.""" def __init__(self, x: int): """Instantiate a base model.""" self.x = x
[docs]class Model(nn.Module, ABC): """The base class for ChemicalX models."""
[docs] @abstractmethod def unpack(self, batch: DrugPairBatch): """Unpack a batch into a tuple of the features needed for forward. :param batch: A batch object :returns: A tuple that will be used as the positional arguments in this model's :func:`forward` method. """