#!/usr/bin/env python3
"""A script that plots figures to explain the embedding.
"""
import typing
import os
import os.path as op
from argparse import ArgumentParser
import numpy as np
import numpy.typing as npt
import torch
from torch_geometric.data import Data
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from utils.commonutils.config import cdirs
from utils.plotutils.plotconfig import configure_matplotlib
from utils.plotutils import plotools
from utils.scriptutils.parser import add_predefined_arguments
configure_matplotlib()
[docs]def get_random_batch_file(
experiment_name: str,
stage: str,
test_dataset_name: str,
rng: np.random.Generator,
) -> Data:
# Get random batch file
input_dir = op.join(
cdirs.data_directory, experiment_name, stage, "test", test_dataset_name
)
possible_input_paths = [
file_.path
for file_ in os.scandir(input_dir)
if file_.is_file()
if file_.name != "done"
]
assert possible_input_paths, f"No files found in {input_dir}"
input_path = possible_input_paths[
rng.integers(low=0, high=len(possible_input_paths), size=1)[0]
]
return torch.load(input_path, map_location="cpu")
[docs]def get_df_edges(
edge_indices: npt.ArrayLike,
df_hits: pd.DataFrame,
) -> pd.DataFrame:
edge_indices = np.asarray(edge_indices)
df_edges = pd.DataFrame(
{
"hit_idx_left": edge_indices[0],
"hit_idx_right": edge_indices[1],
},
)
# Add particle ID information
for side in ["left", "right"]:
df_edges = df_edges.merge(
df_hits.rename( # type: ignore
columns={column: f"{column}_{side}" for column in df_hits.columns}
),
on=f"hit_idx_{side}",
how="left",
)
return df_edges
[docs]def filter_edges_planes(
df_edges: pd.DataFrame, first_plane: int, plane_range: int
) -> pd.DataFrame:
return df_edges[
(df_edges["plane_left"] >= first_plane)
& (df_edges["plane_right"] <= first_plane + plane_range)
].reset_index(drop=True)
[docs]def get_fig_ax(
axes: typing.Tuple[str, str] = ("z", "x"), unit: str = "mm"
) -> typing.Tuple[Figure, Axes]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlabel(f"${axes[0]}$ [{unit}]")
ax.set_ylabel(f"${axes[1]}$ [{unit}]")
ax.grid(color="grey", alpha=0.2)
return fig, ax
[docs]def plot_hits(
ax: Axes, df_hits: pd.DataFrame, axes: typing.Tuple[str, str] = ("z", "x")
):
ax.scatter(x=df_hits[axes[0]], y=df_hits[axes[1]], s=1, c="k")
[docs]def plot_edges(
ax: Axes,
df_edges: pd.DataFrame,
axes: typing.Tuple[str, str] = ("z", "x"),
**kwargs,
):
for _, edge in df_edges.iterrows():
ax.plot(
(edge[f"{axes[0]}_left"], edge[f"{axes[0]}_right"]),
(edge[f"{axes[1]}_left"], edge[f"{axes[1]}_right"]),
**kwargs,
)
if __name__ == "__main__":
parser = ArgumentParser("Plot plots to explain the embedding.")
add_predefined_arguments(
parser, ["experiment_name", "test_dataset_name", "output_dir"]
)
parser.add_argument(
"-s",
"--seed",
help="Random seed to use to randomly select file and hits.",
type=int,
default=13,
required=False,
)
parser.add_argument(
"-p",
"--first_plane",
help="First plane number to plot",
type=int,
default=10,
required=False,
)
parser.add_argument(
"-plane_range",
"--plane_range",
help="Longest plane distance of an edge.",
type=int,
default=2,
required=False,
)
parser.add_argument(
"-n",
"--n_hits",
help="Number of hits the focus on, in the first plane.",
default=1,
required=False,
)
parsed_args = parser.parse_args()
experiment_name: str = parsed_args.experiment_name
test_dataset_name: str = parsed_args.test_dataset_name
seed: int = parsed_args.seed
first_plane: int = parsed_args.first_plane
plane_range: int = parsed_args.plane_range
n_hits: int = parsed_args.n_hits
output_dir: str = (
parsed_args.output_dir
if parsed_args.output_dir is not None
else op.join(cdirs.analysis_directory, "embedding_explanation")
)
suffix = f"_{experiment_name}_{test_dataset_name}"
rng = np.random.default_rng(seed=seed)
batch = get_random_batch_file(
experiment_name=experiment_name,
stage="embedding_processed",
test_dataset_name=test_dataset_name,
rng=rng,
)
df_hits = pd.DataFrame(
{
"hit_idx": np.arange(batch["x"].shape[0]),
"plane": batch["plane"].numpy(),
**{axis: batch[f"un_{axis}"] for axis in ["x", "y", "z"]},
}
)
df_hits_particles = pd.DataFrame(
{
"particle_id": batch["particle_id_hit_idx"][:, 0].numpy(),
"hit_idx": batch["particle_id_hit_idx"][:, 1].numpy(),
}
)
df_edges = get_df_edges(edge_indices=batch["edge_index"].numpy(), df_hits=df_hits)
df_true_edges = get_df_edges(
edge_indices=batch["signal_true_edges"].numpy(), df_hits=df_hits
)
df_hits_plane_range = df_hits[
(df_hits["plane"] >= first_plane)
& (df_hits["plane"] <= first_plane + plane_range)
].reset_index(drop=True)
df_edges_plane_range = filter_edges_planes(
df_edges=df_edges, first_plane=first_plane, plane_range=plane_range
)
df_true_edges_plane_range = filter_edges_planes(
df_edges=df_true_edges, first_plane=first_plane, plane_range=plane_range
)
# Choose hits belonging to the the first plane, randomly
first_plane_hit_indices = df_true_edges_plane_range[
df_true_edges_plane_range["plane_left"] == first_plane
]["hit_idx_left"].to_numpy()
chosen_first_plane_hit_indices = rng.choice(first_plane_hit_indices, size=n_hits)
fig, ax = get_fig_ax()
plot_hits(ax=ax, df_hits=df_hits_plane_range)
plotools.pad_on_top(ax)
plotools.add_text(ax=ax, ha="right", va="top")
plotools.save_fig(fig=fig, path=op.join(output_dir, f"hits_only{suffix}"))
plt.close(fig)
fig, ax = get_fig_ax()
plot_edges(ax=ax, df_edges=df_true_edges_plane_range, alpha=0.5)
plot_hits(ax=ax, df_hits=df_hits_plane_range)
plotools.pad_on_top(ax)
plotools.add_text(ax=ax, ha="right", va="top")
plotools.save_fig(fig=fig, path=op.join(output_dir, f"hits_and_true_edges{suffix}"))
plt.close(fig)
fig, ax = get_fig_ax()
plot_edges(
ax=ax,
df_edges=df_true_edges_plane_range[
df_true_edges_plane_range["hit_idx_left"].isin(
chosen_first_plane_hit_indices
)
],
color="g",
)
plot_hits(ax=ax, df_hits=df_hits_plane_range)
plotools.pad_on_top(ax)
plotools.add_text(ax=ax, ha="right", va="top")
plotools.save_fig(
fig=fig, path=op.join(output_dir, f"chosen_hits_and_true_edges{suffix}")
)
plt.close(fig)
fig, ax = get_fig_ax()
plot_edges(
ax=ax,
df_edges=df_edges_plane_range[
df_edges_plane_range["hit_idx_left"].isin(chosen_first_plane_hit_indices)
],
color="grey",
alpha=0.5,
)
plot_edges(
ax=ax,
df_edges=df_true_edges_plane_range[
df_true_edges_plane_range["hit_idx_left"].isin(
chosen_first_plane_hit_indices
)
],
color="g",
)
plot_hits(ax=ax, df_hits=df_hits_plane_range)
plotools.pad_on_top(ax)
plotools.add_text(ax=ax, ha="right", va="top")
plotools.save_fig(
fig=fig, path=op.join(output_dir, f"chosen_hits_and_edges{suffix}")
)
plt.close(fig)