Pipeline¶
A collection of full training and evaluation pipelines.
- class Result(model, predictions, losses, train_time, evaluation_time, metrics)[source]¶
A result package.
- pipeline(*, dataset, model, model_kwargs=None, optimizer_cls=<class 'torch.optim.adam.Adam'>, optimizer_kwargs=None, loss_cls=<class 'torch.nn.modules.loss.BCELoss'>, loss_kwargs=None, batch_size=512, epochs, context_features, drug_features, drug_molecules, train_size=None, random_state=None, metrics=None, device=None)[source]¶
Run the training and evaluation pipeline.
- Parameters
dataset (
Union[str,DatasetLoader,Type[DatasetLoader],None]) –The dataset can be specified in one of three ways:
The name of the dataset
A subclass of
chemicalx.DatasetLoaderAn instance of a
chemicalx.DatasetLoader
model (
Union[str,Model,Type[Model],None]) –The model can be specified in one of three ways:
The name of the model
A subclass of
chemicalx.ModelAn instance of a
chemicalx.Model
model_kwargs (
Optional[Mapping[str,Any]]) – Keyword arguments to pass through to the model constructor. Relevant if passing model by string or class.optimizer_cls (
Type[Optimizer]) – The class for the optimizer to use. Currently defaults totorch.optim.Adam.optimizer_kwargs (
Optional[Mapping[str,Any]]) – Keyword arguments to pass through to the optimizer construction.loss_cls (
Type[_Loss]) – The loss to use. If none given, usestorch.nn.BCELoss.loss_kwargs (
Optional[Mapping[str,Any]]) – Keyword arguments to pass through to the loss construction.batch_size (
int) – The batch sizeepochs (
int) – The number of epochs to traincontext_features (
bool) – Indicator whether the batch should include biological context features.drug_features (
bool) – Indicator whether the batch should include drug features.drug_molecules (
bool) – Indicator whether the batch should include drug moleculestrain_size (
Optional[float]) – The ratio of training triples. Default is 0.8 if None is passed.random_state (
Optional[int]) – The random seed for splitting the triples. Default is 42. Set to none for no fixed seed.metrics (
Optional[Sequence[str]]) – The list of metrics to use.
- Return type
- Returns
A result object with the trained model and evaluation results