Source code for pipeline.utils.modelutils.batches

"""A module used to handle list of batches stored in model.
"""
from __future__ import annotations
import typing
import numpy as np
from torch_geometric.data import Data
from .basemodel import ModelBase


[docs]def get_batches(model: ModelBase, partition: str) -> typing.List[Data]: """Get the list batches for the given model. Args: model: PyTorch model inheriting from :py:class:`ModelBase` partition: ``train``, ``val``, ``test`` (for the current already loaded test sample) or the name of a test dataset Returns: List of PyTorch Geometric data objects Notes: The input directories are saved as hyperparameters in the model. This is why it is possible to get the data input directories from a model. """ # Use correct batches if partition == "train": batches = model.trainset elif partition == "val": batches = model.valset elif partition == "test": batches = model.testset else: model.load_testset(test_dataset_name=partition) batches = model.testset assert ( batches is not None ), "Error, list of batches is `None`: no batches were loaded" return batches
[docs]def select_subset( batches: typing.List[Data], n_events: int | None = None, seed: int | None = None ) -> typing.List[Data]: """Randomly select a subset of batches. Args: batches: overall list of batches n_events: Maximal number of events to select seed: Seed for reproducible randomness Returns: List of PyTorch Data objects """ if n_events is not None: n_events = int(n_events) if n_events < len(batches): # Randomly select a subset of ``n_events`` events rng = np.random.default_rng(seed=seed) indices = rng.choice(len(batches), n_events, replace=False) batches = [batches[idx] for idx in indices] return batches