Source code for pipeline.TrackBuilding.triplets2tracks

"""A module that defines functions to go from triplets to tracks.
"""
from __future__ import annotations
import typing

import torch

try:
    import cupy as cp

except ImportError:
    cp = None

try:
    import cudf

except ImportError:
    cudf = None


from . import triplets
from .components import connected_components
from utils.tools import tarray


[docs]def split_triplet_indices_two_connected_components( triplet_indices: typing.Dict[str, torch.Tensor], edge_index: torch.Tensor | None = None, strategy: str | None = None, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Split triplets into 2 categories: * Triplets to merge now with a first connected component algorithm * Triplets to merge after removing the duplicate triplets, after the first \ connected component algorithm. Args: triplet_indices: dictionary that associates a type of triplet with the triplet indices edge_index: edge indices. Only used in the ``no_multiple_central_hit`` strategy strategy: splitting strategy to use, which correspond to which articulations to merge in the first connected component algorithm. All the strategies are expected to produce the very same result. Default is ``no_multiple_edge``. Returns: Triplet indices to merge by the first connected component algorithm, and triplet indices to merge the second one. """ if strategy is None: strategy = "2cc_no_multiple_edge" if strategy == "2cc_without_articulation": triplet_index_to_merge_after_connected_components = triplet_indices[ "articulation" ] triplet_index_to_merge_now = torch.cat( ( triplet_indices["elbow_left"], triplet_indices["elbow_right"], ), dim=-1, ) else: if strategy == "2cc_no_multiple_central_hit": assert ( edge_index is not None ), "`no_multiple_central_hit` strategy required `edge_index`." triplet_central_hit_idx_counts = tarray.count_occurences( edge_index[1][triplet_indices["articulation"][0]] ) mask_triplet_to_merge_now = triplet_central_hit_idx_counts == 1 elif strategy == "2cc_no_multiple_edge": triplet_left_counts = tarray.count_occurences( triplet_indices["articulation"][0] ) triplet_right_counts = tarray.count_occurences( triplet_indices["articulation"][1] ) mask_triplet_to_merge_now = (triplet_left_counts == 1) & ( triplet_right_counts == 1 ) else: raise ValueError("`strategy` not recognised.") triplet_index_to_merge_now = torch.cat( ( triplet_indices["articulation"][:, mask_triplet_to_merge_now], triplet_indices["elbow_left"], triplet_indices["elbow_right"], ), dim=-1, ) triplet_index_to_merge_after_connected_components = triplet_indices[ "articulation" ][:, ~mask_triplet_to_merge_now] return ( triplet_index_to_merge_now, triplet_index_to_merge_after_connected_components, )
if cp is not None and cudf is not None: CuDfDataFrame = cudf.DataFrame from .components import cure_max_node_idx
[docs] def build_tracks_from_triplets_2cc( triplet_indices: typing.Dict[str, torch.Tensor], edge_index: torch.Tensor, strategy: str | None = None, ) -> CuDfDataFrame: import cugraph ( triplet_index, forked_triplet_index, ) = split_triplet_indices_two_connected_components( triplet_indices=triplet_indices, edge_index=edge_index, strategy=strategy, ) # Weakly Connected Component Algorithm applied on edges (instead of hits) # Multiple triplets sharing the same central hit are not considered in this # first step df_triplet_index = cudf.DataFrame( { "source": cp.asarray(triplet_index[0]), "destination": cp.asarray(triplet_index[1]), } ) graph = cugraph.from_cudf_edgelist(df_triplet_index, renumber=False) df_labels = cure_max_node_idx( df_labels=cugraph.weakly_connected_components(graph), # type: ignore max_node_idx=edge_index.shape[1] - 1, ).rename( columns={ "vertex": "edge_idx", "labels": "track_id", }, ) # Renumber the track IDs so that they go from 0 to the number of tracks - 1 _, indices = cp.unique(df_labels["track_id"].to_cupy(), return_inverse=True) # type: ignore df_labels["track_id"] = indices # type: ignore # Now let's handle the remaining triplets df_remaining_triplets = cudf.DataFrame( { "track_id_left": cp.asarray(forked_triplet_index[0]), "track_id_right": cp.asarray(forked_triplet_index[1]), "central_hit_idx": cp.asarray(edge_index[1][forked_triplet_index[0]]), } ) # Add `track_id` information # This allow to transform a "triplet", that is, a link between 2 edges, # into a link between 2 tracks. df_remaining_triplets = triplets.update_dataframe_track_links( df_track_links=df_remaining_triplets, df_new_labels=df_labels.rename(columns={"edge_idx": "old_track_id"}), # type: ignore ) # Transform triplets into links between track IDs # Remove duplicate links df_track_links = df_remaining_triplets.drop_duplicates( ["central_hit_idx", "track_id_left", "track_id_right"] ) # First consider links that are effectively not a fork df_track_links["fork"] = tarray.count_occurences( # type: ignore df_track_links["central_hit_idx"].to_cupy() # type: ignore ) df_track_links_wo_fork = df_track_links[df_track_links["fork"] == 1][ # type: ignore ["track_id_left", "track_id_right"] ] df_track_links_with_fork = df_track_links[df_track_links["fork"] >= 2][ # type: ignore ["track_id_left", "track_id_right"] ] if df_track_links_wo_fork.shape[0]: # type: ignore df_new_labels = cure_max_node_idx( df_labels=cugraph.weakly_connected_components( # type: ignore cugraph.from_cudf_edgelist( df_track_links_wo_fork, source="track_id_left", destination="track_id_right", renumber=False, ) ), max_node_idx=df_labels["track_id"].max(), # type: ignore ).rename( columns={ "vertex": "old_track_id", "labels": "track_id", } ) # affect the new labels to `df_labels` df_labels = triplets.update_dataframe_connected_edges( df_connected_edges=df_labels, # type: ignore df_new_labels=df_new_labels, # type: ignore ) # affect the new labels to `df_track_links_with_fork` df_track_links_with_fork = triplets.update_dataframe_track_links( df_track_links=df_track_links_with_fork, # type: ignore df_new_labels=df_new_labels, # type: ignore ) df_track_links_with_fork = df_track_links_with_fork.drop_duplicates( ["track_id_left", "track_id_right"] ) if df_track_links_with_fork.shape[0]: # type: ignore df_labels = triplets.forked_track_links_to_tracks( df_track_links_with_fork=df_track_links_with_fork, # type: ignore df_connected_edges=df_labels, # type: ignore ) df_tracks = triplets.connected_edges_to_connected_hits( edge_index=edge_index, df_connected_edges=df_labels # type: ignore ) return df_tracks # type: ignore
[docs]def split_articulations( triplet_index: torch.Tensor, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: # Split triplets into 2 categories # - triplets to merge know # - triplets that may need copy of hits # Triplets that may need copy are articulations that share an edge triplet_left_counts = tarray.count_occurences(triplet_index[0]) triplet_right_counts = tarray.count_occurences(triplet_index[1]) mask_triplet_to_merge_now = (triplet_left_counts == 1) & (triplet_right_counts == 1) return ( triplet_index[:, mask_triplet_to_merge_now], triplet_index[:, ~mask_triplet_to_merge_now], )
[docs]def connect_elbows( triplet_indices: typing.Dict[str, torch.Tensor], n_edges: int, ) -> typing.Tuple[tarray.DataFrame, torch.Tensor]: """Connect the left and right elbows. Args: triplet_indices: dictionary that associates a type of triplet with the triplet indices n_edges: number of edges Returns: Tuple of a dataframe and a tensor. The dataframe defines the the small connected components formed by connecting the elbows, through its two columns ``edge_idx`` and ``track_id``. The tensor corresponds to the articulation link indices between these connected components. """ use_cuda = triplet_indices["articulation"].device.type == "cuda" dtype = triplet_indices["articulation"].dtype np_or_cp = tarray.get_numpy_or_cupy(use_cuda=use_cuda) # First merge left elbows, then right elbows # Assign to each edge the smallest edge it is connected to by an elbow old_to_new_indices = torch.arange( n_edges, device=triplet_indices["articulation"].device, dtype=dtype, ) for triplet_name in ["elbow_left", "elbow_right"]: triplet_index = triplet_indices.pop(triplet_name) df_triplet = tarray.to_dataframe( { "triplet_idx_left": triplet_index[0], "triplet_idx_right": triplet_index[1], }, use_cuda=use_cuda, ) df_triplet["triplet_idx_min"] = df_triplet[ ["triplet_idx_left", "triplet_idx_right"] ].min( # type: ignore axis=1 ) df_triplet["triplet_idx_max"] = df_triplet[ ["triplet_idx_left", "triplet_idx_right"] ].max( # type: ignore axis=1 ) # Assign to `triplet_idx_right` the smallest value it is connected to # or itself df_new_labels = ( df_triplet.groupby("triplet_idx_max")["triplet_idx_min"] # type: ignore .min() .reset_index() .rename(columns={"triplet_idx_max": "edge_idx", "triplet_idx_min": "label"}) ) old_to_new_indices[ tarray.series_to_tensor(df_new_labels["edge_idx"]).to(dtype=dtype) ] = tarray.series_to_tensor(df_new_labels["label"]).to(dtype=dtype) # Assign new indices to other triplet indices triplet_indices = { triplet_name: old_to_new_indices[triplet_index] for triplet_name, triplet_index in triplet_indices.items() } articulation_triplet_index = triplet_indices.pop("articulation") df_labels = tarray.to_dataframe( { "edge_idx": np_or_cp.arange(old_to_new_indices.shape[0]), "track_id": old_to_new_indices, }, use_cuda=use_cuda, ) # Renumber the track IDs so that they go from 0 to the number of tracks - 1 if df_labels.shape[0]: _, indices = np_or_cp.unique( # type: ignore tarray.series_to_array(df_labels["track_id"]), return_inverse=True ) df_labels["track_id"] = indices # Also propagate this renumbering into `articulation_triplet_index` edge_idx_to_track_id = tarray.array_to_tensor(indices) articulation_triplet_index = edge_idx_to_track_id[articulation_triplet_index] # Remove duplicate articulations due to elbow connections articulation_triplet_index = torch.unique(articulation_triplet_index, dim=1) return df_labels, articulation_triplet_index
[docs]def filter_single_edges( df_labels: tarray.DataFrame, edge_score: torch.Tensor, edge_score_cut: float = 0.7, ): use_cuda = tarray.get_use_cuda_from_dataframe(df_labels) cp_or_np = tarray.get_numpy_or_cupy(use_cuda) df_labels = df_labels.merge( df_labels.groupby("track_id")["edge_idx"] # type: ignore .count() .rename("n_edges") .reset_index(), how="left", on="track_id", ) df_labels = df_labels.merge( tarray.to_dataframe( # type: ignore { "edge_idx": cp_or_np.arange(edge_score.shape[0]), "edge_score": edge_score, }, use_cuda=use_cuda, ), on=["edge_idx"], how="left", ) return df_labels[ (df_labels["n_edges"] >= 2) | (df_labels["edge_score"] >= edge_score_cut) # type: ignore ]
[docs]def build_tracks_from_triplets_1cc( triplet_indices: typing.Dict[str, torch.Tensor], edge_index: torch.Tensor, edge_score: torch.Tensor | None = None, single_edge_score_cut: float | None = None, ) -> tarray.DataFrame: use_cuda = edge_index.device.type == "cuda" df_labels, articulation_triplet_index = connect_elbows( triplet_indices=triplet_indices, n_edges=edge_index.shape[1], ) not_forked_triplet_index, forked_triplet_index = split_articulations( triplet_index=articulation_triplet_index ) # Weakly Connected Component Algorithm applied on edges (instead of hits) # Multiple triplets that fork df_not_forked_triplet_index = tarray.to_dataframe( { "source": not_forked_triplet_index[0], "destination": not_forked_triplet_index[1], }, use_cuda=use_cuda, ) df_new_labels = connected_components( df_edges=df_not_forked_triplet_index, max_node_idx=df_labels["track_id"].max() # type: ignore ).rename( columns={ "vertex": "old_track_id", "labels": "track_id", }, ) df_labels = triplets.update_dataframe_connected_edges( df_connected_edges=df_labels, df_new_labels=df_new_labels, ) # Now let's handle the forked triplets df_remaining_triplets = tarray.to_dataframe( { "track_id_left": forked_triplet_index[0], "track_id_right": forked_triplet_index[1], }, use_cuda=use_cuda, ) df_remaining_triplets = triplets.update_dataframe_track_links( df_track_links=df_remaining_triplets, df_new_labels=df_new_labels, ) # Remove duplicate links df_track_links_with_fork = df_remaining_triplets.drop_duplicates( ["track_id_left", "track_id_right"] ) if df_track_links_with_fork.shape[0]: # type: ignore df_labels = triplets.forked_track_links_to_tracks( df_track_links_with_fork=df_track_links_with_fork, df_connected_edges=df_labels, ) if edge_score is not None: assert single_edge_score_cut is not None df_labels = filter_single_edges( df_labels, edge_score=edge_score, edge_score_cut=single_edge_score_cut, ) return triplets.connected_edges_to_connected_hits( edge_index=edge_index, df_connected_edges=df_labels, # type: ignore )