Source code for pipeline.utils.graphutils.edgebuilding
"""A module that allows to build edges in various ways.
"""
import torch
from .torchutils import get_groupby_indices
[docs]def get_random_pairs_plane_by_plane(
n_random: int,
planes: torch.Tensor,
query_indices: torch.Tensor,
n_planes: int,
plane_range: int | None = None,
) -> torch.Tensor:
"""Build random edges from query hits.
Args:
n_random: Number of random pairs by query point
planes: 1D tensor of planes of all the points
query_indices: indices of the query points
plane_range: for each plane, random edges will be drawn from the query point
to a points belonging to one of the next ``plane_range`` planes.
A ``None`` means that the plane range that is considered is infinite.
Returns:
Random edge indices drawn from query hits.
"""
query_planes = planes[query_indices]
if query_indices.numel():
# Randomly select query points
indices_for_source = torch.randint(
low=0,
high=query_indices.shape[0],
size=(n_random * query_indices.shape[0],),
device=query_indices.device,
)
indices_for_source = torch.sort(indices_for_source).values
source_query_indices = query_indices[indices_for_source]
source_query_planes = query_planes[indices_for_source]
# Group by edges
unique_source_query_planes, source_query_plane_counts = torch.unique(
source_query_planes, return_counts=True
)
destination_plane_run_lengths = get_groupby_indices(
sorted_tensor=planes,
expected_unique_values=torch.arange(
n_planes, device=planes.device, dtype=planes.dtype
),
end_padding=plane_range if plane_range is not None else 0,
)
list_destination_indices = []
for source_query_plane, source_query_plane_count in zip(
unique_source_query_planes, source_query_plane_counts
):
# Indices to delimitate the next `query_plane` planes
idx_next_start = destination_plane_run_lengths[source_query_plane + 1]
idx_next_stop = (
destination_plane_run_lengths[source_query_plane + plane_range + 1]
if plane_range is not None
else destination_plane_run_lengths[-1]
)
if idx_next_start == idx_next_stop:
destination_indices_plane = torch.full(
size=(source_query_plane_count,),
fill_value=-1,
device=query_indices.device,
dtype=torch.int64,
)
else:
destination_indices_plane = torch.randint(
low=idx_next_start, # type: ignore
high=idx_next_stop, # type: ignore
size=(source_query_plane_count,),
device=query_indices.device,
)
list_destination_indices.append(destination_indices_plane)
destination_indices = torch.cat(list_destination_indices, dim=0)
random_edge_indices = torch.stack((source_query_indices, destination_indices))
# Remove edge indices that are not valid
random_edge_indices = random_edge_indices[:, random_edge_indices[1] != -1]
else:
random_edge_indices = torch.tensor(
[], dtype=torch.long, device=query_indices.device
).reshape(2, 0)
return random_edge_indices