#!/usr/bin/env python3
"""A script that runs the performance evaluation of the ETX4VELO pipeline, using
MonteTracko.
"""
import os
import os.path as op
import typing
import logging
import time
from argparse import ArgumentParser
import numpy as np
import pandas as pd
import montetracko as mt
from Preprocessing.particle_fitting_metrics import (
compute_particle_line_metrics_dataframe,
)
from Evaluation.matching import perform_matching
from Evaluation.plotting import plot_evaluation_categories
from Evaluation.reporting import report_evaluation
from utils.commonutils.config import (
load_config,
get_performance_directory_experiment,
cdirs,
get_detector_from_pipeline_config,
)
from utils.commonutils.cdetector import get_coordinate_names
from utils.scriptutils import configure_logger, headline
from utils.scriptutils.parser import add_predefined_arguments
from utils.loaderutils.tracks import load_tracks_preprocessed_dataframes_given_partition
configure_logger()
[docs]def compute_plane_stats(
df_hits_particles: pd.DataFrame, df_particles: pd.DataFrame
) -> pd.DataFrame:
"""Compute variables related to the numbers of hits w.r.t. the planes.
Args:
df_hits_particles: Dataframe of hits-particles association. Must have
the columns ``event_id``, ``particle_id`` and ``plane``.
df_particles: Dataframe of particles. Must have the columns ``event_id``
and ``particle_id``.
Returns:
Dataframe of particles with the new columns.
"""
min_planes = (
df_hits_particles.groupby(["event_id", "particle_id"])["plane"]
.min()
.rename("min_plane")
)
max_planes = (
df_hits_particles.groupby(["event_id", "particle_id"])["plane"]
.max()
.rename("max_plane")
)
n_unique_planes = (
df_hits_particles.groupby(["event_id", "particle_id"])["plane"]
.nunique()
.rename("n_unique_planes")
)
n_hits = (
df_hits_particles.groupby(["event_id", "particle_id"]).size().rename("n_hits")
)
n_repeated_planes = (n_hits - n_unique_planes).rename("n_repeated_planes")
n_skipped_planes = (max_planes - min_planes + 1 - n_unique_planes).rename(
"n_skipped_planes"
)
return df_particles.merge(
pd.concat(
(n_unique_planes, n_repeated_planes, n_skipped_planes), axis=1
).reset_index(),
how="left",
on=["event_id", "particle_id"],
)
[docs]def compute_n_shared_hits(
df_hits_particles: pd.DataFrame, df_particles: pd.DataFrame
) -> pd.DataFrame:
n_particles_per_hit = (
df_hits_particles.groupby(["event_id", "hit_id"])["particle_id"]
.count()
.rename("n_particles_per_hit")
)
df_hits_particles = df_hits_particles.merge(
n_particles_per_hit,
on=["event_id", "hit_id"],
how="left",
)
df_hits_particles["more_than_1_particle"] = (
df_hits_particles["n_particles_per_hit"] >= 2
)
df_particles = df_particles.merge(
df_hits_particles.groupby(["event_id", "particle_id"])["more_than_1_particle"]
.sum()
.rename("n_shared_hits"),
how="left",
on=["event_id", "particle_id"],
)
return df_particles
[docs]def evaluate(
df_hits_particles: pd.DataFrame,
df_particles: pd.DataFrame,
df_tracks: pd.DataFrame,
allen_report: bool = True,
table_report: bool = True,
plot_categories: typing.Iterable[mt.requirement.Category] | None = None,
plotted_groups: typing.List[str] | None = ["basic"],
min_track_length: int = 3,
matching_fraction: float = 0.7,
output_dir: str | None = None,
detector: str | None = None,
suffix: str | None = None,
cure_clones: bool = False,
timestamp: bool = False,
) -> mt.TrackEvaluator:
"""Runs truth-based tracking evaluation.
Args:
path_or_config: path to the Exa.TrkX configuration file.
min_track_length: minimum length of a track to be considered in the evaluation.
whether_to_plot: whether to plot histograms.
allen_report: whether to report in Allen categories using the Allen reporter
plot_categories: Categories to plot on. By default, the one-dimensional
histograms are plotted for the reconstructible tracks in the velo,
and the long electrons.
In order not to plot, you may set this variable to an empty list.
plotted_groups: Pre-configured metrics and columns to plot.
Each group corresponds to one plot that shows the the distributions of
various metrics as a function of various truth variables,
as hard-coded in :py:func:`plot`.
There are 3 groups: ``basic``, ``geometry`` and ``challenging``.
Returns:
object containing the evaluation.
"""
if detector is None:
detector = cdirs.detectors[0]
logging.info("Compute plat stats")
df_particles = compute_plane_stats(
df_hits_particles=df_hits_particles,
df_particles=df_particles,
)
df_particles = compute_n_shared_hits(
df_hits_particles=df_hits_particles,
df_particles=df_particles,
)
if detector == "velo":
if plotted_groups is not None and "geometry" in plotted_groups:
logging.info("Compute particle line metrics")
new_distances = compute_particle_line_metrics_dataframe(
hits=df_hits_particles,
metric_names=[
"distance_to_line",
"distance_to_z_axis",
"xz_angle",
"yz_angle",
],
event_id_column="event_id",
)
df_particles = df_particles.merge(
new_distances, how="left", on=["event_id", "particle_id"]
)
elif detector == "scifi" or detector == "scifi_xz":
df_quadratic_coeffs = compute_particle_line_metrics_dataframe(
hits=df_hits_particles[df_hits_particles["dxdy"] == 0.0],
metric_names=["quadratic_coeff"],
coord_names=get_coordinate_names(detector="scifi_xz"),
line_type="quadpoly_2d",
)
df_particles = df_particles.merge(
df_quadratic_coeffs.reset_index(),
how="left",
on=["event_id", "particle_id"],
)
df_particles["quadratic_coeff"] = df_particles["quadratic_coeff"].abs()
df_particles["px"] = np.abs(df_particles["pt"] * np.cos(df_particles["phi"]))
df_particles["py"] = np.abs(df_particles["pt"] * np.sin(df_particles["phi"]))
df_particles["pz"] = np.sqrt(df_particles["p"] ** 2 - df_particles["pt"] ** 2)
else:
raise ValueError(f"Detector {detector} is not recognised.")
logging.info("2) Matching")
trackEvaluator = perform_matching(
df_tracks=df_tracks,
df_hits_particles=df_hits_particles,
df_particles=df_particles,
min_track_length=min_track_length,
matching_fraction=matching_fraction,
cure_clones=cure_clones,
)
logging.info("3) Evaluation")
perform_evaluation(
trackEvaluator=trackEvaluator,
allen_report=allen_report,
table_report=table_report,
plot_categories=plot_categories,
plotted_groups=plotted_groups,
output_dir=output_dir,
suffix=suffix,
timestamp=timestamp,
detector=detector,
)
return trackEvaluator
[docs]def evaluate_partition(
path_or_config: str | dict,
partition: str,
suffix: str | None = None,
output_dir: str | None = None,
**kwargs,
):
"""Evaluate the track finding performance in a given partition.
Args:
path_or_config: pipeline configuration dictionary or path to a YAML file
that contains it
partition: ``train``, ``val`` or the name of a test dataset.
suffix: Suffix to add to the end of the files that are produced
output_dir: directory where to save the reports and figures
**kwargs: Other keyword arguments passed to :py:func:`evaluate`
"""
config = load_config(path_or_config=path_or_config)
logging.info(headline(f"Evaluation for {partition}"))
if suffix is None:
suffix = ""
(
df_tracks,
df_hits_particles,
df_particles,
) = load_tracks_preprocessed_dataframes_given_partition(
path_or_config=path_or_config, partition=partition, suffix=suffix
)
detector = get_detector_from_pipeline_config(path_or_config=path_or_config)
logging.info("Detector: " + detector)
return evaluate(
df_hits_particles=df_hits_particles,
df_particles=df_particles,
df_tracks=df_tracks,
suffix=f"{suffix}_{partition}",
detector=detector,
output_dir=(
op.join(get_performance_directory_experiment(config), "track_building")
if output_dir is None
else output_dir
),
**kwargs,
)
if __name__ == "__main__":
parser = ArgumentParser(description="Run the evaluation")
add_predefined_arguments(parser, ["pipeline_config", "partition"])
parser.add_argument(
"-s",
"--suffix",
help="Suffix put in the name of the figures and report saved to disk.",
required=False,
)
parser.add_argument(
"--time_stamp",
help="Put a time stamp in the name of the file where the report is saved.",
action="store_true",
)
parsed_args = parser.parse_args()
evaluate_partition(
path_or_config=parsed_args.pipeline_config,
partition=parsed_args.partition,
suffix=parsed_args.suffix,
allen_report=True,
table_report=True,
plotted_groups=None,
timestamp=parsed_args.time_stamp,
)