Source code for pipeline.utils.plotutils.tracks

"""A module that contains utilies to plot tracks.
"""
import typing

import numpy as np
import numpy.typing as npt
import pandas as pd
from matplotlib.pyplot import cm  # type: ignore
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

from .plotools import add_text


def _filter_and_set_up_colors(
    df_units: pd.DataFrame, df_colors: pd.DataFrame, unity: str | typing.List[str]
) -> pd.DataFrame:
    n_particles = df_colors["color_idx"].max() + 1
    df_units = df_units.merge(
        df_units.groupby(unity)["particle_id"]
        .count()
        .rename("n_particles")
        .reset_index(),
        on=unity,
        how="left",
    )
    df_units = df_units.drop_duplicates(unity).merge(
        df_colors, on="particle_id", how="left"
    )
    df_units.drop("particle_id", axis=1, inplace=True)
    df_units.loc[df_units["n_particles"] >= 2, "color_idx"] = n_particles
    return df_units


[docs]def plot_tracks( df_hits_particles: pd.DataFrame, df_edges: pd.DataFrame, particle_ids: npt.ArrayLike, list_axes: typing.List[typing.Tuple[str, str]] | None = None, lhcb: bool = False, ) -> typing.Tuple[Figure, npt.NDArray]: """Plots tracks given their track IDs. Args: df_hits_particles: dataframe of hits-particles with columns ``hit_idx`` and ``particle_id``, and the coordinates ``x``, ``y`` and ``z`` df_edges: dataframe of edges-particles with columns ``hit_idx_left``, ``hit_idx_right`` and ``particle_id`` particle_ids: list of particle IDs to plot list_axes: list of 2-tuples corresponding to the x- and y-axes of each figure. Returns: Maplotlib figure and axes. """ particle_ids = np.asarray(particle_ids) n_particles = particle_ids.shape[0] if list_axes is None: list_axes = [("y", "x"), ("z", "x")] df_edges = df_edges[["hit_idx_left", "hit_idx_right", "particle_id"]] df_hits_particles = df_hits_particles[["hit_idx", "particle_id", "x", "y", "z"]] colors = np.concatenate( ( cm.rainbow(np.linspace(0, 1, n_particles)), np.array([[0.0, 0.0, 0.0, 1.0]]), ), axis=0, ) df_colors = pd.DataFrame( { "particle_id": particle_ids, "color_idx": np.arange(particle_ids.shape[0]), } ) df_hits_particles_reduced = _filter_and_set_up_colors( df_units=df_hits_particles[df_hits_particles["particle_id"].isin(particle_ids)], df_colors=df_colors, unity="hit_idx", ) df_edges_reduced = _filter_and_set_up_colors( df_units=df_edges[df_edges["particle_id"].isin(particle_ids)], df_colors=df_colors, unity=["hit_idx_left", "hit_idx_right"], ) if ( "x" not in df_edges_reduced or "y" not in df_edges_reduced or "z" not in df_edges_reduced ): df_edges_reduced.drop(["x", "y", "z"], axis=1, inplace=True, errors="ignore") for side in ["left", "right"]: df_edges_reduced = df_edges_reduced.merge( df_hits_particles_reduced.rename( columns={ column: f"{column}_{side}" for column in df_hits_particles_reduced.columns } ), on=f"hit_idx_{side}", how="left", ) fig, mpl_axes = plt.subplots(1, len(list_axes), figsize=(8 * len(list_axes), 6)) for mpl_ax, axes in zip(mpl_axes, list_axes): # plot all the hits mpl_ax.set_xlabel(f"${axes[0]}$ [mm]") mpl_ax.set_ylabel(f"${axes[1]}$ [mm]") mpl_ax.grid(color="grey", alpha=0.2) mpl_ax.scatter( df_hits_particles_reduced[axes[0]], df_hits_particles_reduced[axes[1]], c=colors[df_hits_particles_reduced["color_idx"].to_numpy()], ) # Plot all the edges for _, edge_row in df_edges_reduced.iterrows(): mpl_ax.plot( [edge_row[f"{axes[0]}_left"], edge_row[f"{axes[0]}_right"]], [edge_row[f"{axes[1]}_left"], edge_row[f"{axes[1]}_right"]], color=colors[int(edge_row["color_idx"])], ) if lhcb: ymin, ymax = mpl_ax.get_ylim() y_at_xmin = df_hits_particles_reduced[ df_hits_particles_reduced[axes[0]] == df_hits_particles_reduced[axes[0]].min() ][axes[1]].min() if y_at_xmin - ymin > ymax - y_at_xmin: add_text(mpl_ax, ha="left", va="bottom") else: add_text(mpl_ax, ha="left", va="top") fig.tight_layout() return fig, mpl_axes