Source code for pipeline.utils.loaderutils.dataiterator

"""Implement a general data loader that does not load all the data into
memory, in order to deal with large datasets.
"""
from __future__ import annotations
import torch
from torch.utils.data import Dataset
from torch_geometric.data import Data

from .pathandling import get_input_paths


[docs]class LazyDatasetBase(Dataset): def __init__( self, input_dir: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, **kwargs, ): self.input_dir = str(input_dir) self.input_paths = get_input_paths( input_dir=self.input_dir, n_events=n_events, shuffle=shuffle, seed=seed, ) self.fetch_dataset_kwargs = kwargs def __len__(self) -> int: """Number of input files""" return len(self.input_paths)
[docs] def fetch_dataset(self, input_path: str, map_location: str = "cpu", **kwargs): """Load and process one PyTorch DataSet. Args: input_path: path to the PyTorch dataset map_location: location where to load the dataset **kwargs: Other keyword arguments passed to :py:func:`torch.load` Returns: Load PyTorch data object """ fetch_dataset_kwargs = self.fetch_dataset_kwargs.copy() map_location_kwargs = fetch_dataset_kwargs.pop("map_location", None) return torch.load( input_path, map_location=( map_location_kwargs if map_location_kwargs is not None else map_location ), **fetch_dataset_kwargs, **kwargs, )
def __getitem__(self, idx: int) -> Data: input_path = self.input_paths[idx] dataset = self.fetch_dataset(input_path=input_path) return dataset