Source code for scripts.trackfactory.dfutils

import typing

import numpy as np
import numpy.typing as npt
import pandas as pd
from torch_geometric.data import Data

from utils.tools import tarray
from utils.plotutils.plotools import save_fig
from utils.plotutils.tracks import plot_tracks


[docs]def get_df_hits_particles_from_batch(batch: Data) -> pd.DataFrame: """Load dataframe of hits-particles from a PyTorch batch. Args: batch: PyTorch Geometric data object Returns: Dataframe of hits-particles, with columns ``particle_id`` and ``hit_idx``, as long with the coordinates ``x``, ``y`` and ``z``, the ``plane`` number, and the ``hit_number`` within the particle and the particle ``length``. """ hit_indices = batch["particle_id_hit_idx"][:, 1].numpy() particle_ids = batch["particle_id_hit_idx"][:, 0].numpy() df_hits_particles = pd.DataFrame( { "particle_id": particle_ids, "hit_idx": hit_indices, "plane": batch["plane"].numpy()[hit_indices], **{ axis: batch[f"un_{axis}"][batch["particle_id_hit_idx"][:, 1]] for axis in ["x", "y", "z"] }, }, ) # Add hit number within particle df_hits_particles = df_hits_particles.merge( df_hits_particles.groupby("particle_id")["plane"] .min() .rename("min_plane") .reset_index(), how="left", on="particle_id", ) df_hits_particles = df_hits_particles.merge( df_hits_particles.groupby("particle_id")["plane"] .max() .rename("max_plane") .reset_index(), how="left", on="particle_id", ) return df_hits_particles
[docs]def add_particle_information( truncated_path: str, df_hits_particles: pd.DataFrame, particle_columns: typing.List[str], ) -> pd.DataFrame: """Add particle information to the dataframe of hits-particles.""" df_particles = pd.read_parquet( truncated_path + "-particles.parquet", columns=["particle_id"] + particle_columns, ) df_hits_particles = df_hits_particles.merge( df_particles, on="particle_id", how="left" ) return df_hits_particles
[docs]def get_df_edges( edge_indices: npt.ArrayLike, df_hits_particles: pd.DataFrame, ) -> pd.DataFrame: edge_indices = np.asarray(edge_indices) df_edges = pd.DataFrame( { "hit_idx_left": edge_indices[0], "hit_idx_right": edge_indices[1], }, ) # Add particle ID information for side in ["left", "right"]: df_edges = df_edges.merge( df_hits_particles.rename( # type: ignore columns={ column: f"{column}_{side}" for column in df_hits_particles.columns } ), on=[f"hit_idx_{side}"], how="left", ) # Only keep edges with constant particle ID df_edges = df_edges[df_edges["particle_id_left"] == df_edges["particle_id_right"]] df_edges = df_edges.drop("particle_id_left", axis=1).rename( columns={"particle_id_right": "particle_id"} ) return df_edges
[docs]def compute_edge_counts(df_edges: pd.DataFrame) -> pd.DataFrame: return df_edges.merge( df_edges.groupby(["hit_idx_left", "hit_idx_right"])["particle_id"] .count() .rename("n_particles") .reset_index(), on=["hit_idx_left", "hit_idx_right"], how="left", )
[docs]def compute_n_particles_hit(df_hits_particles: pd.DataFrame) -> pd.DataFrame: df_hits_particles["n_particles_hit"] = tarray.count_occurences( df_hits_particles["hit_idx"].to_numpy() ) return df_hits_particles
[docs]def find_connected_particle_ids( particle_id: int, df_hits_particles: pd.DataFrame ) -> npt.NDArray: return df_hits_particles[ df_hits_particles["hit_idx"].isin( df_hits_particles[df_hits_particles["particle_id"] == particle_id][ "hit_idx" ] ) ]["particle_id"].unique()
[docs]def plot_connected_particle_ids( particle_ids: typing.Iterable[int], df_hits_particles: pd.DataFrame, df_true_edges: pd.DataFrame, n_plots_max: int, output_wpath: str, lhcb: bool = False, ) -> int: n_plots = 0 all_plotted_particle_ids = [] for particle_id in particle_ids: if particle_id not in all_plotted_particle_ids: n_plots += 1 connected_particle_ids = find_connected_particle_ids( particle_id, df_hits_particles=df_hits_particles ) fig, _ = plot_tracks( df_hits_particles=df_hits_particles, df_edges=df_true_edges, particle_ids=connected_particle_ids, lhcb=lhcb, ) save_fig( fig, output_wpath.format(particle_id=particle_id), ) all_plotted_particle_ids += list(connected_particle_ids) if n_plots >= n_plots_max: break return n_plots
[docs]def no_shared_edges(hit_indices: np.ndarray, df_edges: pd.DataFrame): for side in ("left", "right"): hit_indices = df_edges[ df_edges[f"hit_idx_{side}"].isin(hit_indices) & (df_edges["n_particles"] == 1) ][f"hit_idx_{side}"].unique() return hit_indices