"""A module that defines functions to go from triplets to tracks.
"""
from __future__ import annotations
import typing
import torch
try:
    import cupy as cp
except ImportError:
    cp = None
try:
    import cudf
except ImportError:
    cudf = None
from . import triplets
from .components import connected_components
from utils.tools import tarray
[docs]def split_triplet_indices_two_connected_components(
    triplet_indices: typing.Dict[str, torch.Tensor],
    edge_index: torch.Tensor | None = None,
    strategy: str | None = None,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
    """Split triplets into 2 categories:
    * Triplets to merge now with a first connected component algorithm
    * Triplets to merge after removing the duplicate triplets, after the first \
    connected component algorithm.
    Args:
        triplet_indices: dictionary that associates a type of triplet with
            the triplet indices
        edge_index: edge indices. Only used in the ``no_multiple_central_hit``
            strategy
        strategy: splitting strategy to use, which correspond to which articulations
            to merge in the first connected component algorithm.
            All the strategies are expected to produce the very same result.
            Default is ``no_multiple_edge``.
    Returns:
        Triplet indices to merge by the first connected component algorithm,
        and triplet indices to merge the second one.
    """
    if strategy is None:
        strategy = "2cc_no_multiple_edge"
    if strategy == "2cc_without_articulation":
        triplet_index_to_merge_after_connected_components = triplet_indices[
            "articulation"
        ]
        triplet_index_to_merge_now = torch.cat(
            (
                triplet_indices["elbow_left"],
                triplet_indices["elbow_right"],
            ),
            dim=-1,
        )
    else:
        if strategy == "2cc_no_multiple_central_hit":
            assert (
                edge_index is not None
            ), "`no_multiple_central_hit` strategy required `edge_index`."
            triplet_central_hit_idx_counts = tarray.count_occurences(
                edge_index[1][triplet_indices["articulation"][0]]
            )
            mask_triplet_to_merge_now = triplet_central_hit_idx_counts == 1
        elif strategy == "2cc_no_multiple_edge":
            triplet_left_counts = tarray.count_occurences(
                triplet_indices["articulation"][0]
            )
            triplet_right_counts = tarray.count_occurences(
                triplet_indices["articulation"][1]
            )
            mask_triplet_to_merge_now = (triplet_left_counts == 1) & (
                triplet_right_counts == 1
            )
        else:
            raise ValueError("`strategy` not recognised.")
        triplet_index_to_merge_now = torch.cat(
            (
                triplet_indices["articulation"][:, mask_triplet_to_merge_now],
                triplet_indices["elbow_left"],
                triplet_indices["elbow_right"],
            ),
            dim=-1,
        )
        triplet_index_to_merge_after_connected_components = triplet_indices[
            "articulation"
        ][:, ~mask_triplet_to_merge_now]
    return (
        triplet_index_to_merge_now,
        triplet_index_to_merge_after_connected_components,
    ) 
if cp is not None and cudf is not None:
    CuDfDataFrame = cudf.DataFrame
    from .components import cure_max_node_idx
[docs]    def build_tracks_from_triplets_2cc(
        triplet_indices: typing.Dict[str, torch.Tensor],
        edge_index: torch.Tensor,
        strategy: str | None = None,
    ) -> CuDfDataFrame:
        import cugraph
        (
            triplet_index,
            forked_triplet_index,
        ) = split_triplet_indices_two_connected_components(
            triplet_indices=triplet_indices,
            edge_index=edge_index,
            strategy=strategy,
        )
        # Weakly Connected Component Algorithm applied on edges (instead of hits)
        # Multiple triplets sharing the same central hit are not considered in this
        # first step
        df_triplet_index = cudf.DataFrame(
            {
                "source": cp.asarray(triplet_index[0]),
                "destination": cp.asarray(triplet_index[1]),
            }
        )
        graph = cugraph.from_cudf_edgelist(df_triplet_index, renumber=False)
        df_labels = cure_max_node_idx(
            df_labels=cugraph.weakly_connected_components(graph),  # type: ignore
            max_node_idx=edge_index.shape[1] - 1,
        ).rename(
            columns={
                "vertex": "edge_idx",
                "labels": "track_id",
            },
        )
        # Renumber the track IDs so that they go from 0 to the number of tracks - 1
        _, indices = cp.unique(df_labels["track_id"].to_cupy(), return_inverse=True)  # type: ignore
        df_labels["track_id"] = indices  # type: ignore
        # Now let's handle the remaining triplets
        df_remaining_triplets = cudf.DataFrame(
            {
                "track_id_left": cp.asarray(forked_triplet_index[0]),
                "track_id_right": cp.asarray(forked_triplet_index[1]),
                "central_hit_idx": cp.asarray(edge_index[1][forked_triplet_index[0]]),
            }
        )
        # Add `track_id` information
        # This allow to transform a "triplet", that is, a link between 2 edges,
        # into a link between 2 tracks.
        df_remaining_triplets = triplets.update_dataframe_track_links(
            df_track_links=df_remaining_triplets,
            df_new_labels=df_labels.rename(columns={"edge_idx": "old_track_id"}),  # type: ignore
        )
        # Transform triplets into links between track IDs
        # Remove duplicate links
        df_track_links = df_remaining_triplets.drop_duplicates(
            ["central_hit_idx", "track_id_left", "track_id_right"]
        )
        # First consider links that are effectively not a fork
        df_track_links["fork"] = tarray.count_occurences(  # type: ignore
            df_track_links["central_hit_idx"].to_cupy()  # type: ignore
        )
        df_track_links_wo_fork = df_track_links[df_track_links["fork"] == 1][  # type: ignore
            ["track_id_left", "track_id_right"]
        ]
        df_track_links_with_fork = df_track_links[df_track_links["fork"] >= 2][  # type: ignore
            ["track_id_left", "track_id_right"]
        ]
        if df_track_links_wo_fork.shape[0]:  # type: ignore
            df_new_labels = cure_max_node_idx(
                df_labels=cugraph.weakly_connected_components(  # type: ignore
                    cugraph.from_cudf_edgelist(
                        df_track_links_wo_fork,
                        source="track_id_left",
                        destination="track_id_right",
                        renumber=False,
                    )
                ),
                max_node_idx=df_labels["track_id"].max(),  # type: ignore
            ).rename(
                columns={
                    "vertex": "old_track_id",
                    "labels": "track_id",
                }
            )
            # affect the new labels to `df_labels`
            df_labels = triplets.update_dataframe_connected_edges(
                df_connected_edges=df_labels,  # type: ignore
                df_new_labels=df_new_labels,  # type: ignore
            )
            # affect the new labels to `df_track_links_with_fork`
            df_track_links_with_fork = triplets.update_dataframe_track_links(
                df_track_links=df_track_links_with_fork,  # type: ignore
                df_new_labels=df_new_labels,  # type: ignore
            )
            df_track_links_with_fork = df_track_links_with_fork.drop_duplicates(
                ["track_id_left", "track_id_right"]
            )
        if df_track_links_with_fork.shape[0]:  # type: ignore
            df_labels = triplets.forked_track_links_to_tracks(
                df_track_links_with_fork=df_track_links_with_fork,  # type: ignore
                df_connected_edges=df_labels,  # type: ignore
            )
        df_tracks = triplets.connected_edges_to_connected_hits(
            edge_index=edge_index, df_connected_edges=df_labels  # type: ignore
        )
        return df_tracks  # type: ignore 
[docs]def split_articulations(
    triplet_index: torch.Tensor,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
    # Split triplets into 2 categories
    # - triplets to merge know
    # - triplets that may need copy of hits
    # Triplets that may need copy are articulations that share an edge
    triplet_left_counts = tarray.count_occurences(triplet_index[0])
    triplet_right_counts = tarray.count_occurences(triplet_index[1])
    mask_triplet_to_merge_now = (triplet_left_counts == 1) & (triplet_right_counts == 1)
    return (
        triplet_index[:, mask_triplet_to_merge_now],
        triplet_index[:, ~mask_triplet_to_merge_now],
    ) 
[docs]def connect_elbows(
    triplet_indices: typing.Dict[str, torch.Tensor],
    n_edges: int,
) -> typing.Tuple[tarray.DataFrame, torch.Tensor]:
    """Connect the left and right elbows.
    Args:
        triplet_indices: dictionary that associates a type of triplet with
            the triplet indices
        n_edges: number of edges
    Returns:
        Tuple of a dataframe and a tensor.
        The dataframe defines the the small connected components formed by connecting
        the elbows, through its two columns ``edge_idx`` and ``track_id``.
        The tensor corresponds to the articulation link indices between these
        connected components.
    """
    use_cuda = triplet_indices["articulation"].device.type == "cuda"
    dtype = triplet_indices["articulation"].dtype
    np_or_cp = tarray.get_numpy_or_cupy(use_cuda=use_cuda)
    # First merge left elbows, then right elbows
    # Assign to each edge the smallest edge it is connected to by an elbow
    old_to_new_indices = torch.arange(
        n_edges,
        device=triplet_indices["articulation"].device,
        dtype=dtype,
    )
    for triplet_name in ["elbow_left", "elbow_right"]:
        triplet_index = triplet_indices.pop(triplet_name)
        df_triplet = tarray.to_dataframe(
            {
                "triplet_idx_left": triplet_index[0],
                "triplet_idx_right": triplet_index[1],
            },
            use_cuda=use_cuda,
        )
        df_triplet["triplet_idx_min"] = df_triplet[
            ["triplet_idx_left", "triplet_idx_right"]
        ].min(  # type: ignore
            axis=1
        )
        df_triplet["triplet_idx_max"] = df_triplet[
            ["triplet_idx_left", "triplet_idx_right"]
        ].max(  # type: ignore
            axis=1
        )
        # Assign to `triplet_idx_right` the smallest value it is connected to
        # or itself
        df_new_labels = (
            df_triplet.groupby("triplet_idx_max")["triplet_idx_min"]  # type: ignore
            .min()
            .reset_index()
            .rename(columns={"triplet_idx_max": "edge_idx", "triplet_idx_min": "label"})
        )
        old_to_new_indices[
            tarray.series_to_tensor(df_new_labels["edge_idx"]).to(dtype=dtype)
        ] = tarray.series_to_tensor(df_new_labels["label"]).to(dtype=dtype)
        # Assign new indices to other triplet indices
        triplet_indices = {
            triplet_name: old_to_new_indices[triplet_index]
            for triplet_name, triplet_index in triplet_indices.items()
        }
    articulation_triplet_index = triplet_indices.pop("articulation")
    df_labels = tarray.to_dataframe(
        {
            "edge_idx": np_or_cp.arange(old_to_new_indices.shape[0]),
            "track_id": old_to_new_indices,
        },
        use_cuda=use_cuda,
    )
    # Renumber the track IDs so that they go from 0 to the number of tracks - 1
    if df_labels.shape[0]:
        _, indices = np_or_cp.unique(  # type: ignore
            tarray.series_to_array(df_labels["track_id"]), return_inverse=True
        )
        df_labels["track_id"] = indices
        # Also propagate this renumbering into `articulation_triplet_index`
        edge_idx_to_track_id = tarray.array_to_tensor(indices)
        articulation_triplet_index = edge_idx_to_track_id[articulation_triplet_index]
    # Remove duplicate articulations due to elbow connections
    articulation_triplet_index = torch.unique(articulation_triplet_index, dim=1)
    return df_labels, articulation_triplet_index 
[docs]def filter_single_edges(
    df_labels: tarray.DataFrame,
    edge_score: torch.Tensor,
    edge_score_cut: float = 0.7,
):
    use_cuda = tarray.get_use_cuda_from_dataframe(df_labels)
    cp_or_np = tarray.get_numpy_or_cupy(use_cuda)
    df_labels = df_labels.merge(
        df_labels.groupby("track_id")["edge_idx"]  # type: ignore
        .count()
        .rename("n_edges")
        .reset_index(),
        how="left",
        on="track_id",
    )
    df_labels = df_labels.merge(
        tarray.to_dataframe(  # type: ignore
            {
                "edge_idx": cp_or_np.arange(edge_score.shape[0]),
                "edge_score": edge_score,
            },
            use_cuda=use_cuda,
        ),
        on=["edge_idx"],
        how="left",
    )
    return df_labels[
        (df_labels["n_edges"] >= 2) | (df_labels["edge_score"] >= edge_score_cut)  # type: ignore
    ] 
[docs]def build_tracks_from_triplets_1cc(
    triplet_indices: typing.Dict[str, torch.Tensor],
    edge_index: torch.Tensor,
    edge_score: torch.Tensor | None = None,
    single_edge_score_cut: float | None = None,
) -> tarray.DataFrame:
    use_cuda = edge_index.device.type == "cuda"
    df_labels, articulation_triplet_index = connect_elbows(
        triplet_indices=triplet_indices,
        n_edges=edge_index.shape[1],
    )
    not_forked_triplet_index, forked_triplet_index = split_articulations(
        triplet_index=articulation_triplet_index
    )
    # Weakly Connected Component Algorithm applied on edges (instead of hits)
    # Multiple triplets that fork
    df_not_forked_triplet_index = tarray.to_dataframe(
        {
            "source": not_forked_triplet_index[0],
            "destination": not_forked_triplet_index[1],
        },
        use_cuda=use_cuda,
    )
    df_new_labels = connected_components(
        df_edges=df_not_forked_triplet_index, max_node_idx=df_labels["track_id"].max()  # type: ignore
    ).rename(
        columns={
            "vertex": "old_track_id",
            "labels": "track_id",
        },
    )
    df_labels = triplets.update_dataframe_connected_edges(
        df_connected_edges=df_labels,
        df_new_labels=df_new_labels,
    )
    # Now let's handle the forked triplets
    df_remaining_triplets = tarray.to_dataframe(
        {
            "track_id_left": forked_triplet_index[0],
            "track_id_right": forked_triplet_index[1],
        },
        use_cuda=use_cuda,
    )
    df_remaining_triplets = triplets.update_dataframe_track_links(
        df_track_links=df_remaining_triplets,
        df_new_labels=df_new_labels,
    )
    # Remove duplicate links
    df_track_links_with_fork = df_remaining_triplets.drop_duplicates(
        ["track_id_left", "track_id_right"]
    )
    if df_track_links_with_fork.shape[0]:  # type: ignore
        df_labels = triplets.forked_track_links_to_tracks(
            df_track_links_with_fork=df_track_links_with_fork,
            df_connected_edges=df_labels,
        )
    if edge_score is not None:
        assert single_edge_score_cut is not None
        df_labels = filter_single_edges(
            df_labels,
            edge_score=edge_score,
            edge_score_cut=single_edge_score_cut,
        )
    return triplets.connected_edges_to_connected_hits(
        edge_index=edge_index,
        df_connected_edges=df_labels,  # type: ignore
    )