"""A module that defines functions to go from triplets to tracks.
"""
from __future__ import annotations
import typing
import torch
try:
import cupy as cp
except ImportError:
cp = None
try:
import cudf
except ImportError:
cudf = None
from . import triplets
from .components import connected_components
from utils.tools import tarray
[docs]def split_triplet_indices_two_connected_components(
triplet_indices: typing.Dict[str, torch.Tensor],
edge_index: torch.Tensor | None = None,
strategy: str | None = None,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Split triplets into 2 categories:
* Triplets to merge now with a first connected component algorithm
* Triplets to merge after removing the duplicate triplets, after the first \
connected component algorithm.
Args:
triplet_indices: dictionary that associates a type of triplet with
the triplet indices
edge_index: edge indices. Only used in the ``no_multiple_central_hit``
strategy
strategy: splitting strategy to use, which correspond to which articulations
to merge in the first connected component algorithm.
All the strategies are expected to produce the very same result.
Default is ``no_multiple_edge``.
Returns:
Triplet indices to merge by the first connected component algorithm,
and triplet indices to merge the second one.
"""
if strategy is None:
strategy = "2cc_no_multiple_edge"
if strategy == "2cc_without_articulation":
triplet_index_to_merge_after_connected_components = triplet_indices[
"articulation"
]
triplet_index_to_merge_now = torch.cat(
(
triplet_indices["elbow_left"],
triplet_indices["elbow_right"],
),
dim=-1,
)
else:
if strategy == "2cc_no_multiple_central_hit":
assert (
edge_index is not None
), "`no_multiple_central_hit` strategy required `edge_index`."
triplet_central_hit_idx_counts = tarray.count_occurences(
edge_index[1][triplet_indices["articulation"][0]]
)
mask_triplet_to_merge_now = triplet_central_hit_idx_counts == 1
elif strategy == "2cc_no_multiple_edge":
triplet_left_counts = tarray.count_occurences(
triplet_indices["articulation"][0]
)
triplet_right_counts = tarray.count_occurences(
triplet_indices["articulation"][1]
)
mask_triplet_to_merge_now = (triplet_left_counts == 1) & (
triplet_right_counts == 1
)
else:
raise ValueError("`strategy` not recognised.")
triplet_index_to_merge_now = torch.cat(
(
triplet_indices["articulation"][:, mask_triplet_to_merge_now],
triplet_indices["elbow_left"],
triplet_indices["elbow_right"],
),
dim=-1,
)
triplet_index_to_merge_after_connected_components = triplet_indices[
"articulation"
][:, ~mask_triplet_to_merge_now]
return (
triplet_index_to_merge_now,
triplet_index_to_merge_after_connected_components,
)
if cp is not None and cudf is not None:
CuDfDataFrame = cudf.DataFrame
from .components import cure_max_node_idx
[docs] def build_tracks_from_triplets_2cc(
triplet_indices: typing.Dict[str, torch.Tensor],
edge_index: torch.Tensor,
strategy: str | None = None,
) -> CuDfDataFrame:
import cugraph
(
triplet_index,
forked_triplet_index,
) = split_triplet_indices_two_connected_components(
triplet_indices=triplet_indices,
edge_index=edge_index,
strategy=strategy,
)
# Weakly Connected Component Algorithm applied on edges (instead of hits)
# Multiple triplets sharing the same central hit are not considered in this
# first step
df_triplet_index = cudf.DataFrame(
{
"source": cp.asarray(triplet_index[0]),
"destination": cp.asarray(triplet_index[1]),
}
)
graph = cugraph.from_cudf_edgelist(df_triplet_index, renumber=False)
df_labels = cure_max_node_idx(
df_labels=cugraph.weakly_connected_components(graph), # type: ignore
max_node_idx=edge_index.shape[1] - 1,
).rename(
columns={
"vertex": "edge_idx",
"labels": "track_id",
},
)
# Renumber the track IDs so that they go from 0 to the number of tracks - 1
_, indices = cp.unique(df_labels["track_id"].to_cupy(), return_inverse=True) # type: ignore
df_labels["track_id"] = indices # type: ignore
# Now let's handle the remaining triplets
df_remaining_triplets = cudf.DataFrame(
{
"track_id_left": cp.asarray(forked_triplet_index[0]),
"track_id_right": cp.asarray(forked_triplet_index[1]),
"central_hit_idx": cp.asarray(edge_index[1][forked_triplet_index[0]]),
}
)
# Add `track_id` information
# This allow to transform a "triplet", that is, a link between 2 edges,
# into a link between 2 tracks.
df_remaining_triplets = triplets.update_dataframe_track_links(
df_track_links=df_remaining_triplets,
df_new_labels=df_labels.rename(columns={"edge_idx": "old_track_id"}), # type: ignore
)
# Transform triplets into links between track IDs
# Remove duplicate links
df_track_links = df_remaining_triplets.drop_duplicates(
["central_hit_idx", "track_id_left", "track_id_right"]
)
# First consider links that are effectively not a fork
df_track_links["fork"] = tarray.count_occurences( # type: ignore
df_track_links["central_hit_idx"].to_cupy() # type: ignore
)
df_track_links_wo_fork = df_track_links[df_track_links["fork"] == 1][ # type: ignore
["track_id_left", "track_id_right"]
]
df_track_links_with_fork = df_track_links[df_track_links["fork"] >= 2][ # type: ignore
["track_id_left", "track_id_right"]
]
if df_track_links_wo_fork.shape[0]: # type: ignore
df_new_labels = cure_max_node_idx(
df_labels=cugraph.weakly_connected_components( # type: ignore
cugraph.from_cudf_edgelist(
df_track_links_wo_fork,
source="track_id_left",
destination="track_id_right",
renumber=False,
)
),
max_node_idx=df_labels["track_id"].max(), # type: ignore
).rename(
columns={
"vertex": "old_track_id",
"labels": "track_id",
}
)
# affect the new labels to `df_labels`
df_labels = triplets.update_dataframe_connected_edges(
df_connected_edges=df_labels, # type: ignore
df_new_labels=df_new_labels, # type: ignore
)
# affect the new labels to `df_track_links_with_fork`
df_track_links_with_fork = triplets.update_dataframe_track_links(
df_track_links=df_track_links_with_fork, # type: ignore
df_new_labels=df_new_labels, # type: ignore
)
df_track_links_with_fork = df_track_links_with_fork.drop_duplicates(
["track_id_left", "track_id_right"]
)
if df_track_links_with_fork.shape[0]: # type: ignore
df_labels = triplets.forked_track_links_to_tracks(
df_track_links_with_fork=df_track_links_with_fork, # type: ignore
df_connected_edges=df_labels, # type: ignore
)
df_tracks = triplets.connected_edges_to_connected_hits(
edge_index=edge_index, df_connected_edges=df_labels # type: ignore
)
return df_tracks # type: ignore
[docs]def split_articulations(
triplet_index: torch.Tensor,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
# Split triplets into 2 categories
# - triplets to merge know
# - triplets that may need copy of hits
# Triplets that may need copy are articulations that share an edge
triplet_left_counts = tarray.count_occurences(triplet_index[0])
triplet_right_counts = tarray.count_occurences(triplet_index[1])
mask_triplet_to_merge_now = (triplet_left_counts == 1) & (triplet_right_counts == 1)
return (
triplet_index[:, mask_triplet_to_merge_now],
triplet_index[:, ~mask_triplet_to_merge_now],
)
[docs]def connect_elbows(
triplet_indices: typing.Dict[str, torch.Tensor],
n_edges: int,
) -> typing.Tuple[tarray.DataFrame, torch.Tensor]:
"""Connect the left and right elbows.
Args:
triplet_indices: dictionary that associates a type of triplet with
the triplet indices
n_edges: number of edges
Returns:
Tuple of a dataframe and a tensor.
The dataframe defines the the small connected components formed by connecting
the elbows, through its two columns ``edge_idx`` and ``track_id``.
The tensor corresponds to the articulation link indices between these
connected components.
"""
use_cuda = triplet_indices["articulation"].device.type == "cuda"
dtype = triplet_indices["articulation"].dtype
np_or_cp = tarray.get_numpy_or_cupy(use_cuda=use_cuda)
# First merge left elbows, then right elbows
# Assign to each edge the smallest edge it is connected to by an elbow
old_to_new_indices = torch.arange(
n_edges,
device=triplet_indices["articulation"].device,
dtype=dtype,
)
for triplet_name in ["elbow_left", "elbow_right"]:
triplet_index = triplet_indices.pop(triplet_name)
df_triplet = tarray.to_dataframe(
{
"triplet_idx_left": triplet_index[0],
"triplet_idx_right": triplet_index[1],
},
use_cuda=use_cuda,
)
df_triplet["triplet_idx_min"] = df_triplet[
["triplet_idx_left", "triplet_idx_right"]
].min( # type: ignore
axis=1
)
df_triplet["triplet_idx_max"] = df_triplet[
["triplet_idx_left", "triplet_idx_right"]
].max( # type: ignore
axis=1
)
# Assign to `triplet_idx_right` the smallest value it is connected to
# or itself
df_new_labels = (
df_triplet.groupby("triplet_idx_max")["triplet_idx_min"] # type: ignore
.min()
.reset_index()
.rename(columns={"triplet_idx_max": "edge_idx", "triplet_idx_min": "label"})
)
old_to_new_indices[
tarray.series_to_tensor(df_new_labels["edge_idx"]).to(dtype=dtype)
] = tarray.series_to_tensor(df_new_labels["label"]).to(dtype=dtype)
# Assign new indices to other triplet indices
triplet_indices = {
triplet_name: old_to_new_indices[triplet_index]
for triplet_name, triplet_index in triplet_indices.items()
}
articulation_triplet_index = triplet_indices.pop("articulation")
df_labels = tarray.to_dataframe(
{
"edge_idx": np_or_cp.arange(old_to_new_indices.shape[0]),
"track_id": old_to_new_indices,
},
use_cuda=use_cuda,
)
# Renumber the track IDs so that they go from 0 to the number of tracks - 1
if df_labels.shape[0]:
_, indices = np_or_cp.unique( # type: ignore
tarray.series_to_array(df_labels["track_id"]), return_inverse=True
)
df_labels["track_id"] = indices
# Also propagate this renumbering into `articulation_triplet_index`
edge_idx_to_track_id = tarray.array_to_tensor(indices)
articulation_triplet_index = edge_idx_to_track_id[articulation_triplet_index]
# Remove duplicate articulations due to elbow connections
articulation_triplet_index = torch.unique(articulation_triplet_index, dim=1)
return df_labels, articulation_triplet_index
[docs]def filter_single_edges(
df_labels: tarray.DataFrame,
edge_score: torch.Tensor,
edge_score_cut: float = 0.7,
):
use_cuda = tarray.get_use_cuda_from_dataframe(df_labels)
cp_or_np = tarray.get_numpy_or_cupy(use_cuda)
df_labels = df_labels.merge(
df_labels.groupby("track_id")["edge_idx"] # type: ignore
.count()
.rename("n_edges")
.reset_index(),
how="left",
on="track_id",
)
df_labels = df_labels.merge(
tarray.to_dataframe( # type: ignore
{
"edge_idx": cp_or_np.arange(edge_score.shape[0]),
"edge_score": edge_score,
},
use_cuda=use_cuda,
),
on=["edge_idx"],
how="left",
)
return df_labels[
(df_labels["n_edges"] >= 2) | (df_labels["edge_score"] >= edge_score_cut) # type: ignore
]
[docs]def build_tracks_from_triplets_1cc(
triplet_indices: typing.Dict[str, torch.Tensor],
edge_index: torch.Tensor,
edge_score: torch.Tensor | None = None,
single_edge_score_cut: float | None = None,
) -> tarray.DataFrame:
use_cuda = edge_index.device.type == "cuda"
df_labels, articulation_triplet_index = connect_elbows(
triplet_indices=triplet_indices,
n_edges=edge_index.shape[1],
)
not_forked_triplet_index, forked_triplet_index = split_articulations(
triplet_index=articulation_triplet_index
)
# Weakly Connected Component Algorithm applied on edges (instead of hits)
# Multiple triplets that fork
df_not_forked_triplet_index = tarray.to_dataframe(
{
"source": not_forked_triplet_index[0],
"destination": not_forked_triplet_index[1],
},
use_cuda=use_cuda,
)
df_new_labels = connected_components(
df_edges=df_not_forked_triplet_index, max_node_idx=df_labels["track_id"].max() # type: ignore
).rename(
columns={
"vertex": "old_track_id",
"labels": "track_id",
},
)
df_labels = triplets.update_dataframe_connected_edges(
df_connected_edges=df_labels,
df_new_labels=df_new_labels,
)
# Now let's handle the forked triplets
df_remaining_triplets = tarray.to_dataframe(
{
"track_id_left": forked_triplet_index[0],
"track_id_right": forked_triplet_index[1],
},
use_cuda=use_cuda,
)
df_remaining_triplets = triplets.update_dataframe_track_links(
df_track_links=df_remaining_triplets,
df_new_labels=df_new_labels,
)
# Remove duplicate links
df_track_links_with_fork = df_remaining_triplets.drop_duplicates(
["track_id_left", "track_id_right"]
)
if df_track_links_with_fork.shape[0]: # type: ignore
df_labels = triplets.forked_track_links_to_tracks(
df_track_links_with_fork=df_track_links_with_fork,
df_connected_edges=df_labels,
)
if edge_score is not None:
assert single_edge_score_cut is not None
df_labels = filter_single_edges(
df_labels,
edge_score=edge_score,
edge_score_cut=single_edge_score_cut,
)
return triplets.connected_edges_to_connected_hits(
edge_index=edge_index,
df_connected_edges=df_labels, # type: ignore
)