Source code for pipeline.utils.modelutils.batches
"""A module used to handle list of batches stored in model.
"""
from __future__ import annotations
import typing
import numpy as np
from torch_geometric.data import Data
from .basemodel import ModelBase
[docs]def get_batches(model: ModelBase, partition: str) -> typing.List[Data]:
"""Get the list batches for the given model.
Args:
model: PyTorch model inheriting from :py:class:`ModelBase`
partition: ``train``, ``val``, ``test`` (for the current already loaded
test sample) or the name of a test dataset
Returns:
List of PyTorch Geometric data objects
Notes:
The input directories are saved as hyperparameters in the model. This is why
it is possible to get the data input directories from a model.
"""
# Use correct batches
if partition == "train":
batches = model.trainset
elif partition == "val":
batches = model.valset
elif partition == "test":
batches = model.testset
else:
model.load_testset(test_dataset_name=partition)
batches = model.testset
assert (
batches is not None
), "Error, list of batches is `None`: no batches were loaded"
return batches
[docs]def select_subset(
batches: typing.List[Data], n_events: int | None = None, seed: int | None = None
) -> typing.List[Data]:
"""Randomly select a subset of batches.
Args:
batches: overall list of batches
n_events: Maximal number of events to select
seed: Seed for reproducible randomness
Returns:
List of PyTorch Data objects
"""
if n_events is not None:
n_events = int(n_events)
if n_events < len(batches):
# Randomly select a subset of ``n_events`` events
rng = np.random.default_rng(seed=seed)
indices = rng.choice(len(batches), n_events, replace=False)
batches = [batches[idx] for idx in indices]
return batches