Source code for pipeline.TrackBuilding.triplets

"""A module that contains helper functions for building tracks from triplets.
"""
import typing
import torch


from utils.tools import tarray


[docs]def get_filtered_triplet_indices( triplet_indices: typing.Dict[str, torch.Tensor], triplet_scores: typing.Dict[str, torch.Tensor], triplet_score_cut: float | typing.Dict[str, float], ) -> typing.Dict[str, torch.Tensor]: """Filter the triplets that have a score lower than the required minimal score. Args: triplet_indices: dictionary that associates a type of triplet with the triplet indices triplet_scores: dictionary that associates a type of triplet with the scores of the triplets triplet_score_cut: minimal triplet score required Returns: dictionary that associates a type of triplet with the filtered triplet indices """ if isinstance(triplet_score_cut, float): triplet_score_cuts = { triplet_name: triplet_score_cut for triplet_name in triplet_indices.keys() } else: assert isinstance(triplet_score_cut, dict) triplet_score_cuts = triplet_score_cut # Filter triplets triplet_indices = { triplet_name: triplet_index[ :, triplet_scores[triplet_name] > triplet_score_cuts[triplet_name] ] for triplet_name, triplet_index in triplet_indices.items() } return triplet_indices
[docs]def connected_edges_to_connected_hits( edge_index: torch.Tensor, df_connected_edges: tarray.DataFrame, ) -> tarray.DataFrame: """Turn connected components of edges into connected components of hits. Args: edge_index: tensor of edge indices df_connected_edges: a dataframe with columns ``edge_idx`` and ``track_id``, defining the connected components of edges. Returns: Dataframe with 2 columns ``hit_idx`` and ``track_id`` """ use_cuda = edge_index.device.type == "cuda" np_or_cp = tarray.get_numpy_or_cupy(use_cuda) edge_indices_idx = np_or_cp.arange(edge_index.shape[1]) df_edges = tarray.to_dataframe( { "hit_idx": torch.cat((edge_index[0], edge_index[1])), "edge_idx": np_or_cp.concatenate((edge_indices_idx, edge_indices_idx)), }, use_cuda=use_cuda, ) return df_connected_edges.merge( df_edges, how="inner", on="edge_idx", )[ ["track_id", "hit_idx"] ].drop_duplicates() # type: ignore
[docs]def update_dataframe_connected_edges( df_connected_edges: tarray.DataFrame, df_new_labels: tarray.DataFrame, ) -> tarray.DataFrame: return df_connected_edges.rename(columns={"track_id": "old_track_id"}).merge( # type: ignore df_new_labels, how="left", on="old_track_id", )[ df_connected_edges.columns ] # type: ignore