Source code for pipeline.Processing.splitting
"""A module that allows to handle the splitting of the overall dataset into a train
and a validation set.
"""
import typing
import numpy as np
[docs]def randomly_split_list(
list_values: list,
sizes: typing.List[int],
seed: int | None = None,
) -> typing.List[list]:
"""Split a list into sub-lists of given sizes, without repetition. The total size
may be smaller that the size of the original list.
Args:
list_values: list to split
sizes: list of the sizes of the list to produce
seed: random seed
Returns:
Splitted list
"""
total_size = sum(sizes)
assert total_size <= len(list_values), (
f"{total_size} elements were requested, but only {len(list_values)} are "
"available."
)
rng = np.random.default_rng(seed=seed)
list_values_shuffled = rng.choice(list_values, size=total_size, replace=False)
cumulative_sizes = np.cumsum(sizes).tolist()
return [
list_values_shuffled[start_idx:end_idx].tolist()
for start_idx, end_idx in zip([0] + cumulative_sizes[:-1], cumulative_sizes)
]