Source code for pipeline.utils.loaderutils.dataiterator
"""Implement a general data loader that does not load all the data into
memory, in order to deal with large datasets.
"""
from __future__ import annotations
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data
from .pathandling import get_input_paths
[docs]class LazyDatasetBase(Dataset):
def __init__(
self,
input_dir: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
**kwargs,
):
self.input_dir = str(input_dir)
self.input_paths = get_input_paths(
input_dir=self.input_dir,
n_events=n_events,
shuffle=shuffle,
seed=seed,
)
self.fetch_dataset_kwargs = kwargs
def __len__(self) -> int:
"""Number of input files"""
return len(self.input_paths)
[docs] def fetch_dataset(self, input_path: str, map_location: str = "cpu", **kwargs):
"""Load and process one PyTorch DataSet.
Args:
input_path: path to the PyTorch dataset
map_location: location where to load the dataset
**kwargs: Other keyword arguments passed to :py:func:`torch.load`
Returns:
Load PyTorch data object
"""
fetch_dataset_kwargs = self.fetch_dataset_kwargs.copy()
map_location_kwargs = fetch_dataset_kwargs.pop("map_location", None)
return torch.load(
input_path,
map_location=(
map_location_kwargs if map_location_kwargs is not None else map_location
),
**fetch_dataset_kwargs,
**kwargs,
)
def __getitem__(self, idx: int) -> Data:
input_path = self.input_paths[idx]
dataset = self.fetch_dataset(input_path=input_path)
return dataset