Source code for pipeline.utils.graphutils.edgeutils

"""A module that defines utilities to handle edges exclusively.
"""

import typing
import torch
from utils.tools import tarray


[docs]def sort_edge_nodes(edges: torch.Tensor, ordering_tensor: torch.Tensor) -> None: """Sort the nodes of the edges in ascending value of a certain tensor Args: edges: Two-dimensional array of edges, with shape :math:`\\left(2, n_edges)`` ordering_tensor: Tensor of values for the nodes. The first node of an edge is required to have a lower value that the second node. """ not_correctly_ordered_mask = ordering_tensor[edges[0]] > ordering_tensor[edges[1]] edges[:, not_correctly_ordered_mask] = edges[:, not_correctly_ordered_mask].flip(0)
[docs]def remove_duplicate_edges( edge_indices: torch.Tensor, edge_tensors: typing.List[torch.Tensor], ) -> typing.Tuple[torch.Tensor, typing.List[torch.Tensor]]: """Remove duplicate edges in ``edge_indices`` and propagate the removing to the other "edge" tensors in ``edge_tensors``. Args: edge_indices: the edge indices edge_tensors: a list of edge tensors Returns: Updated ``edge_indices`` and ``edge_tensors``. """ use_cuda = edge_indices.device.type == "cuda" df_edges = tarray.to_dataframe( { "idx_left": edge_indices[0], "idx_right": edge_indices[1], **{ f"_column_{i}": edge_tensor for i, edge_tensor in enumerate(edge_tensors) }, }, use_cuda=use_cuda, ) df_edges = df_edges.drop_duplicates(["idx_left", "idx_right"]) unique_edge_indices = tarray.series_to_tensor(df_edges[["idx_left", "idx_right"]]).T # type: ignore unique_edge_tensors = [ tarray.series_to_tensor(df_edges[f"_column_{i}"]) # type: ignore for i, _ in enumerate(edge_tensors) ] return unique_edge_indices, unique_edge_tensors
[docs]def compute_edge_labels_from_pid_only( edge_indices: torch.Tensor, particle_ids: torch.Tensor, ) -> torch.Tensor: """Compute the array of labels that indicate whether an edge is True or False. Can be used for training. Args: edge_indices: 2D tensor whose columns are two hit indices, corresponding to an edge particle_ids: list of particle IDs of the hits Returns: 1D tensor whose size is equal to the number of columns in ``edge_indices``, and that indicates whether the corresponding edge in ``edge_indices`` is a True edge or not. """ return ( # nodes between two edges have the same `particle_id` particle_ids[edge_indices[0]] == particle_ids[edge_indices[1]] ) & ( # `particle_id` is not 0 (i.e., not noise) particle_ids[edge_indices[0]] != 0 )