"""A module that contains helper functions for building tracks from triplets.
"""
import typing
import torch
from utils.tools import tarray
[docs]def get_filtered_triplet_indices(
triplet_indices: typing.Dict[str, torch.Tensor],
triplet_scores: typing.Dict[str, torch.Tensor],
triplet_score_cut: float | typing.Dict[str, float],
) -> typing.Dict[str, torch.Tensor]:
"""Filter the triplets that have a score lower than the required minimal score.
Args:
triplet_indices: dictionary that associates a type of triplet with
the triplet indices
triplet_scores: dictionary that associates a type of triplet with the scores
of the triplets
triplet_score_cut: minimal triplet score required
Returns:
dictionary that associates a type of triplet with the filtered triplet indices
"""
if isinstance(triplet_score_cut, float):
triplet_score_cuts = {
triplet_name: triplet_score_cut for triplet_name in triplet_indices.keys()
}
else:
assert isinstance(triplet_score_cut, dict)
triplet_score_cuts = triplet_score_cut
# Filter triplets
triplet_indices = {
triplet_name: triplet_index[
:, triplet_scores[triplet_name] > triplet_score_cuts[triplet_name]
]
for triplet_name, triplet_index in triplet_indices.items()
}
return triplet_indices
[docs]def connected_edges_to_connected_hits(
edge_index: torch.Tensor,
df_connected_edges: tarray.DataFrame,
) -> tarray.DataFrame:
"""Turn connected components of edges into connected components of hits.
Args:
edge_index: tensor of edge indices
df_connected_edges: a dataframe with columns ``edge_idx`` and ``track_id``,
defining the connected components of edges.
Returns:
Dataframe with 2 columns ``hit_idx`` and ``track_id``
"""
use_cuda = edge_index.device.type == "cuda"
np_or_cp = tarray.get_numpy_or_cupy(use_cuda)
edge_indices_idx = np_or_cp.arange(edge_index.shape[1])
df_edges = tarray.to_dataframe(
{
"hit_idx": torch.cat((edge_index[0], edge_index[1])),
"edge_idx": np_or_cp.concatenate((edge_indices_idx, edge_indices_idx)),
},
use_cuda=use_cuda,
)
return df_connected_edges.merge(
df_edges,
how="inner",
on="edge_idx",
)[
["track_id", "hit_idx"]
].drop_duplicates() # type: ignore
[docs]def forked_track_links_to_tracks(
df_track_links_with_fork: tarray.DataFrame,
df_connected_edges: tarray.DataFrame,
) -> tarray.DataFrame:
"""Form tracks from links between tracks corresponding to forks from one track
to more than 1 track.
Every track link is considered as a new track.
Args:
df_track_links_with_fork: dataframe with columns
``track_id_left`` and ``track_id_right`` that defined tracks linked
to one another
df_connected_edges: dataframe of connected components of edges,
with columns ``edge_idx`` and ``track_id``
Returns:
Dataframe of connected components of edges, with the new forked tracks
"""
use_cuda = tarray.get_use_cuda_from_dataframe(df_connected_edges)
np_or_cp = tarray.get_numpy_or_cupy(use_cuda)
pd_or_cudf = tarray.get_pandas_or_cudf(use_cuda)
# Each forked link will be associated with a new track whose ``track_id``
# is given by a new column ``new_track_id``
max_track_id = df_connected_edges["track_id"].max() # type: ignore
df_track_links_with_fork["new_track_id"] = np_or_cp.arange(
max_track_id + 1,
max_track_id + df_track_links_with_fork.shape[0] + 1,
)
# Build the new tracks
# This will copy the edges
df_track_links_with_fork_melted = df_track_links_with_fork.melt(
id_vars=["new_track_id"],
value_vars=["track_id_left", "track_id_right"],
value_name="track_id",
var_name="column",
)
# Remove the individual connected components the forked tracks are created from
track_ids_to_copy = np_or_cp.unique(
tarray.series_to_array(
df_track_links_with_fork[["track_id_left", "track_id_right"]] # type: ignore
)
)
df_connected_edges_to_remove_mask = df_connected_edges["track_id"].isin( # type: ignore
track_ids_to_copy
)
df_new_labels = df_track_links_with_fork_melted.merge(
df_connected_edges[df_connected_edges_to_remove_mask], # type: ignore
on="track_id",
how="inner",
)[
["edge_idx", "new_track_id"]
].rename( # type: ignore
columns={"new_track_id": "track_id"} # type: ignore
)
# Concatenate the new tracks to the dataframe of tracks
df_connected_edges = pd_or_cudf.concat(
(
df_connected_edges[~df_connected_edges_to_remove_mask], # type: ignore
df_new_labels,
),
axis=0,
)
return df_connected_edges
[docs]def update_dataframe_track_links(
df_track_links: tarray.DataFrame, df_new_labels: tarray.DataFrame
) -> tarray.DataFrame:
"""Update the dataframe of track links with a new definition of tracks.
Tracks are here connected edges.
Args:
df_track_links: dataframe of track links with columns
``track_id_left`` and ``track_id_right``
df_new_labels: dataframe that defines new connection between tracks,
with columns ``old_track_id`` and ``track_id``
Returns:
Updated dataframe of track links, where the old track IDs are replaced by
the new one, as defined in ``df_new_labels``.
"""
for side in ["left", "right"]:
df_track_links = df_track_links.rename(
columns={f"track_id_{side}": f"old_track_id_{side}"}
).merge( # type: ignore
df_new_labels.rename(
columns={
"old_track_id": f"old_track_id_{side}",
"track_id": f"track_id_{side}",
}
), # type: ignore
on=f"old_track_id_{side}",
)[
df_track_links.columns
] # type: ignore
return df_track_links
[docs]def update_dataframe_connected_edges(
df_connected_edges: tarray.DataFrame,
df_new_labels: tarray.DataFrame,
) -> tarray.DataFrame:
return df_connected_edges.rename(columns={"track_id": "old_track_id"}).merge( # type: ignore
df_new_labels,
how="left",
on="old_track_id",
)[
df_connected_edges.columns
] # type: ignore