Source code for pipeline.utils.graphutils.edgebuilding

"""A module that allows to build edges in various ways.
"""

import torch

from .torchutils import get_groupby_indices


[docs]def get_random_pairs_plane_by_plane( n_random: int, planes: torch.Tensor, query_indices: torch.Tensor, n_planes: int, plane_range: int | None = None, ) -> torch.Tensor: """Build random edges from query hits. Args: n_random: Number of random pairs by query point planes: 1D tensor of planes of all the points query_indices: indices of the query points plane_range: for each plane, random edges will be drawn from the query point to a points belonging to one of the next ``plane_range`` planes. A ``None`` means that the plane range that is considered is infinite. Returns: Random edge indices drawn from query hits. """ query_planes = planes[query_indices] if query_indices.numel(): # Randomly select query points indices_for_source = torch.randint( low=0, high=query_indices.shape[0], size=(n_random * query_indices.shape[0],), device=query_indices.device, ) indices_for_source = torch.sort(indices_for_source).values source_query_indices = query_indices[indices_for_source] source_query_planes = query_planes[indices_for_source] # Group by edges unique_source_query_planes, source_query_plane_counts = torch.unique( source_query_planes, return_counts=True ) destination_plane_run_lengths = get_groupby_indices( sorted_tensor=planes, expected_unique_values=torch.arange( n_planes, device=planes.device, dtype=planes.dtype ), end_padding=plane_range if plane_range is not None else 0, ) list_destination_indices = [] for source_query_plane, source_query_plane_count in zip( unique_source_query_planes, source_query_plane_counts ): # Indices to delimitate the next `query_plane` planes idx_next_start = destination_plane_run_lengths[source_query_plane + 1] idx_next_stop = ( destination_plane_run_lengths[source_query_plane + plane_range + 1] if plane_range is not None else destination_plane_run_lengths[-1] ) if idx_next_start == idx_next_stop: destination_indices_plane = torch.full( size=(source_query_plane_count,), fill_value=-1, device=query_indices.device, dtype=torch.int64, ) else: destination_indices_plane = torch.randint( low=idx_next_start, # type: ignore high=idx_next_stop, # type: ignore size=(source_query_plane_count,), device=query_indices.device, ) list_destination_indices.append(destination_indices_plane) destination_indices = torch.cat(list_destination_indices, dim=0) random_edge_indices = torch.stack((source_query_indices, destination_indices)) # Remove edge indices that are not valid random_edge_indices = random_edge_indices[:, random_edge_indices[1] != -1] else: random_edge_indices = torch.tensor( [], dtype=torch.long, device=query_indices.device ).reshape(2, 0) return random_edge_indices