Source code for pipeline.utils.graphutils.edgeutils
"""A module that defines utilities to handle edges exclusively.
"""
import typing
import torch
from utils.tools import tarray
[docs]def sort_edge_nodes(edges: torch.Tensor, ordering_tensor: torch.Tensor) -> None:
    """Sort the nodes of the edges in ascending value of a certain tensor
    Args:
        edges: Two-dimensional array of edges, with shape :math:`\\left(2, n_edges)``
        ordering_tensor: Tensor of values for the nodes. The first node of an edge
            is required to have a lower value that the second node.
    """
    not_correctly_ordered_mask = ordering_tensor[edges[0]] > ordering_tensor[edges[1]]
    edges[:, not_correctly_ordered_mask] = edges[:, not_correctly_ordered_mask].flip(0) 
[docs]def remove_duplicate_edges(
    edge_indices: torch.Tensor,
    edge_tensors: typing.List[torch.Tensor],
) -> typing.Tuple[torch.Tensor, typing.List[torch.Tensor]]:
    """Remove duplicate edges in ``edge_indices`` and propagate the removing
    to the other "edge" tensors in ``edge_tensors``.
    Args:
        edge_indices: the edge indices
        edge_tensors: a list of edge tensors
    Returns:
        Updated ``edge_indices`` and ``edge_tensors``.
    """
    use_cuda = edge_indices.device.type == "cuda"
    df_edges = tarray.to_dataframe(
        {
            "idx_left": edge_indices[0],
            "idx_right": edge_indices[1],
            **{
                f"_column_{i}": edge_tensor
                for i, edge_tensor in enumerate(edge_tensors)
            },
        },
        use_cuda=use_cuda,
    )
    df_edges = df_edges.drop_duplicates(["idx_left", "idx_right"])
    unique_edge_indices = tarray.series_to_tensor(df_edges[["idx_left", "idx_right"]]).T  # type: ignore
    unique_edge_tensors = [
        tarray.series_to_tensor(df_edges[f"_column_{i}"])  # type: ignore
        for i, _ in enumerate(edge_tensors)
    ]
    return unique_edge_indices, unique_edge_tensors 
[docs]def compute_edge_labels_from_pid_only(
    edge_indices: torch.Tensor,
    particle_ids: torch.Tensor,
) -> torch.Tensor:
    """Compute the array of labels that indicate whether an edge is True or False.
    Can be used for training.
    Args:
        edge_indices: 2D tensor whose columns are two hit indices, corresponding
            to an edge
        particle_ids: list of particle IDs of the hits
    Returns:
        1D tensor whose size is equal to the number of columns in ``edge_indices``,
        and that indicates whether the corresponding edge in ``edge_indices`` is a
        True edge or not.
    """
    return (
        # nodes between two edges have the same `particle_id`
        particle_ids[edge_indices[0]]
        == particle_ids[edge_indices[1]]
    ) & (
        # `particle_id` is not 0 (i.e., not noise)
        particle_ids[edge_indices[0]]
        != 0
    )