"""A module to build triplets from edges.
"""
import typing
import torch
from utils.graphutils import batch2df
from utils.tools import tarray
[docs]def from_df_edges_to_df_triplets(
df_edges: tarray.DataFrame,
) -> typing.Dict[str, tarray.DataFrame]:
dict_df_triplets = {}
TRIPLET_COLUMNS = ["edge_idx_1", "edge_idx_2"]
# Build articulations
dict_df_triplets["articulation"] = df_edges[
["edge_idx", "hit_idx_right"]
].merge( # type: ignore
df_edges[["edge_idx", "hit_idx_left"]],
how="inner",
left_on="hit_idx_right",
right_on="hit_idx_left",
suffixes=("_1", "_2"),
)[
TRIPLET_COLUMNS
]
# Build elbow left and right
for side in {"left", "right"}:
dict_df_triplets[f"elbow_{side}"] = df_edges[
["edge_idx", f"hit_idx_{side}"]
].merge( # type: ignore
df_edges[["edge_idx", f"hit_idx_{side}"]],
how="inner",
on=f"hit_idx_{side}",
suffixes=("_1", "_2"),
)[
TRIPLET_COLUMNS
]
# Remove edges connected with themselves -> `!=`
# Remove duplicate elbows (i.e., elbow 1-2 and elbow 2-1) -> ``>=`` or `<=`
dict_df_triplets[f"elbow_{side}"] = dict_df_triplets[f"elbow_{side}"][
dict_df_triplets[f"elbow_{side}"]["edge_idx_1"]
< dict_df_triplets[f"elbow_{side}"]["edge_idx_2"]
]
return dict_df_triplets
[docs]def from_edge_index_to_triplet_indices(
edge_index: torch.Tensor,
) -> typing.Dict[str, torch.Tensor]:
"""Build the triplet indices from the array of edge indices.
Args:
edge_index: tensor of shape ``(2, n_edges)`` with the edge indices
Returns:
Dictionary that associates a triplet name with the tensor of triplet indices
``(2, n_triplets)``
"""
df_edges = batch2df.get_df_edges_from_edge_index(edge_index=edge_index)
dict_df_triplets = from_df_edges_to_df_triplets(df_edges=df_edges)
dict_triplet_indices = {
triplet_name: tarray.series_to_tensor(
df_triplets[["edge_idx_1", "edge_idx_2"]]
).T
for triplet_name, df_triplets in dict_df_triplets.items()
}
return dict_triplet_indices
[docs]def compute_triplet_truths(
df_triplets: tarray.DataFrame,
df_edges_particles: tarray.DataFrame,
) -> None:
"""Add the target column ``y`` to the dataframe of triplets. In-place.
Args:
df_triplets: dataframe of triplets, without truth information,
with columns ``triplet_idx``, ``edge_idx_1`` and ``edge_idx_2``
df_edges_particles: dataframe of edges, with truth information,
with columns ``edge_idx``, ``y`` and ``particle_id``
"""
df_triplets_particles = df_triplets.copy()
for edge_number in {1, 2}:
df_triplets_particles = df_triplets_particles.merge(
df_edges_particles.rename( # type: ignore
columns={
column: f"{column}_{edge_number}"
for column in df_edges_particles.columns
}
),
on=f"edge_idx_{edge_number}",
how="left",
)
df_triplets_particles = df_triplets_particles[
["triplet_idx", "y_1", "y_2", "particle_id_1", "particle_id_2"]
]
df_triplets_particles["y"] = ( # type: ignore
df_triplets_particles["y_1"] # type: ignore
& df_triplets_particles["y_2"] # type: ignore
& (
df_triplets_particles["particle_id_1"] # type: ignore
== df_triplets_particles["particle_id_2"] # type: ignore
)
& (df_triplets_particles["particle_id_1"] != 0) # type: ignore
)
df_triplets["y"] = (
df_triplets_particles[["triplet_idx", "y"]] # type: ignore
.groupby("triplet_idx", sort=True)["y"] # type: ignore
.max()
)
[docs]def from_triplet_index_to_triplet_truth(
triplet_index: torch.Tensor, df_edges_particles: tarray.DataFrame
) -> torch.Tensor:
df_triplets = batch2df.get_df_triplets_from_triplet_index(
triplet_index=triplet_index
)
compute_triplet_truths(
df_triplets=df_triplets, df_edges_particles=df_edges_particles
)
return tarray.series_to_tensor(df_triplets["y"])
[docs]def get_triplet_truths_from_tensors(
triplet_indices: typing.Dict[str, torch.Tensor],
edge_index: torch.Tensor,
edge_truth: torch.Tensor,
particle_id_hit_idx: torch.Tensor,
) -> typing.Dict[str, torch.Tensor]:
df_hits_particles = batch2df.get_df_hits_particles_from_particle_id_hit_idx(
particle_id_hit_idx
)
df_edges = batch2df.get_df_edges_from_edge_index(
edge_index=edge_index, tensors={"y": edge_truth}
)
df_edges_particles = batch2df.merge_df_hits_particles_to_edges(
df_edges=df_edges,
df_hits_particles=df_hits_particles,
combine_particle_id=True,
)
triplet_truths = {
triplet_name: from_triplet_index_to_triplet_truth(
triplet_index=triplet_index,
df_edges_particles=df_edges_particles,
)
for triplet_name, triplet_index in triplet_indices.items()
}
return triplet_truths