Source code for pipeline.utils.graphutils.edgeutils
"""A module that defines utilities to handle edges exclusively.
"""
import typing
import torch
from utils.tools import tarray
[docs]def sort_edge_nodes(edges: torch.Tensor, ordering_tensor: torch.Tensor) -> None:
"""Sort the nodes of the edges in ascending value of a certain tensor
Args:
edges: Two-dimensional array of edges, with shape :math:`\\left(2, n_edges)``
ordering_tensor: Tensor of values for the nodes. The first node of an edge
is required to have a lower value that the second node.
"""
not_correctly_ordered_mask = ordering_tensor[edges[0]] > ordering_tensor[edges[1]]
edges[:, not_correctly_ordered_mask] = edges[:, not_correctly_ordered_mask].flip(0)
[docs]def remove_duplicate_edges(
edge_indices: torch.Tensor,
edge_tensors: typing.List[torch.Tensor],
) -> typing.Tuple[torch.Tensor, typing.List[torch.Tensor]]:
"""Remove duplicate edges in ``edge_indices`` and propagate the removing
to the other "edge" tensors in ``edge_tensors``.
Args:
edge_indices: the edge indices
edge_tensors: a list of edge tensors
Returns:
Updated ``edge_indices`` and ``edge_tensors``.
"""
use_cuda = edge_indices.device.type == "cuda"
df_edges = tarray.to_dataframe(
{
"idx_left": edge_indices[0],
"idx_right": edge_indices[1],
**{
f"_column_{i}": edge_tensor
for i, edge_tensor in enumerate(edge_tensors)
},
},
use_cuda=use_cuda,
)
df_edges = df_edges.drop_duplicates(["idx_left", "idx_right"])
unique_edge_indices = tarray.series_to_tensor(df_edges[["idx_left", "idx_right"]]).T # type: ignore
unique_edge_tensors = [
tarray.series_to_tensor(df_edges[f"_column_{i}"]) # type: ignore
for i, _ in enumerate(edge_tensors)
]
return unique_edge_indices, unique_edge_tensors
[docs]def compute_edge_labels_from_pid_only(
edge_indices: torch.Tensor,
particle_ids: torch.Tensor,
) -> torch.Tensor:
"""Compute the array of labels that indicate whether an edge is True or False.
Can be used for training.
Args:
edge_indices: 2D tensor whose columns are two hit indices, corresponding
to an edge
particle_ids: list of particle IDs of the hits
Returns:
1D tensor whose size is equal to the number of columns in ``edge_indices``,
and that indicates whether the corresponding edge in ``edge_indices`` is a
True edge or not.
"""
return (
# nodes between two edges have the same `particle_id`
particle_ids[edge_indices[0]]
== particle_ids[edge_indices[1]]
) & (
# `particle_id` is not 0 (i.e., not noise)
particle_ids[edge_indices[0]]
!= 0
)