"""A module that defines the embedding training and inference.
"""
from __future__ import annotations
import typing
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from utils.modelutils.basemodel import ModelBase
from utils.loaderutils.dataiterator import LazyDatasetBase
from utils.modelutils.metrics import (
compute_classification_efficiency_purity,
compute_efficiency_purity,
)
from utils.commonutils.config import load_config
from utils.graphutils.truths import get_truths_exatrkx
from utils.graphutils.knn import build_edges_plane_by_plane
from utils.graphutils.edgeutils import remove_duplicate_edges
from utils.graphutils.edgebuilding import get_random_pairs_plane_by_plane
TensorOrNone = typing.TypeVar("TensorOrNone", torch.Tensor, None)
pd.set_option("chained_assignment", None)
[docs]class EmbeddingLazyDataSet(LazyDatasetBase):
def __init__(
self,
*args,
particle_requirement: str | None = None,
query_particle_requirement: str | None = None,
target_particle_requirement: str | None = None,
particles_from_parquet: bool = False,
**kwargs,
):
"""This special lazy dataset (used only for training and validation datasets!)
allow to handle the :py:attr:`EmbeddingBase.particle_requirement`,
:py:attr:`EmbeddingBase.query_particle_requirement`
and :py:attr:`EmbeddingBase.target_particle_requirement`.
Writting this in this lazy dataset would allow in theory to compute
the requirements in parallel to the minimisation.
The ``particle_from_parquet`` argument allow to deal with missing variables
in the PyTorch data objects. The latter are directly loaded
from the preprocessed files. This is a bit messy and therefore not
recommended.
"""
super().__init__(*args, **kwargs)
self.particle_requirement = particle_requirement
self.query_particle_requirement = query_particle_requirement
self.target_particle_requirement = target_particle_requirement
self.particles_from_parquet = bool(particles_from_parquet)
[docs] def fetch_dataset(self, input_path: str, map_location: str = "cpu", **kwargs):
batch = super().fetch_dataset(input_path, map_location, **kwargs)
if (
self.particle_requirement is not None
or self.query_particle_requirement is not None
):
batch = self.apply_particle_requirement(batch=batch)
return batch
[docs] def apply_particle_requirement(self, batch: Data) -> Data:
# Load dataframe of particles
if self.particles_from_parquet:
particle_columns = ["from_sdecay", "pid", "has_velo", "has_scifi", "eta"]
df_particles = pd.read_parquet(
f"{batch['truncated_path']}-particles.parquet",
columns=["particle_id"] + particle_columns,
)
df_particles.set_index("particle_id")
else:
particle_columns = [
column[len("particle_") :]
for column in batch.keys()
if column.startswith("particle_") and column != "particle_id_hit_idx"
]
df_particles = pd.DataFrame(
{
column: batch[f"particle_{column}"].numpy()
for column in particle_columns
},
index=batch["unique_particle_id"].numpy(),
)
df_particles.index.name = "particle_id"
# Evaluate whether to keep the particle or not
if self.particle_requirement is not None:
df_particles.eval(
f"keep_mask = ({self.particle_requirement})", inplace=True
)
# Evaluate whether the hit can be a query point or not
if self.query_particle_requirement is not None:
df_particles.eval(
f"query_mask = {self.query_particle_requirement}", inplace=True
)
# Evaluate whether the particle is a target to optimize
if self.target_particle_requirement is not None:
df_particles.eval(
f"target_mask = {self.target_particle_requirement}", inplace=True
)
# Drop the particle columns now that the "mask" columns have been computed
# (merging less columns = faster)
if particle_columns is not None:
df_particles.drop(
[column for column in particle_columns],
axis=1,
inplace=True,
)
# Define the dataframe of hits-particles
# with columns `particle_id`, `hit_idx` and `hit_id`.
# `hit_id` is needed in case the preprocessed dataframe was used.
df_hits_particles = pd.DataFrame(
{
"particle_id": batch["particle_id_hit_idx"][:, 0].numpy(),
"hit_idx": batch["particle_id_hit_idx"][:, 1].numpy(),
},
).merge(
pd.DataFrame(
{
"hit_idx": np.arange(batch["hit_id"].shape[0]),
"hit_id": batch["hit_id"].numpy(),
}
),
on=["hit_idx"],
how="left",
)
# Merge the dataframe of particles to add the columns
# `keep_mask`, `query_mask` and `target_mask`
# This merging removes noise.
df_hits_particles = df_hits_particles.merge(
df_particles, how="inner", on=["particle_id"]
)
# Apply the particle requirement mask
if "keep_mask" in df_hits_particles:
assert not df_hits_particles["keep_mask"].isnull().values.any() # type: ignore
df_hits_particles = df_hits_particles[df_hits_particles["keep_mask"]]
# Compute overall query and target masks
mask_columns = []
if "query_mask" in df_hits_particles:
mask_columns.append("query_mask")
if "target_mask" in df_hits_particles:
mask_columns.append("target_mask")
dict_mask_columns = {}
if mask_columns:
df_hits = (
df_hits_particles.groupby("hit_idx", sort=True)[mask_columns].sum() > 0
).reset_index()
for mask_column in mask_columns:
dict_mask_columns[mask_column] = torch.from_numpy(
df_hits[mask_column].to_numpy()
)
kept_hit_indices = torch.from_numpy(df_hits["hit_idx"].to_numpy())
else:
kept_hit_indices = torch.from_numpy(
np.unique(df_hits_particles["hit_idx"].to_numpy())
)
#: Tensor that allow to convert original hit indices to filtered hit indices
old_to_new_indices = torch.arange(batch["hit_id"].shape[0])
old_to_new_indices[kept_hit_indices] = torch.arange(kept_hit_indices.shape[0])
#: Filter and reindex true edge indices
true_edge_indices = batch["signal_true_edges"]
true_edge_indices = true_edge_indices[
:,
torch.isin(true_edge_indices[0], kept_hit_indices)
& torch.isin(true_edge_indices[1], kept_hit_indices),
]
true_edge_indices = old_to_new_indices[true_edge_indices]
new_batch = Data(
plane=batch["plane"][kept_hit_indices],
fake=batch["fake"][kept_hit_indices],
signal_true_edges=true_edge_indices,
**dict_mask_columns,
)
for hit_variable in ["x", "un_x", "un_y", "un_z", "zatyeq0", "xatyeq0"]:
if hit_variable in batch.keys():
new_batch[hit_variable] = batch[hit_variable][kept_hit_indices]
return new_batch
[docs]class EmbeddingBase(ModelBase):
"""A class that implements the metric learning model."""
[docs] def get_lazy_dataset(self, *args, **kwargs) -> EmbeddingLazyDataSet:
return EmbeddingLazyDataSet(*args, **kwargs)
[docs] def get_lazy_dataset_partition(
self, partition: str, *args, **kwargs
) -> LazyDatasetBase:
if partition in ["train", "val"]:
return super().get_lazy_dataset_partition(
partition=partition,
*args,
particle_requirement=self.hparams.get("particle_requirement"),
query_particle_requirement=self.hparams.get(
"query_particle_requirement"
),
target_particle_requirement=self.hparams.get(
"target_particle_requirement"
),
**kwargs,
)
else:
return super().get_lazy_dataset_partition(
partition=partition, *args, **kwargs
)
@property
def edgedir(self) -> str:
"""Edge direction: ``left`` or ``right``."""
return self.hparams.get("edgedir", "right")
@property
def n_total_planes(self) -> int:
"""Total number of planes."""
return self.hparams["n_planes"]
@property
def n_planes(self) -> int:
"""Number of unremoved planes (e.g., xz-scifi)."""
n_removed_planes = (
len(removed_planes)
if (removed_planes := self.hparams.get("removed_planes")) is not None
else 0
)
return self.n_total_planes - n_removed_planes
@property
def query_planes(self) -> torch.Tensor | None:
"""Planes that can be queried."""
if (list_query_planes := self.hparams.get("query_planes")) is not None:
return torch.tensor(list_query_planes, device=self.device)
else:
return None
@property
def last_plane(self) -> int:
"""Index of the last plane."""
return self.n_planes - 1
[docs] def validate_edges(self, edge_index: torch.Tensor, planes: torch.Tensor):
"""Check whether non-bidirectional edges all have the correct directions.
Args:
edge_index: edge indices
planes: plane index of each hit
"""
# Assert that the planes are sorted
assert torch.all(planes[:-1] <= planes[1:])
# Assert the the planes are in the right direction
edgedir = self.edgedir
left_planes = planes[edge_index[0]]
right_planes = planes[edge_index[1]]
if edgedir == "right":
assert torch.all(left_planes < right_planes)
elif edgedir == "left":
assert torch.all(left_planes > right_planes)
else:
raise ValueError(f"`edgedir` has an an invalid value {edgedir}")
[docs] def remove_planes(
self,
features: torch.Tensor,
planes: torch.Tensor,
true_edge_index: TensorOrNone = None,
) -> typing.Tuple[torch.Tensor, TensorOrNone, torch.Tensor, torch.Tensor | None]:
"""Remove hits belonging to planes given by the hyperparameter
``removed_planes``.
Args:
features: hit features
planes: hit plane indices
truth_edge_index: Optionally, tensor of true edge indices
Returns:
Reindexed hit features, true edge indices, planes and original hit indices.
If no plane is removed, this is indicated by original hit indices
being None.
"""
n_hits = features.shape[0]
if (
removed_planes := self.hparams.get("removed_planes")
) is not None and removed_planes:
# Filter out hits belonging to planes to remove
removed_planes = torch.as_tensor(removed_planes, device=planes.device)
mask_hits_to_keep = ~torch.isin(planes, removed_planes)
if true_edge_index is not None:
old_to_new_indices = torch.full(
size=planes.shape,
fill_value=-1,
device=planes.device,
dtype=torch.long,
)
old_to_new_indices[mask_hits_to_keep] = torch.arange(
mask_hits_to_keep.sum(), # type: ignore
device=mask_hits_to_keep.device,
)
true_edge_index = true_edge_index[
:,
mask_hits_to_keep[true_edge_index[0]]
& mask_hits_to_keep[true_edge_index[1]],
]
true_edge_index = old_to_new_indices[true_edge_index]
planes = planes[mask_hits_to_keep]
# Reindex plane number
unique_planes = torch.arange(self.n_total_planes, device=planes.device)
reindexed_unique_planes = torch.arange(self.n_planes, device=planes.device)
old_to_new_planes = torch.full(
size=(self.n_total_planes,),
fill_value=-1,
device=planes.device,
dtype=torch.long,
)
old_to_new_planes[~torch.isin(unique_planes, removed_planes)] = (
reindexed_unique_planes
)
reindexed_planes = old_to_new_planes[planes]
return (
features[mask_hits_to_keep],
true_edge_index,
reindexed_planes,
torch.arange(n_hits, device=features.device)[mask_hits_to_keep],
)
else:
return features, true_edge_index, planes, None
[docs] def get_query_points(
self,
embeddings: torch.Tensor,
true_edge_indices: torch.Tensor,
planes: torch.Tensor | None = None,
query_mask: torch.Tensor | None = None,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Get the points the edges will be drawn from to generate the training set.
Args:
embeddings: point embeddings
true_edge_indices: true edge indices
particle_ids: particle IDs for each point in ``embeddings``
Returns:
1D tensor of query indices and 2D tensor of query embeddings
"""
query_indices = true_edge_indices.unique()
if query_mask is not None:
query_indices = query_indices[query_mask[query_indices]]
# Remove query points that belong to the last plane
assert planes is not None
last_plane = torch.max(planes)
if (query_planes := self.query_planes) is not None:
if self.edgedir == "left":
assert 0 not in query_planes
elif self.edgedir == "right":
assert last_plane not in query_planes
else:
raise
query_indices = query_indices[
torch.isin(planes[query_indices], query_planes)
]
else:
if self.edgedir == "left":
query_indices = query_indices[planes[query_indices] > 0]
elif self.edgedir == "right":
query_indices = query_indices[planes[query_indices] < last_plane]
else:
raise
query_indices = query_indices[torch.randperm(len(query_indices))][
: self.hparams["points_per_batch"]
]
query_indices = torch.sort(query_indices).values
query_embeddings = embeddings[query_indices]
return query_indices, query_embeddings
[docs] def build_edges(
self,
embeddings: torch.Tensor,
planes: torch.Tensor,
k_max: int,
squared_distance_max: float,
query_embeddings: torch.Tensor | None = None,
query_indices: torch.Tensor | None = None,
) -> torch.Tensor:
"""Build edges by applying kNNs.
Edges are built by looping over the planes,
and drawing neighbours between a plane and the next ``plane_range`` plane,
where ``plane_range`` is an hyperparameter.
Args:
embeddings: embeddings of all the points
planes: planes of all the points.
k_max: maximum number of neigbhours for the kNN
squared_distance_max: maximum (embedded) distance for 2 points
to be considered as neighbours
query_embeddings: embeddings of the query points
query_indices: indices of the query points (in ``embeddings``)
Returns:
Edges build by the kNN.
"""
return build_edges_plane_by_plane(
coords=embeddings,
planes=planes,
k_max=k_max,
squared_distance_max=squared_distance_max,
plane_range=self.hparams.get("plane_range"),
start_coords=query_embeddings,
start_indices=query_indices,
n_planes=self.n_planes,
)
[docs] def get_hnm_pairs(
self,
query_embeddings: torch.Tensor,
query_indices: torch.Tensor,
embeddings: torch.Tensor,
planes: torch.Tensor,
) -> torch.Tensor:
"""Get the edges from hard-negative mining.
Args:
query_embeddings: Embeddings of the query points
query_indices: Corresponding indices of the query points
embeddings: Embeddings of all the points
planes: planes of all the points
Returns:
Edge indices of the hard-negative mined edges
"""
knn_edge_indices = self.build_edges(
query_embeddings=query_embeddings,
embeddings=embeddings,
query_indices=query_indices,
squared_distance_max=self.hparams["squared_distance_max"],
k_max=self.hparams["k_max"],
planes=planes,
)
return knn_edge_indices
[docs] def get_random_pairs(
self,
query_indices: torch.Tensor,
planes: torch.Tensor,
) -> torch.Tensor:
"""Get random edges drawn from the query points.
Args:
query_indices: indices of the query points
embeddings: Embeddings of all the points
planes: planes of all the points. Only used for non-directional graphs,
as random pairs are only drawn from one plane to one of the next
``plane_range`` planes (where ``plane_range`` is a hyperparameter).
Returns:
Edge indices of random edges drawn from the query points
"""
random_pairs = get_random_pairs_plane_by_plane(
n_random=self.hparams["randomisation"],
planes=planes,
query_indices=query_indices,
plane_range=self.hparams.get("plane_range"),
n_planes=self.n_planes,
)
return random_pairs
[docs] def append_true_pairs(
self,
training_edge_indices: torch.Tensor,
y_truth: torch.Tensor,
true_edge_indices: torch.Tensor,
planes: torch.Tensor,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Append the true edges to the tensor of training edges.
Args:
training_edge_indices: training sample of edge indices
y_truth: whether the edges in ``training_edge_indices`` are genuine
or fake
true_edge_indices: all the genuine edge indices
Returns:
Training edge indices with the true edge indices added,
and updated ``y_truth``.
"""
plane_range = self.hparams["plane_range"]
true_edge_indices = true_edge_indices[
:,
torch.abs(planes[true_edge_indices[1]] - planes[true_edge_indices[0]])
<= plane_range,
]
training_edge_indices = torch.cat(
(
training_edge_indices,
true_edge_indices,
),
dim=-1,
)
y_truth = torch.cat(
[
y_truth,
torch.ones(
true_edge_indices.shape[1], device=y_truth.device, dtype=torch.bool
),
]
)
return training_edge_indices, y_truth
[docs] def get_squared_distances(
self,
embeddings: torch.Tensor,
edge_indices: torch.Tensor,
) -> torch.Tensor:
"""Get the squared distances
Args:
embeddings: Embeddings of all the points
edge_indices: edge indices
Returns:
``squared_distances`` tensor corresponding to the squred L2 distance between
the embeddings of the hits of every edge.
"""
reference = embeddings[edge_indices[1]]
neighbors = embeddings[edge_indices[0]]
distance = torch.sum((reference - neighbors) ** 2, dim=-1)
return distance
[docs] def get_truth(
self, edge_indices: torch.Tensor, true_edge_indices: torch.Tensor
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Get the true label of each edge (whether it's genuine or fake).
Args:
edge_indices: edge indices
true_edge_indices: the true edge indices
Returns:
2 one-dimensional torch tensors. The first tensor is the tensor of
edge indices,that could be shuffled a bit.
The second tensor contains, for each edge (column) in ``edge_indices``,
whether this edge is genuine (1) or fake (0).
"""
edge_indices, y_truth = get_truths_exatrkx( # type: ignore
edge_indices=edge_indices,
true_edge_indices=true_edge_indices,
device=self.device,
)
return edge_indices, y_truth
[docs] def get_loss(
self,
embeddings: torch.Tensor,
edge_indices: torch.Tensor,
y_truth: torch.Tensor,
weights: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute the loss for the given embeddings and edges.
Args:
embeddings: embeddings of all the points
edge_indices: edge indices
y_truth: for each edge (column) in ``edge_indices``, whether this edge is
genuine (``True``) or fake (``False``)
weights: edge weights
Returns:
Value of the siamese-like loss
"""
squared_distances = self.get_squared_distances(
embeddings=embeddings, edge_indices=edge_indices
)
negative_squared_distances = squared_distances[~y_truth]
positive_squared_distances = squared_distances[y_truth]
margin = self.hparams["margin"]
if weights is None:
negative_weights = None
positive_weights = None
else:
negative_weights = weights[~y_truth]
positive_weights = weights[y_truth]
if negative_weights is None:
negative_loss = torch.nn.functional.hinge_embedding_loss(
negative_squared_distances,
torch.full_like(input=negative_squared_distances, fill_value=-1.0),
margin=margin,
reduction="mean",
)
else:
negative_losses = torch.nn.functional.hinge_embedding_loss(
negative_squared_distances,
torch.full_like(input=negative_squared_distances, fill_value=-1.0),
margin=margin,
reduction="none",
)
negative_loss = (negative_weights * negative_losses).mean()
if positive_weights is None:
positive_loss = torch.nn.functional.hinge_embedding_loss(
positive_squared_distances,
torch.full_like(input=positive_squared_distances, fill_value=1.0),
margin=margin,
reduction="mean",
)
else:
positive_losses = torch.nn.functional.hinge_embedding_loss(
positive_squared_distances,
torch.full_like(input=positive_squared_distances, fill_value=1.0),
margin=margin,
reduction="mean",
)
positive_loss = (positive_weights * positive_losses).mean()
pos_weight: float = self.hparams.get("weight", 1.0)
return negative_loss + pos_weight * positive_loss
[docs] def get_training_edges(
self,
embeddings: torch.Tensor,
true_edge_indices: torch.Tensor,
planes: torch.Tensor,
query_mask: torch.Tensor | None = None,
) -> typing.Tuple[torch.Tensor, torch.Tensor]:
"""Get the edges used for the training.
Args:
embeddings: Embeddings of all the points
true_edge_indices: 2D tensor of genuine edge indices
particle_ids: tensor of particle IDs for every point. Only used
in the ``query_noise_points`` regime
planes: tensor of planes for every point. Only used for one-directional
graph.
Returns:
2D tensor of training edge indices and 1D tensor indicating whether
the corresponding edge is genuine or fake.
"""
# Sanity check
if self.hparams.get("query_particle_requirement") is not None:
assert query_mask is not None
# Get the query points the edges will be drawn from
query_indices, query_embeddings = self.get_query_points(
embeddings=embeddings,
true_edge_indices=true_edge_indices,
planes=planes,
query_mask=query_mask,
)
list_training_edge_indices = []
# Append Hard Negative Mining (hnm) with KNN graph
list_training_edge_indices.append(
self.get_hnm_pairs(
query_embeddings, query_indices, embeddings, planes=planes
)
)
# Append random edges pairs (rp) for stability
list_training_edge_indices.append(
self.get_random_pairs(query_indices, planes=planes)
)
# Remove true edge indices that cannot be queried
if query_mask is not None:
true_edge_indices = true_edge_indices[:, query_mask[true_edge_indices[0]]]
training_edge_indices = torch.cat(list_training_edge_indices, dim=-1)
# Calculate truth from intersection between prediction graph and truth graph
training_edge_indices, y_truth = self.get_truth(
edge_indices=training_edge_indices, true_edge_indices=true_edge_indices
)
# Append all positive examples and their truth and weighting
training_edge_indices, y_truth = self.append_true_pairs(
training_edge_indices=training_edge_indices,
y_truth=y_truth,
true_edge_indices=true_edge_indices,
planes=planes,
)
# Remove duplicate edge indices
training_edge_indices, [y_truth] = remove_duplicate_edges(
training_edge_indices, edge_tensors=[y_truth]
)
# Sanity check on the plane numbers of the edges
assert planes is not None
assert torch.all(
planes[training_edge_indices[0]] < planes[training_edge_indices[1]]
)
return training_edge_indices, y_truth
[docs] def training_validation_step(
self, batch: Data, with_grad: bool = False
) -> typing.Dict[str, torch.Tensor]:
"""Common step for the training and validation steps.
This encompasses selecting query hits, drawing edges from them,
running the embedding inference and computing the loss.
"""
features = self.get_features(batch)
true_edge_indices = batch["signal_true_edges"]
planes = batch["plane"]
# Filter out fake hits
if self.hparams.get("remove_noise", True):
n_hits = features.shape[0]
features = features[~batch["fake"]]
planes = planes[~batch["fake"]]
old_indices = torch.where(~batch["fake"])[0]
old_to_new_indices = torch.full(
size=(n_hits,),
fill_value=-1,
device=old_indices.device,
dtype=torch.long,
)
old_to_new_indices[old_indices] = torch.arange(
features.shape[0], device=old_to_new_indices.device
)
true_edge_indices = old_to_new_indices[true_edge_indices]
assert not (
true_edge_indices == -1
).any(), "Existing edges between noise hits"
(
features,
true_edge_indices,
reindexed_planes,
original_indices,
) = self.remove_planes(
features=features, true_edge_index=true_edge_indices, planes=planes
)
self.validate_edges(edge_index=true_edge_indices, planes=reindexed_planes)
# Get the training sample
with torch.no_grad():
embeddings = self(features)
training_edge_indices, y_truth = self.get_training_edges(
embeddings=embeddings,
true_edge_indices=true_edge_indices,
planes=reindexed_planes,
query_mask=(batch["query_mask"] if "query_mask" in batch.keys() else None),
)
# Compute grad only for included hits
if with_grad:
included_hits = training_edge_indices.unique()
embeddings[included_hits] = self(features[included_hits])
# Compute hinge loss
if "target_mask" in batch.keys():
target_weight = self.hparams.get("target_weight")
non_target_weight = self.hparams.get("non_target_weight")
assert (
target_weight is not None or non_target_weight is not None
), "`target_mask` is specified but no target weights are set."
edge_target_mask = batch["target_mask"][training_edge_indices[0]]
weights = torch.ones(
training_edge_indices.shape[1], device=embeddings.device
)
if target_weight is not None:
weights[edge_target_mask] = target_weight
if non_target_weight is not None:
weights[~edge_target_mask] = non_target_weight
else:
weights = None
# Sanity check
if self.hparams.get("target_particle_requirement") is not None:
assert weights is not None
loss = self.get_loss(
embeddings=embeddings,
edge_indices=training_edge_indices,
y_truth=y_truth,
weights=weights,
)
outputs = {
"training_edge_indices": training_edge_indices,
"y_truth": y_truth,
"loss": loss,
"embeddings": embeddings,
"original_indices": original_indices,
}
return outputs
[docs] def training_step(self, batch, batch_idx):
outputs = self.training_validation_step(batch=batch, with_grad=True)
self.log(
"train_loss",
outputs["loss"],
on_epoch=True,
on_step=self.hparams.get("on_step", False),
batch_size=outputs["training_edge_indices"].shape[1],
prog_bar=True,
)
return outputs["loss"]
[docs] def validation_step(self, batch: Data, batch_idx: int) -> torch.Tensor:
outputs = self.training_validation_step(batch=batch, with_grad=False)
training_edge_indices = outputs["training_edge_indices"]
embeddings = outputs["embeddings"]
# Compute L2 distance between the edges
training_edge_distances = self.get_squared_distances(
embeddings=embeddings, edge_indices=training_edge_indices
)
y_train = training_edge_distances < self.hparams["squared_distance_max"]
eff, pur = compute_classification_efficiency_purity(
predictions=y_train, truths=outputs["y_truth"]
)
current_lr = self.optimizers().param_groups[0]["lr"] # type: ignore
self.log_dict(
{
"val_loss": outputs["loss"],
"eff": eff,
"pur": pur,
"current_lr": current_lr,
},
on_epoch=True,
on_step=False,
batch_size=training_edge_indices.shape[1],
)
return outputs["loss"]
[docs] def inference(
self,
batch: Data,
squared_distance_max: float,
k_max: int,
evaluate: bool = False,
overall: bool = False,
log: bool = False,
) -> typing.Dict[str, torch.Tensor]:
"""Run the embedding inference + kNN to build edges of an event.
Args:
batch: event PyTorch data object
squared_distance_max: squared maximal distance in the embedding space
k_max: maximal number of neighbours
evaluate: whether to also output the loss, efficiency and purity
overall: if ``batch`` already contains ``edge_index``, whether to enable
concatenaning new edges to the old edge indices instead of
replacing them.
log: whether to add an entry to the log
"""
features = self.get_features(batch)
batch["x"] = features
planes = batch["plane"]
# Get the true edge indices
true_edge_indices = batch["signal_true_edges"]
(
features,
true_edge_indices,
reindexed_planes,
original_indices,
) = self.remove_planes(
features=features,
planes=planes,
true_edge_index=true_edge_indices,
)
# Check that the true edge directions make sense
self.validate_edges(edge_index=true_edge_indices, planes=reindexed_planes)
# Run forward step to get hit embeddings
embeddings = self(features)
# If `query_planes` is specified, edges can only start from hits
# belonging to these planes.
query_planes = self.query_planes
if query_planes is not None:
hit_mask = torch.isin(planes, query_planes)
query_indices = torch.arange(embeddings.shape[0], device=embeddings.device)[
hit_mask
]
query_embeddings = embeddings[hit_mask]
else:
query_indices = None
query_embeddings = None
# Build whole KNN graph
edge_indices = self.build_edges(
embeddings=embeddings,
squared_distance_max=squared_distance_max,
k_max=k_max,
planes=reindexed_planes,
query_embeddings=query_embeddings,
query_indices=query_indices,
)
# Put the hit indices back to when no planes where removed
if original_indices is not None:
edge_indices = original_indices[edge_indices]
true_edge_indices = original_indices[true_edge_indices]
if query_planes is not None:
edge_mask = torch.isin(planes[true_edge_indices[0]], query_planes)
true_edge_indices = true_edge_indices[:, edge_mask]
true_edge_indices = true_edge_indices[
:,
planes[true_edge_indices[1]] == planes[true_edge_indices[0]] + 1,
]
# Extract the edge targets
edge_indices, y_truth = self.get_truth(edge_indices, true_edge_indices)
# Check that the edge directions make sense
self.validate_edges(edge_index=edge_indices, planes=planes)
evaluation_outputs = {}
# If batch already contains edge indices, concatenate the new edge indices
# to this array.
# This allow to stack several embedding networks + kNN to build the edges
# of the overall event.
if (
overall
and "edge_index" in batch.keys()
and (already_defined_edge_indices := batch["edge_index"]) is not None
):
assert "y" in batch.keys()
overall_edge_indices = torch.cat(
(already_defined_edge_indices, edge_indices),
dim=-1,
)
overall_y_truth = torch.cat((batch["y"], y_truth), dim=-1)
else:
overall_edge_indices = edge_indices
overall_y_truth = y_truth
# In evaluation mode, the loss, efficiency and purity are computed
# and can also be logged.
if evaluate:
if original_indices is not None:
original_embeddings = torch.zeros(
size=(planes.shape[0], embeddings.shape[1]),
device=embeddings.device,
dtype=embeddings.dtype,
)
original_embeddings[original_indices] = embeddings
else:
original_embeddings = embeddings
evaluation_outputs["loss"] = self.get_loss(
embeddings=original_embeddings,
edge_indices=edge_indices,
y_truth=y_truth,
)
eff, pur = compute_efficiency_purity(
n_truths=true_edge_indices.shape[1],
n_true_positives=overall_y_truth.sum().cpu().numpy(),
n_positives=overall_edge_indices.shape[1],
)
evaluation_outputs["eff"] = eff
evaluation_outputs["pur"] = pur
if log:
evaluation_outputs["lr"] = self.optimizers().param_groups[0]["lr"] # type: ignore
self.log_dict(
evaluation_outputs,
on_epoch=True,
on_step=False,
batch_size=edge_indices.shape[1],
)
return {
"edge_indices": overall_edge_indices,
"y_truth": overall_y_truth,
"true_edge_indices": true_edge_indices,
"overall_true_edge_indices": batch["signal_true_edges"],
**evaluation_outputs,
}
@property
def input_kwargs(self) -> typing.Dict[str, typing.Any]:
return {
**super(EmbeddingBase, self).input_kwargs,
"input": dict(size=(1, 3), dtype=torch.float),
}
@property
def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]:
return {
**super(EmbeddingBase, self).subnetwork_to_outputs,
"default": ["output"],
}
@property
def input_to_dynamic_axes(self):
return {
**super(EmbeddingBase, self).input_to_dynamic_axes,
"input": {0: "n_nodes"},
"output": {0: "n_nodes"},
}
def _onnx_default(self, input: torch.Tensor) -> torch.Tensor:
return self.forward(input)
[docs] def to_onnx(
self,
outpath: str,
mode: typing.Literal["default"] | None = None,
options: typing.Iterable[str] | None = None,
) -> None:
"""Save model to an ONNX file.
Args:
outpath: Path where to save the ONNX model.
mode: only ``default`` mode is supported.
"""
super(EmbeddingBase, self).to_onnx(outpath=outpath, mode=mode, options=options)
[docs]def get_example_data(
path_or_config: str | dict, idx: int = 0
) -> typing.Tuple[pd.DataFrame, Data]:
configs = load_config(path_or_config)
embedding_configs = configs["embedding"]
model = EmbeddingBase(embedding_configs)
training_example = model.valset[idx]
example_hit_inputs = model.get_input_data(training_example.x)
example_hit_df = pd.DataFrame(example_hit_inputs.numpy())
return example_hit_df, training_example