"""A module that contains various ways of applying a kNN.
"""
from __future__ import annotations
import torch
import faiss
import faiss.contrib.torch_utils
from .torchutils import get_groupby_indices
default_device = "cuda" if torch.cuda.is_available() else "cpu"
[docs]def build_edges_exatrkx(
query: torch.Tensor,
database: torch.Tensor,
query_indices: torch.Tensor | None = None,
r_max: float = 1.0,
k_max: int = 10,
device: str | torch.device | None = None,
) -> torch.Tensor:
"""
NOTE: These KNN/FRNN algorithms return the distances**2.
Therefore we need to be careful when comparing them to the target distances
(r_val, r_test), and to the margin parameter (which is L1 distance)
"""
if device is None:
device = default_device
elif str(device) == "cuda:0":
device = "cuda"
distances: torch.Tensor
start_indices: torch.Tensor
end_indices: torch.Tensor
if str(device) == "cuda":
res = faiss.StandardGpuResources()
distances, end_indices = faiss.knn_gpu(res=res, xq=query, xb=database, k=k_max) # type: ignore
elif str(device) == "cpu":
index = faiss.IndexFlatL2(database.shape[1])
index.add(database) # type: ignore
distances, end_indices = index.search(query, k_max) # type: ignore
else:
raise ValueError(f"Device {device} is not recognised.")
start_indices = torch.Tensor.repeat(
torch.arange(end_indices.shape[0], device=device, dtype=torch.long),
(end_indices.shape[1], 1),
).T
edge_list = torch.stack(
(
start_indices[distances <= r_max**2],
end_indices[distances <= r_max**2],
)
)
# Reset indices subset to correct global index
if query_indices is not None:
edge_list[0] = query_indices[edge_list[0]]
# Remove self-loops
edge_list = edge_list[:, edge_list[0] != edge_list[1]]
return edge_list
[docs]def build_edges_faiss(
start_coords: torch.Tensor,
end_coords: torch.Tensor,
k_max: int,
squared_distance_max: float,
res: faiss.StandardGpuResources | None = None,
enforce_cpu: bool = False,
) -> torch.Tensor:
"""Apply a kNN using ``faiss``.
The CPU execution is much, much faster than the GPU one in
our case.
Args:
start_coords: Coordinates of the starting points to connect to
end points
end_coords: Coordinates of the points that can be connected to
starting points
k_max: Maximum number of neighbours to connect to each
starting point.
squared_distance_max: Maximum distance for a two points to be considered
as neighbours
res: Faiss GPU resource
Returns:
Edge indices built by the kNN.
"""
start_coords_on_gpu = start_coords.device.type == "cuda"
use_gpu = start_coords_on_gpu and not enforce_cpu
if enforce_cpu:
start_coords = start_coords.cpu()
end_coords = end_coords.cpu()
if res is None and use_gpu:
res = faiss.StandardGpuResources()
# Apply kNN
index = (
faiss.GpuIndexFlatL2(res, end_coords.shape[1])
if use_gpu
else faiss.IndexFlatL2(end_coords.shape[1])
)
index.add(end_coords) # type: ignore
distances, end_indices = index.search(start_coords, k_max) # type: ignore
if enforce_cpu and start_coords_on_gpu:
# Put back tensors on GPU
distances = distances.cuda()
end_indices = end_indices.cuda()
# Build array of start indices
start_indices = torch.Tensor.repeat(
torch.arange(
end_indices.shape[0], device=end_indices.device, dtype=end_indices.dtype
),
(end_indices.shape[1], 1),
).T
# Build array of edges
distance_mask = distances < squared_distance_max
return torch.stack((start_indices[distance_mask], end_indices[distance_mask]))
[docs]def build_edges_plane_by_plane(
coords: torch.Tensor,
planes: torch.Tensor,
k_max: int,
squared_distance_max: float,
n_planes: int,
plane_range: int | None = None,
start_coords: torch.Tensor | None = None,
start_planes: torch.Tensor | None = None,
start_indices: torch.Tensor | None = None,
enforce_cpu: bool = True,
):
"""Build edges by applying a kNN for each plane,
to build edges between this plane and the next ``plane_range``.
The loop over the planes is sequential but this is not a
requirement.
Args:
coords: 2D tensor of (embedded) coordinates of the points
to apply the kNN on. The points must be sorted by plane number.
planes: 1D tensor of plane number for each point
plane_range: Maximum number of planes 2 connected points can be separated by.
A ``None`` value means that the plane_range is infinite
k_max: Maximum number of neighbours to connect to each
starting point.
squared_distance_max: Maximum distance for a two points to be considered
as neighbours
start_coords: the coordinates of the starting points, in the case where there
are different from ``coords``
start_planes: the coordinates of the starting planes, in the case where
``start_coords`` provided.
start_indices: the corresponding indices in ``coords`` of ``start_coords``
and ``start_planes``. If provided, it is not necessary to provide neither
``start_coords`` nor ``start_planes``.
Returns:
2D tensor of edge indices built by the kNNs.
"""
use_gpu = coords.device.type == "cuda" and not enforce_cpu
res = faiss.StandardGpuResources() if use_gpu else None
# Assert planes is sorted
assert torch.all(planes[:-1] <= planes[1:]), "`planes` is not sorted."
# Groupby plane
expected_unique_planes = torch.arange(
n_planes, device=planes.device, dtype=planes.dtype
)
plane_run_lengths = get_groupby_indices(
sorted_tensor=planes,
expected_unique_values=expected_unique_planes,
end_padding=plane_range - 1 if plane_range is not None else 0,
)
# Handle `start_planes`
if start_planes is None:
if start_indices is None:
if start_coords is not None:
raise ValueError(
"`start_indices` and `start_planes` are None but `start_coords` "
"is not None. Thus, `start_planes` cannot be inferred."
)
start_planes = planes
else:
start_planes = planes[start_indices]
# Handle `start_coords`
if start_coords is None:
start_coords = coords if start_indices is None else coords[start_indices]
# Handle `start_plane_run_lengths`
start_plane_run_lengths = (
plane_run_lengths
if start_indices is None
else get_groupby_indices(
sorted_tensor=start_planes,
expected_unique_values=expected_unique_planes,
end_padding=plane_range - 1 if plane_range is not None else 0,
)
)
list_edge_indices = []
for plane in expected_unique_planes[:-1]:
# Current plane
idx_current_start = start_plane_run_lengths[plane]
idx_current_stop = start_plane_run_lengths[plane + 1]
# Indices to delimitate the planes
idx_next_start = plane_run_lengths[plane + 1]
idx_next_stop = (
plane_run_lengths[plane + 1 + plane_range]
if plane_range is not None
else plane_run_lengths[-1]
)
# Coordinates of the current plane
start_coords_plane = start_coords[idx_current_start:idx_current_stop]
# Coordinates of the next `plane_range`
end_coords_plane = coords[idx_next_start:idx_next_stop]
# Build edges between `start_coords` and `end_coords`
edge_indices = build_edges_faiss(
start_coords=start_coords_plane,
end_coords=end_coords_plane,
squared_distance_max=squared_distance_max,
k_max=k_max,
res=res,
enforce_cpu=enforce_cpu,
)
# Increment the hit indices in ``edge_indices``
# as the kNN was applied to a splice of coordinates
edge_indices[0] += idx_current_start
edge_indices[1] += idx_next_start
list_edge_indices.append(edge_indices)
edge_indices = torch.concat(list_edge_indices, dim=1)
if start_indices is not None:
edge_indices[0] = start_indices[edge_indices[0]]
return edge_indices