Source code for pipeline.utils.graphutils.torchutils

"""A python module that defines utilities using with PyTorch computation.
"""
from __future__ import annotations
import torch


[docs]def get_groupby_indices( sorted_tensor: torch.Tensor, expected_unique_values: torch.Tensor | None = None, end_padding: int = 0, ) -> torch.Tensor: """Get the array of grouping indices. Args: sorted_tensor: A tensor of sorted values. expected_unique_values: The expected unique values in the sorted tensor. This allows to consider missing values in ``sorted_tensor`` end_padding: this parameter allows to append to the returned 1D tensor the last value ``end_padding`` times. Returns: 1D tensor of grouping indices, i.e., the indices of the starts of a new group in ``sorted_tensor``, starting from 0, and including the last value ``len(sorted_tensor)``. This tensor allows to loop over slices of unique values of ``sorted_tensor``. """ unique_values, unique_value_counts = torch.unique(sorted_tensor, return_counts=True) if expected_unique_values is not None: if not torch.all(torch.isin(unique_values, expected_unique_values)): raise ValueError( "The tensor `sorted_tensor` contain values that are not in " "the tensor `expected_unique_values`." ) if not torch.all(torch.isin(expected_unique_values, unique_values)): value_counts = torch.zeros( expected_unique_values.shape, dtype=unique_value_counts.dtype, device=unique_value_counts.device, ) value_counts[ torch.isin(expected_unique_values, unique_values) ] = unique_value_counts else: value_counts = unique_value_counts else: value_counts = unique_value_counts run_lengths = torch.zeros( (value_counts.shape[0] + 1 + end_padding,), dtype=value_counts.dtype, device=value_counts.device, ) run_lengths[1 : value_counts.shape[0] + 1] = torch.cumsum(value_counts, dim=0) run_lengths[value_counts.shape[0] + 1 :] = run_lengths[value_counts.shape[0]] return run_lengths
[docs]def scatter_reduce( src: torch.Tensor, index: torch.Tensor, reduce: str, dim_size: int, ) -> torch.Tensor: """A scatter reduce for ``dim=0`` that works with ONNX export. It uses the experimental function :py:func:`torch.scatter_reduce`. The arguments match the ones of the pytorch-scatter library, with a supplementary argument ``reduce`` (``sum``, ``prod``, ``mean``, ``amax`` or ``amin``) that allows to choose the reduction. """ return torch.scatter_reduce( input=torch.zeros((dim_size, src.shape[1]), dtype=src.dtype, device=src.device), dim=0, index=index.unsqueeze(0).expand(src.shape[1], -1).T, src=src, reduce=reduce, include_self=True, # For support of ONNX export )