Source code for pipeline.utils.graphutils.torchutils
"""A python module that defines utilities using with PyTorch computation.
"""
from __future__ import annotations
import torch
[docs]def get_groupby_indices(
sorted_tensor: torch.Tensor,
expected_unique_values: torch.Tensor | None = None,
end_padding: int = 0,
) -> torch.Tensor:
"""Get the array of grouping indices.
Args:
sorted_tensor: A tensor of sorted values.
expected_unique_values: The expected unique values in the sorted
tensor.
This allows to consider missing values in ``sorted_tensor``
end_padding: this parameter allows to append to the returned 1D tensor
the last value ``end_padding`` times.
Returns:
1D tensor of grouping indices, i.e., the indices of the starts
of a new group in ``sorted_tensor``, starting from 0, and including
the last value ``len(sorted_tensor)``.
This tensor allows to loop over slices of unique values of ``sorted_tensor``.
"""
unique_values, unique_value_counts = torch.unique(sorted_tensor, return_counts=True)
if expected_unique_values is not None:
if not torch.all(torch.isin(unique_values, expected_unique_values)):
raise ValueError(
"The tensor `sorted_tensor` contain values that are not in "
"the tensor `expected_unique_values`."
)
if not torch.all(torch.isin(expected_unique_values, unique_values)):
value_counts = torch.zeros(
expected_unique_values.shape,
dtype=unique_value_counts.dtype,
device=unique_value_counts.device,
)
value_counts[
torch.isin(expected_unique_values, unique_values)
] = unique_value_counts
else:
value_counts = unique_value_counts
else:
value_counts = unique_value_counts
run_lengths = torch.zeros(
(value_counts.shape[0] + 1 + end_padding,),
dtype=value_counts.dtype,
device=value_counts.device,
)
run_lengths[1 : value_counts.shape[0] + 1] = torch.cumsum(value_counts, dim=0)
run_lengths[value_counts.shape[0] + 1 :] = run_lengths[value_counts.shape[0]]
return run_lengths
[docs]def scatter_reduce(
src: torch.Tensor,
index: torch.Tensor,
reduce: str,
dim_size: int,
) -> torch.Tensor:
"""A scatter reduce for ``dim=0`` that works with ONNX export.
It uses the experimental function :py:func:`torch.scatter_reduce`.
The arguments match the ones of the pytorch-scatter library, with a supplementary
argument ``reduce`` (``sum``, ``prod``, ``mean``, ``amax`` or ``amin``) that
allows to choose the reduction.
"""
return torch.scatter_reduce(
input=torch.zeros((dim_size, src.shape[1]), dtype=src.dtype, device=src.device),
dim=0,
index=index.unsqueeze(0).expand(src.shape[1], -1).T,
src=src,
reduce=reduce,
include_self=True, # For support of ONNX export
)