Source code for pipeline.Processing.splitting

"""A module that allows to handle the splitting of the overall dataset into a train
and a validation set.
"""
import typing
import numpy as np


[docs]def randomly_split_list( list_values: list, sizes: typing.List[int], seed: int | None = None, ) -> typing.List[list]: """Split a list into sub-lists of given sizes, without repetition. The total size may be smaller that the size of the original list. Args: list_values: list to split sizes: list of the sizes of the list to produce seed: random seed Returns: Splitted list """ total_size = sum(sizes) assert total_size <= len(list_values), ( f"{total_size} elements were requested, but only {len(list_values)} are " "available." ) rng = np.random.default_rng(seed=seed) list_values_shuffled = rng.choice(list_values, size=total_size, replace=False) cumulative_sizes = np.cumsum(sizes).tolist() return [ list_values_shuffled[start_idx:end_idx].tolist() for start_idx, end_idx in zip([0] + cumulative_sizes[:-1], cumulative_sizes) ]