Source code for scripts.plotfactory.plot_embedding_explanation

#!/usr/bin/env python3
"""A script that plots figures to explain the embedding.
"""
import typing
import os
import os.path as op
from argparse import ArgumentParser

import numpy as np
import numpy.typing as npt
import torch
from torch_geometric.data import Data
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure

from utils.commonutils.config import cdirs
from utils.plotutils.plotconfig import configure_matplotlib
from utils.plotutils import plotools
from utils.scriptutils.parser import add_predefined_arguments

configure_matplotlib()


[docs]def get_random_batch_file( experiment_name: str, stage: str, test_dataset_name: str, rng: np.random.Generator, ) -> Data: # Get random batch file input_dir = op.join( cdirs.data_directory, experiment_name, stage, "test", test_dataset_name ) possible_input_paths = [ file_.path for file_ in os.scandir(input_dir) if file_.is_file() if file_.name != "done" ] assert possible_input_paths, f"No files found in {input_dir}" input_path = possible_input_paths[ rng.integers(low=0, high=len(possible_input_paths), size=1)[0] ] return torch.load(input_path, map_location="cpu")
[docs]def get_df_edges( edge_indices: npt.ArrayLike, df_hits: pd.DataFrame, ) -> pd.DataFrame: edge_indices = np.asarray(edge_indices) df_edges = pd.DataFrame( { "hit_idx_left": edge_indices[0], "hit_idx_right": edge_indices[1], }, ) # Add particle ID information for side in ["left", "right"]: df_edges = df_edges.merge( df_hits.rename( # type: ignore columns={column: f"{column}_{side}" for column in df_hits.columns} ), on=f"hit_idx_{side}", how="left", ) return df_edges
[docs]def filter_edges_planes( df_edges: pd.DataFrame, first_plane: int, plane_range: int ) -> pd.DataFrame: return df_edges[ (df_edges["plane_left"] >= first_plane) & (df_edges["plane_right"] <= first_plane + plane_range) ].reset_index(drop=True)
[docs]def get_fig_ax( axes: typing.Tuple[str, str] = ("z", "x"), unit: str = "mm" ) -> typing.Tuple[Figure, Axes]: fig, ax = plt.subplots(figsize=(8, 6)) ax.set_xlabel(f"${axes[0]}$ [{unit}]") ax.set_ylabel(f"${axes[1]}$ [{unit}]") ax.grid(color="grey", alpha=0.2) return fig, ax
[docs]def plot_hits( ax: Axes, df_hits: pd.DataFrame, axes: typing.Tuple[str, str] = ("z", "x") ): ax.scatter(x=df_hits[axes[0]], y=df_hits[axes[1]], s=1, c="k")
[docs]def plot_edges( ax: Axes, df_edges: pd.DataFrame, axes: typing.Tuple[str, str] = ("z", "x"), **kwargs, ): for _, edge in df_edges.iterrows(): ax.plot( (edge[f"{axes[0]}_left"], edge[f"{axes[0]}_right"]), (edge[f"{axes[1]}_left"], edge[f"{axes[1]}_right"]), **kwargs, )
if __name__ == "__main__": parser = ArgumentParser("Plot plots to explain the embedding.") add_predefined_arguments( parser, ["experiment_name", "test_dataset_name", "output_dir"] ) parser.add_argument( "-s", "--seed", help="Random seed to use to randomly select file and hits.", type=int, default=13, required=False, ) parser.add_argument( "-p", "--first_plane", help="First plane number to plot", type=int, default=10, required=False, ) parser.add_argument( "-plane_range", "--plane_range", help="Longest plane distance of an edge.", type=int, default=2, required=False, ) parser.add_argument( "-n", "--n_hits", help="Number of hits the focus on, in the first plane.", default=1, required=False, ) parsed_args = parser.parse_args() experiment_name: str = parsed_args.experiment_name test_dataset_name: str = parsed_args.test_dataset_name seed: int = parsed_args.seed first_plane: int = parsed_args.first_plane plane_range: int = parsed_args.plane_range n_hits: int = parsed_args.n_hits output_dir: str = ( parsed_args.output_dir if parsed_args.output_dir is not None else op.join(cdirs.analysis_directory, "embedding_explanation") ) suffix = f"_{experiment_name}_{test_dataset_name}" rng = np.random.default_rng(seed=seed) batch = get_random_batch_file( experiment_name=experiment_name, stage="embedding_processed", test_dataset_name=test_dataset_name, rng=rng, ) df_hits = pd.DataFrame( { "hit_idx": np.arange(batch["x"].shape[0]), "plane": batch["plane"].numpy(), **{axis: batch[f"un_{axis}"] for axis in ["x", "y", "z"]}, } ) df_hits_particles = pd.DataFrame( { "particle_id": batch["particle_id_hit_idx"][:, 0].numpy(), "hit_idx": batch["particle_id_hit_idx"][:, 1].numpy(), } ) df_edges = get_df_edges(edge_indices=batch["edge_index"].numpy(), df_hits=df_hits) df_true_edges = get_df_edges( edge_indices=batch["signal_true_edges"].numpy(), df_hits=df_hits ) df_hits_plane_range = df_hits[ (df_hits["plane"] >= first_plane) & (df_hits["plane"] <= first_plane + plane_range) ].reset_index(drop=True) df_edges_plane_range = filter_edges_planes( df_edges=df_edges, first_plane=first_plane, plane_range=plane_range ) df_true_edges_plane_range = filter_edges_planes( df_edges=df_true_edges, first_plane=first_plane, plane_range=plane_range ) # Choose hits belonging to the the first plane, randomly first_plane_hit_indices = df_true_edges_plane_range[ df_true_edges_plane_range["plane_left"] == first_plane ]["hit_idx_left"].to_numpy() chosen_first_plane_hit_indices = rng.choice(first_plane_hit_indices, size=n_hits) fig, ax = get_fig_ax() plot_hits(ax=ax, df_hits=df_hits_plane_range) plotools.pad_on_top(ax) plotools.add_text(ax=ax, ha="right", va="top") plotools.save_fig(fig=fig, path=op.join(output_dir, f"hits_only{suffix}")) plt.close(fig) fig, ax = get_fig_ax() plot_edges(ax=ax, df_edges=df_true_edges_plane_range, alpha=0.5) plot_hits(ax=ax, df_hits=df_hits_plane_range) plotools.pad_on_top(ax) plotools.add_text(ax=ax, ha="right", va="top") plotools.save_fig(fig=fig, path=op.join(output_dir, f"hits_and_true_edges{suffix}")) plt.close(fig) fig, ax = get_fig_ax() plot_edges( ax=ax, df_edges=df_true_edges_plane_range[ df_true_edges_plane_range["hit_idx_left"].isin( chosen_first_plane_hit_indices ) ], color="g", ) plot_hits(ax=ax, df_hits=df_hits_plane_range) plotools.pad_on_top(ax) plotools.add_text(ax=ax, ha="right", va="top") plotools.save_fig( fig=fig, path=op.join(output_dir, f"chosen_hits_and_true_edges{suffix}") ) plt.close(fig) fig, ax = get_fig_ax() plot_edges( ax=ax, df_edges=df_edges_plane_range[ df_edges_plane_range["hit_idx_left"].isin(chosen_first_plane_hit_indices) ], color="grey", alpha=0.5, ) plot_edges( ax=ax, df_edges=df_true_edges_plane_range[ df_true_edges_plane_range["hit_idx_left"].isin( chosen_first_plane_hit_indices ) ], color="g", ) plot_hits(ax=ax, df_hits=df_hits_plane_range) plotools.pad_on_top(ax) plotools.add_text(ax=ax, ha="right", va="top") plotools.save_fig( fig=fig, path=op.join(output_dir, f"chosen_hits_and_edges{suffix}") ) plt.close(fig)