"""A module that contains helper functions for building tracks from triplets.
"""
import typing
import torch
from utils.tools import tarray
[docs]def get_filtered_triplet_indices(
    triplet_indices: typing.Dict[str, torch.Tensor],
    triplet_scores: typing.Dict[str, torch.Tensor],
    triplet_score_cut: float | typing.Dict[str, float],
) -> typing.Dict[str, torch.Tensor]:
    """Filter the triplets that have a score lower than the required minimal score.
    Args:
        triplet_indices: dictionary that associates a type of triplet with
            the triplet indices
        triplet_scores: dictionary that associates a type of triplet with the scores
            of the triplets
        triplet_score_cut: minimal triplet score required
    Returns:
        dictionary that associates a type of triplet with the filtered triplet indices
    """
    if isinstance(triplet_score_cut, float):
        triplet_score_cuts = {
            triplet_name: triplet_score_cut for triplet_name in triplet_indices.keys()
        }
    else:
        assert isinstance(triplet_score_cut, dict)
        triplet_score_cuts = triplet_score_cut
    # Filter triplets
    triplet_indices = {
        triplet_name: triplet_index[
            :, triplet_scores[triplet_name] > triplet_score_cuts[triplet_name]
        ]
        for triplet_name, triplet_index in triplet_indices.items()
    }
    return triplet_indices 
[docs]def connected_edges_to_connected_hits(
    edge_index: torch.Tensor,
    df_connected_edges: tarray.DataFrame,
) -> tarray.DataFrame:
    """Turn connected components of edges into connected components of hits.
    Args:
        edge_index: tensor of edge indices
        df_connected_edges: a dataframe with columns ``edge_idx`` and ``track_id``,
            defining the connected components of edges.
    Returns:
        Dataframe with 2 columns ``hit_idx`` and ``track_id``
    """
    use_cuda = edge_index.device.type == "cuda"
    np_or_cp = tarray.get_numpy_or_cupy(use_cuda)
    edge_indices_idx = np_or_cp.arange(edge_index.shape[1])
    df_edges = tarray.to_dataframe(
        {
            "hit_idx": torch.cat((edge_index[0], edge_index[1])),
            "edge_idx": np_or_cp.concatenate((edge_indices_idx, edge_indices_idx)),
        },
        use_cuda=use_cuda,
    )
    return df_connected_edges.merge(
        df_edges,
        how="inner",
        on="edge_idx",
    )[
        ["track_id", "hit_idx"]
    ].drop_duplicates()  # type: ignore 
[docs]def forked_track_links_to_tracks(
    df_track_links_with_fork: tarray.DataFrame,
    df_connected_edges: tarray.DataFrame,
) -> tarray.DataFrame:
    """Form tracks from links between tracks corresponding to forks from one track
    to more than 1 track.
    Every track link is considered as a new track.
    Args:
        df_track_links_with_fork: dataframe with columns
            ``track_id_left`` and ``track_id_right`` that defined tracks linked
            to one another
        df_connected_edges: dataframe of connected components of edges,
            with columns ``edge_idx`` and ``track_id``
    Returns:
        Dataframe of connected components of edges, with the new forked tracks
    """
    use_cuda = tarray.get_use_cuda_from_dataframe(df_connected_edges)
    np_or_cp = tarray.get_numpy_or_cupy(use_cuda)
    pd_or_cudf = tarray.get_pandas_or_cudf(use_cuda)
    # Each forked link will be associated with a new track whose ``track_id``
    # is given by a new column ``new_track_id``
    max_track_id = df_connected_edges["track_id"].max()  # type: ignore
    df_track_links_with_fork["new_track_id"] = np_or_cp.arange(
        max_track_id + 1,
        max_track_id + df_track_links_with_fork.shape[0] + 1,
    )
    # Build the new tracks
    # This will copy the edges
    df_track_links_with_fork_melted = df_track_links_with_fork.melt(
        id_vars=["new_track_id"],
        value_vars=["track_id_left", "track_id_right"],
        value_name="track_id",
        var_name="column",
    )
    # Remove the individual connected components the forked tracks are created from
    track_ids_to_copy = np_or_cp.unique(
        tarray.series_to_array(
            df_track_links_with_fork[["track_id_left", "track_id_right"]]  # type: ignore
        )
    )
    df_connected_edges_to_remove_mask = df_connected_edges["track_id"].isin(  # type: ignore
        track_ids_to_copy
    )
    df_new_labels = df_track_links_with_fork_melted.merge(
        df_connected_edges[df_connected_edges_to_remove_mask],  # type: ignore
        on="track_id",
        how="inner",
    )[
        ["edge_idx", "new_track_id"]
    ].rename(  # type: ignore
        columns={"new_track_id": "track_id"}  # type: ignore
    )
    # Concatenate the new tracks to the dataframe of tracks
    df_connected_edges = pd_or_cudf.concat(
        (
            df_connected_edges[~df_connected_edges_to_remove_mask],  # type: ignore
            df_new_labels,
        ),
        axis=0,
    )
    return df_connected_edges 
[docs]def update_dataframe_track_links(
    df_track_links: tarray.DataFrame, df_new_labels: tarray.DataFrame
) -> tarray.DataFrame:
    """Update the dataframe of track links with a new definition of tracks.
    Tracks are here connected edges.
    Args:
        df_track_links: dataframe of track links with columns
            ``track_id_left`` and ``track_id_right``
        df_new_labels: dataframe that defines new connection between tracks,
            with columns ``old_track_id`` and ``track_id``
    Returns:
        Updated dataframe of track links, where the old track IDs are replaced by
        the new one, as defined in ``df_new_labels``.
    """
    for side in ["left", "right"]:
        df_track_links = df_track_links.rename(
            columns={f"track_id_{side}": f"old_track_id_{side}"}
        ).merge(  # type: ignore
            df_new_labels.rename(
                columns={
                    "old_track_id": f"old_track_id_{side}",
                    "track_id": f"track_id_{side}",
                }
            ),  # type: ignore
            on=f"old_track_id_{side}",
        )[
            df_track_links.columns
        ]  # type: ignore
    return df_track_links 
[docs]def update_dataframe_connected_edges(
    df_connected_edges: tarray.DataFrame,
    df_new_labels: tarray.DataFrame,
) -> tarray.DataFrame:
    return df_connected_edges.rename(columns={"track_id": "old_track_id"}).merge(  # type: ignore
        df_new_labels,
        how="left",
        on="old_track_id",
    )[
        df_connected_edges.columns
    ]  # type: ignore