"""A module that contains utilies to plot tracks.
"""
import typing
import numpy as np
import numpy.typing as npt
import pandas as pd
from matplotlib.pyplot import cm # type: ignore
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from .plotools import add_text
def _filter_and_set_up_colors(
df_units: pd.DataFrame, df_colors: pd.DataFrame, unity: str | typing.List[str]
) -> pd.DataFrame:
n_particles = df_colors["color_idx"].max() + 1
df_units = df_units.merge(
df_units.groupby(unity)["particle_id"]
.count()
.rename("n_particles")
.reset_index(),
on=unity,
how="left",
)
df_units = df_units.drop_duplicates(unity).merge(
df_colors, on="particle_id", how="left"
)
df_units.drop("particle_id", axis=1, inplace=True)
df_units.loc[df_units["n_particles"] >= 2, "color_idx"] = n_particles
return df_units
[docs]def plot_tracks(
df_hits_particles: pd.DataFrame,
df_edges: pd.DataFrame,
particle_ids: npt.ArrayLike,
list_axes: typing.List[typing.Tuple[str, str]] | None = None,
lhcb: bool = False,
) -> typing.Tuple[Figure, npt.NDArray]:
"""Plots tracks given their track IDs.
Args:
df_hits_particles: dataframe of hits-particles with columns
``hit_idx`` and ``particle_id``, and the coordinates
``x``, ``y`` and ``z``
df_edges: dataframe of edges-particles with columns
``hit_idx_left``, ``hit_idx_right`` and ``particle_id``
particle_ids: list of particle IDs to plot
list_axes: list of 2-tuples corresponding to the x- and y-axes of each
figure.
Returns:
Maplotlib figure and axes.
"""
particle_ids = np.asarray(particle_ids)
n_particles = particle_ids.shape[0]
if list_axes is None:
list_axes = [("y", "x"), ("z", "x")]
df_edges = df_edges[["hit_idx_left", "hit_idx_right", "particle_id"]]
df_hits_particles = df_hits_particles[["hit_idx", "particle_id", "x", "y", "z"]]
colors = np.concatenate(
(
cm.rainbow(np.linspace(0, 1, n_particles)),
np.array([[0.0, 0.0, 0.0, 1.0]]),
),
axis=0,
)
df_colors = pd.DataFrame(
{
"particle_id": particle_ids,
"color_idx": np.arange(particle_ids.shape[0]),
}
)
df_hits_particles_reduced = _filter_and_set_up_colors(
df_units=df_hits_particles[df_hits_particles["particle_id"].isin(particle_ids)],
df_colors=df_colors,
unity="hit_idx",
)
df_edges_reduced = _filter_and_set_up_colors(
df_units=df_edges[df_edges["particle_id"].isin(particle_ids)],
df_colors=df_colors,
unity=["hit_idx_left", "hit_idx_right"],
)
if (
"x" not in df_edges_reduced
or "y" not in df_edges_reduced
or "z" not in df_edges_reduced
):
df_edges_reduced.drop(["x", "y", "z"], axis=1, inplace=True, errors="ignore")
for side in ["left", "right"]:
df_edges_reduced = df_edges_reduced.merge(
df_hits_particles_reduced.rename(
columns={
column: f"{column}_{side}"
for column in df_hits_particles_reduced.columns
}
),
on=f"hit_idx_{side}",
how="left",
)
fig, mpl_axes = plt.subplots(1, len(list_axes), figsize=(8 * len(list_axes), 6))
for mpl_ax, axes in zip(mpl_axes, list_axes):
# plot all the hits
mpl_ax.set_xlabel(f"${axes[0]}$ [mm]")
mpl_ax.set_ylabel(f"${axes[1]}$ [mm]")
mpl_ax.grid(color="grey", alpha=0.2)
mpl_ax.scatter(
df_hits_particles_reduced[axes[0]],
df_hits_particles_reduced[axes[1]],
c=colors[df_hits_particles_reduced["color_idx"].to_numpy()],
)
# Plot all the edges
for _, edge_row in df_edges_reduced.iterrows():
mpl_ax.plot(
[edge_row[f"{axes[0]}_left"], edge_row[f"{axes[0]}_right"]],
[edge_row[f"{axes[1]}_left"], edge_row[f"{axes[1]}_right"]],
color=colors[int(edge_row["color_idx"])],
)
if lhcb:
ymin, ymax = mpl_ax.get_ylim()
y_at_xmin = df_hits_particles_reduced[
df_hits_particles_reduced[axes[0]]
== df_hits_particles_reduced[axes[0]].min()
][axes[1]].min()
if y_at_xmin - ymin > ymax - y_at_xmin:
add_text(mpl_ax, ha="left", va="bottom")
else:
add_text(mpl_ax, ha="left", va="top")
fig.tight_layout()
return fig, mpl_axes