Source code for pipeline.utils.graphutils.knn

"""A module that contains various ways of applying a kNN.
"""
from __future__ import annotations
import torch
import faiss
import faiss.contrib.torch_utils

from .torchutils import get_groupby_indices

default_device = "cuda" if torch.cuda.is_available() else "cpu"


[docs]def build_edges_exatrkx( query: torch.Tensor, database: torch.Tensor, query_indices: torch.Tensor | None = None, r_max: float = 1.0, k_max: int = 10, device: str | torch.device | None = None, ) -> torch.Tensor: """ NOTE: These KNN/FRNN algorithms return the distances**2. Therefore we need to be careful when comparing them to the target distances (r_val, r_test), and to the margin parameter (which is L1 distance) """ if device is None: device = default_device elif str(device) == "cuda:0": device = "cuda" distances: torch.Tensor start_indices: torch.Tensor end_indices: torch.Tensor if str(device) == "cuda": res = faiss.StandardGpuResources() distances, end_indices = faiss.knn_gpu(res=res, xq=query, xb=database, k=k_max) # type: ignore elif str(device) == "cpu": index = faiss.IndexFlatL2(database.shape[1]) index.add(database) # type: ignore distances, end_indices = index.search(query, k_max) # type: ignore else: raise ValueError(f"Device {device} is not recognised.") start_indices = torch.Tensor.repeat( torch.arange(end_indices.shape[0], device=device, dtype=torch.long), (end_indices.shape[1], 1), ).T edge_list = torch.stack( ( start_indices[distances <= r_max**2], end_indices[distances <= r_max**2], ) ) # Reset indices subset to correct global index if query_indices is not None: edge_list[0] = query_indices[edge_list[0]] # Remove self-loops edge_list = edge_list[:, edge_list[0] != edge_list[1]] return edge_list
[docs]def build_edges_faiss( start_coords: torch.Tensor, end_coords: torch.Tensor, k_max: int, squared_distance_max: float, res: faiss.StandardGpuResources | None = None, enforce_cpu: bool = False, ) -> torch.Tensor: """Apply a kNN using ``faiss``. The CPU execution is much, much faster than the GPU one in our case. Args: start_coords: Coordinates of the starting points to connect to end points end_coords: Coordinates of the points that can be connected to starting points k_max: Maximum number of neighbours to connect to each starting point. squared_distance_max: Maximum distance for a two points to be considered as neighbours res: Faiss GPU resource Returns: Edge indices built by the kNN. """ start_coords_on_gpu = start_coords.device.type == "cuda" use_gpu = start_coords_on_gpu and not enforce_cpu if enforce_cpu: start_coords = start_coords.cpu() end_coords = end_coords.cpu() if res is None and use_gpu: res = faiss.StandardGpuResources() # Apply kNN index = ( faiss.GpuIndexFlatL2(res, end_coords.shape[1]) if use_gpu else faiss.IndexFlatL2(end_coords.shape[1]) ) index.add(end_coords) # type: ignore distances, end_indices = index.search(start_coords, k_max) # type: ignore if enforce_cpu and start_coords_on_gpu: # Put back tensors on GPU distances = distances.cuda() end_indices = end_indices.cuda() # Build array of start indices start_indices = torch.Tensor.repeat( torch.arange( end_indices.shape[0], device=end_indices.device, dtype=end_indices.dtype ), (end_indices.shape[1], 1), ).T # Build array of edges distance_mask = distances < squared_distance_max return torch.stack((start_indices[distance_mask], end_indices[distance_mask]))
[docs]def build_edges_plane_by_plane( coords: torch.Tensor, planes: torch.Tensor, k_max: int, squared_distance_max: float, n_planes: int, plane_range: int | None = None, start_coords: torch.Tensor | None = None, start_planes: torch.Tensor | None = None, start_indices: torch.Tensor | None = None, enforce_cpu: bool = True, ): """Build edges by applying a kNN for each plane, to build edges between this plane and the next ``plane_range``. The loop over the planes is sequential but this is not a requirement. Args: coords: 2D tensor of (embedded) coordinates of the points to apply the kNN on. The points must be sorted by plane number. planes: 1D tensor of plane number for each point plane_range: Maximum number of planes 2 connected points can be separated by. A ``None`` value means that the plane_range is infinite k_max: Maximum number of neighbours to connect to each starting point. squared_distance_max: Maximum distance for a two points to be considered as neighbours start_coords: the coordinates of the starting points, in the case where there are different from ``coords`` start_planes: the coordinates of the starting planes, in the case where ``start_coords`` provided. start_indices: the corresponding indices in ``coords`` of ``start_coords`` and ``start_planes``. If provided, it is not necessary to provide neither ``start_coords`` nor ``start_planes``. Returns: 2D tensor of edge indices built by the kNNs. """ use_gpu = coords.device.type == "cuda" and not enforce_cpu res = faiss.StandardGpuResources() if use_gpu else None # Assert planes is sorted assert torch.all(planes[:-1] <= planes[1:]), "`planes` is not sorted." # Groupby plane expected_unique_planes = torch.arange( n_planes, device=planes.device, dtype=planes.dtype ) plane_run_lengths = get_groupby_indices( sorted_tensor=planes, expected_unique_values=expected_unique_planes, end_padding=plane_range - 1 if plane_range is not None else 0, ) # Handle `start_planes` if start_planes is None: if start_indices is None: if start_coords is not None: raise ValueError( "`start_indices` and `start_planes` are None but `start_coords` " "is not None. Thus, `start_planes` cannot be inferred." ) start_planes = planes else: start_planes = planes[start_indices] # Handle `start_coords` if start_coords is None: start_coords = coords if start_indices is None else coords[start_indices] # Handle `start_plane_run_lengths` start_plane_run_lengths = ( plane_run_lengths if start_indices is None else get_groupby_indices( sorted_tensor=start_planes, expected_unique_values=expected_unique_planes, end_padding=plane_range - 1 if plane_range is not None else 0, ) ) list_edge_indices = [] for plane in expected_unique_planes[:-1]: # Current plane idx_current_start = start_plane_run_lengths[plane] idx_current_stop = start_plane_run_lengths[plane + 1] # Indices to delimitate the planes idx_next_start = plane_run_lengths[plane + 1] idx_next_stop = ( plane_run_lengths[plane + 1 + plane_range] if plane_range is not None else plane_run_lengths[-1] ) # Coordinates of the current plane start_coords_plane = start_coords[idx_current_start:idx_current_stop] # Coordinates of the next `plane_range` end_coords_plane = coords[idx_next_start:idx_next_stop] # Build edges between `start_coords` and `end_coords` edge_indices = build_edges_faiss( start_coords=start_coords_plane, end_coords=end_coords_plane, squared_distance_max=squared_distance_max, k_max=k_max, res=res, enforce_cpu=enforce_cpu, ) # Increment the hit indices in ``edge_indices`` # as the kNN was applied to a splice of coordinates edge_indices[0] += idx_current_start edge_indices[1] += idx_next_start list_edge_indices.append(edge_indices) edge_indices = torch.concat(list_edge_indices, dim=1) if start_indices is not None: edge_indices[0] = start_indices[edge_indices[0]] return edge_indices