"""A python module that allows to build dataframes directly from a batch PyTorch
data object.
"""
import typing
import torch
from torch_geometric.data import Data
from ..tools import tarray
[docs]def get_df_hits(batch: Data, hit_columns: typing.List[str]) -> tarray.DataFrame:
use_cuda = batch[hit_columns[0]].device.type == "cuda"
cp_or_np = tarray.get_numpy_or_cupy(use_cuda=use_cuda)
df_hits = tarray.to_dataframe(
{
"hit_idx": cp_or_np.arange(batch[hit_columns[0]].shape[0]),
**{hit_column: batch[hit_column] for hit_column in hit_columns},
},
use_cuda=use_cuda,
)
return df_hits
[docs]def get_df_hits_particles_from_particle_id_hit_idx(
particle_id_hit_idx: torch.Tensor,
) -> tarray.DataFrame:
use_cuda = particle_id_hit_idx.device.type == "cuda"
df_hits_particles = tarray.to_dataframe(
{
"particle_id": particle_id_hit_idx[:, 0],
"hit_idx": particle_id_hit_idx[:, 1],
},
use_cuda=use_cuda,
)
return df_hits_particles
[docs]def get_df_hits_particles(
batch: Data,
particle_columns: typing.List[str] | None = None,
) -> tarray.DataFrame:
"""Get the dataframe of hits-particles.
Args:
batch: PyTorch Data object that contains the tensor ``particle_id_hit_idx``
particle_columns: A list of particle columns to merge to the outputted
dataframe.
The particle column names are expected to be prefixed by ``particle_``
in ``batch``.
Returns:
Dataframe of hits-particles with columns ``particle_id``,
``hit_idx``, ``hit_particle_idx`` and the columns ``particle_columns``.
"""
df_hits_particles = get_df_hits_particles_from_particle_id_hit_idx(
particle_id_hit_idx=batch["particle_id_hit_idx"]
)
if particle_columns is not None:
df_particles = tarray.to_dataframe(
{
"particle_id": batch["unique_particle_id"],
**{
particle_column: batch[f"particle_{particle_column}"]
for particle_column in particle_columns
},
},
use_cuda=tarray.get_use_cuda_from_dataframe(df_hits_particles),
)
df_hits_particles = df_hits_particles.merge(
df_particles,
on=["particle_id"],
how="left",
)
return df_hits_particles
[docs]def merge_df_hits_particles_to_edges(
df_edges: tarray.DataFrame,
df_hits_particles: tarray.DataFrame,
combine_particle_id: bool = False,
) -> tarray.DataFrame:
"""Merge the dataframe of edges to the left and right hits of the dataframe
of edges.
Args:
df_edges: Dataframe of edges, at least with columns
``hit_idx_left``, ``hit_idx_right``
df_hits_particles: Dataframe of hits particles, at least with column
``hit_idx`` and ``particle_id``
combine_particle_id: whether to combine
``particle_id_left`` and ``particle_id_right``
into ``particle_id``
Returns:
Dataframe of edges-particles.
"""
assert tarray.get_use_cuda_from_dataframe(
df_edges
) == tarray.get_use_cuda_from_dataframe(df_hits_particles), (
"Provided dataframes are not on the same " "device / host."
)
df_edges_particles = df_edges # WARNING: not a copy
for side in ["left", "right"]:
df_edges_particles = df_edges_particles.merge(
df_hits_particles.rename(
columns={
column: f"{column}_{side}" for column in df_hits_particles.columns
}
), # type: ignore
on=f"hit_idx_{side}",
how="left",
)
# Only keep on particle_id
if combine_particle_id:
# Only keep only particle_id
particle_left_eq_right_mask = (
df_edges_particles["particle_id_left"]
== df_edges_particles["particle_id_right"]
)
df_edges_particles["particle_id"] = 0
df_edges_particles.loc[particle_left_eq_right_mask, "particle_id"] = ( # type: ignore
df_edges_particles["particle_id_left"]
)
df_edges_particles.drop(
["particle_id_left", "particle_id_right"], axis=1, inplace=True
)
return df_edges_particles
[docs]def get_df_edges(
batch: Data,
df_hits_particles: tarray.DataFrame | None = None,
combine_particle_id: bool = False,
) -> tarray.DataFrame:
"""Get the dataframe of edges.
Args:
batch: PyTorch Data object that contains the tensors ``edge_index`` and ``y``.
df_hits_particles: Optional dataframe of hits-particles to merge to
to the left and right hits of the dataframe of edges
Returns:
Dataframe of edges, with columns ``hit_idx_left``, ``hit_idx_right``,
``y``, ``edge_idx`` and the columns provided in ``df_hits_particles``
suffixed by ``_left`` and ``_right``.
"""
df_edges = get_df_edges_from_edge_index(
edge_index=batch["edge_index"], tensors={"y": batch["y"]}
)
if df_hits_particles is not None:
df_edges = merge_df_hits_particles_to_edges(
df_edges=df_edges,
df_hits_particles=df_hits_particles,
combine_particle_id=combine_particle_id,
)
return df_edges # type: ignore
[docs]def get_df_edges_from_edge_index(
edge_index: torch.Tensor,
tensors: typing.Dict[str, torch.Tensor] | None = None,
) -> tarray.DataFrame:
use_cuda = edge_index.device.type == "cuda"
cp_or_np = tarray.get_numpy_or_cupy(use_cuda=use_cuda)
dict_edges = {
"hit_idx_left": edge_index[0],
"hit_idx_right": edge_index[1],
"edge_idx": cp_or_np.arange(edge_index.shape[1]),
}
if tensors is not None:
dict_edges.update(tensors)
df_edges = tarray.to_dataframe(dict_edges, use_cuda=use_cuda)
# assert not df_edges.duplicated(
# ["hit_idx_left", "hit_idx_right"]
# ).any(), "The edges contain duplicates."
return df_edges
[docs]def get_df_triplets_from_triplet_index(triplet_index: torch.Tensor) -> tarray.DataFrame:
use_cuda = triplet_index.device.type == "cuda"
cp_or_np = tarray.get_numpy_or_cupy(use_cuda=use_cuda)
return tarray.to_dataframe(
{
"edge_idx_1": triplet_index[0],
"edge_idx_2": triplet_index[1],
"triplet_idx": cp_or_np.arange(triplet_index.shape[1]),
},
use_cuda=use_cuda,
)