Source code for pipeline.Embedding.embedding_base

"""A module that defines the embedding training and inference.
"""

from __future__ import annotations
import typing

import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data

from utils.modelutils.basemodel import ModelBase
from utils.loaderutils.dataiterator import LazyDatasetBase
from utils.modelutils.metrics import (
    compute_classification_efficiency_purity,
    compute_efficiency_purity,
)
from utils.commonutils.config import load_config
from utils.graphutils.truths import get_truths_exatrkx
from utils.graphutils.knn import build_edges_plane_by_plane
from utils.graphutils.edgeutils import remove_duplicate_edges
from utils.graphutils.edgebuilding import get_random_pairs_plane_by_plane

TensorOrNone = typing.TypeVar("TensorOrNone", torch.Tensor, None)

pd.set_option("chained_assignment", None)


[docs]class EmbeddingLazyDataSet(LazyDatasetBase): def __init__( self, *args, particle_requirement: str | None = None, query_particle_requirement: str | None = None, target_particle_requirement: str | None = None, particles_from_parquet: bool = False, **kwargs, ): """This special lazy dataset (used only for training and validation datasets!) allow to handle the :py:attr:`EmbeddingBase.particle_requirement`, :py:attr:`EmbeddingBase.query_particle_requirement` and :py:attr:`EmbeddingBase.target_particle_requirement`. Writting this in this lazy dataset would allow in theory to compute the requirements in parallel to the minimisation. The ``particle_from_parquet`` argument allow to deal with missing variables in the PyTorch data objects. The latter are directly loaded from the preprocessed files. This is a bit messy and therefore not recommended. """ super().__init__(*args, **kwargs) self.particle_requirement = particle_requirement self.query_particle_requirement = query_particle_requirement self.target_particle_requirement = target_particle_requirement self.particles_from_parquet = bool(particles_from_parquet)
[docs] def fetch_dataset(self, input_path: str, map_location: str = "cpu", **kwargs): batch = super().fetch_dataset(input_path, map_location, **kwargs) if ( self.particle_requirement is not None or self.query_particle_requirement is not None ): batch = self.apply_particle_requirement(batch=batch) return batch
[docs] def apply_particle_requirement(self, batch: Data) -> Data: # Load dataframe of particles if self.particles_from_parquet: particle_columns = ["from_sdecay", "pid", "has_velo", "has_scifi", "eta"] df_particles = pd.read_parquet( f"{batch['truncated_path']}-particles.parquet", columns=["particle_id"] + particle_columns, ) df_particles.set_index("particle_id") else: particle_columns = [ column[len("particle_") :] for column in batch.keys() if column.startswith("particle_") and column != "particle_id_hit_idx" ] df_particles = pd.DataFrame( { column: batch[f"particle_{column}"].numpy() for column in particle_columns }, index=batch["unique_particle_id"].numpy(), ) df_particles.index.name = "particle_id" # Evaluate whether to keep the particle or not if self.particle_requirement is not None: df_particles.eval( f"keep_mask = ({self.particle_requirement})", inplace=True ) # Evaluate whether the hit can be a query point or not if self.query_particle_requirement is not None: df_particles.eval( f"query_mask = {self.query_particle_requirement}", inplace=True ) # Evaluate whether the particle is a target to optimize if self.target_particle_requirement is not None: df_particles.eval( f"target_mask = {self.target_particle_requirement}", inplace=True ) # Drop the particle columns now that the "mask" columns have been computed # (merging less columns = faster) if particle_columns is not None: df_particles.drop( [column for column in particle_columns], axis=1, inplace=True, ) # Define the dataframe of hits-particles # with columns `particle_id`, `hit_idx` and `hit_id`. # `hit_id` is needed in case the preprocessed dataframe was used. df_hits_particles = pd.DataFrame( { "particle_id": batch["particle_id_hit_idx"][:, 0].numpy(), "hit_idx": batch["particle_id_hit_idx"][:, 1].numpy(), }, ).merge( pd.DataFrame( { "hit_idx": np.arange(batch["hit_id"].shape[0]), "hit_id": batch["hit_id"].numpy(), } ), on=["hit_idx"], how="left", ) # Merge the dataframe of particles to add the columns # `keep_mask`, `query_mask` and `target_mask` # This merging removes noise. df_hits_particles = df_hits_particles.merge( df_particles, how="inner", on=["particle_id"] ) # Apply the particle requirement mask if "keep_mask" in df_hits_particles: assert not df_hits_particles["keep_mask"].isnull().values.any() # type: ignore df_hits_particles = df_hits_particles[df_hits_particles["keep_mask"]] # Compute overall query and target masks mask_columns = [] if "query_mask" in df_hits_particles: mask_columns.append("query_mask") if "target_mask" in df_hits_particles: mask_columns.append("target_mask") dict_mask_columns = {} if mask_columns: df_hits = ( df_hits_particles.groupby("hit_idx", sort=True)[mask_columns].sum() > 0 ).reset_index() for mask_column in mask_columns: dict_mask_columns[mask_column] = torch.from_numpy( df_hits[mask_column].to_numpy() ) kept_hit_indices = torch.from_numpy(df_hits["hit_idx"].to_numpy()) else: kept_hit_indices = torch.from_numpy( np.unique(df_hits_particles["hit_idx"].to_numpy()) ) #: Tensor that allow to convert original hit indices to filtered hit indices old_to_new_indices = torch.arange(batch["hit_id"].shape[0]) old_to_new_indices[kept_hit_indices] = torch.arange(kept_hit_indices.shape[0]) #: Filter and reindex true edge indices true_edge_indices = batch["signal_true_edges"] true_edge_indices = true_edge_indices[ :, torch.isin(true_edge_indices[0], kept_hit_indices) & torch.isin(true_edge_indices[1], kept_hit_indices), ] true_edge_indices = old_to_new_indices[true_edge_indices] new_batch = Data( plane=batch["plane"][kept_hit_indices], fake=batch["fake"][kept_hit_indices], signal_true_edges=true_edge_indices, **dict_mask_columns, ) for hit_variable in ["x", "un_x", "un_y", "un_z", "zatyeq0", "xatyeq0"]: if hit_variable in batch.keys(): new_batch[hit_variable] = batch[hit_variable][kept_hit_indices] return new_batch
[docs]class EmbeddingBase(ModelBase): """A class that implements the metric learning model."""
[docs] def get_lazy_dataset(self, *args, **kwargs) -> EmbeddingLazyDataSet: return EmbeddingLazyDataSet(*args, **kwargs)
[docs] def get_lazy_dataset_partition( self, partition: str, *args, **kwargs ) -> LazyDatasetBase: if partition in ["train", "val"]: return super().get_lazy_dataset_partition( partition=partition, *args, particle_requirement=self.hparams.get("particle_requirement"), query_particle_requirement=self.hparams.get( "query_particle_requirement" ), target_particle_requirement=self.hparams.get( "target_particle_requirement" ), **kwargs, ) else: return super().get_lazy_dataset_partition( partition=partition, *args, **kwargs )
@property def edgedir(self) -> str: """Edge direction: ``left`` or ``right``.""" return self.hparams.get("edgedir", "right") @property def n_total_planes(self) -> int: """Total number of planes.""" return self.hparams["n_planes"] @property def n_planes(self) -> int: """Number of unremoved planes (e.g., xz-scifi).""" n_removed_planes = ( len(removed_planes) if (removed_planes := self.hparams.get("removed_planes")) is not None else 0 ) return self.n_total_planes - n_removed_planes @property def query_planes(self) -> torch.Tensor | None: """Planes that can be queried.""" if (list_query_planes := self.hparams.get("query_planes")) is not None: return torch.tensor(list_query_planes, device=self.device) else: return None @property def last_plane(self) -> int: """Index of the last plane.""" return self.n_planes - 1
[docs] def validate_edges(self, edge_index: torch.Tensor, planes: torch.Tensor): """Check whether non-bidirectional edges all have the correct directions. Args: edge_index: edge indices planes: plane index of each hit """ # Assert that the planes are sorted assert torch.all(planes[:-1] <= planes[1:]) # Assert the the planes are in the right direction edgedir = self.edgedir left_planes = planes[edge_index[0]] right_planes = planes[edge_index[1]] if edgedir == "right": assert torch.all(left_planes < right_planes) elif edgedir == "left": assert torch.all(left_planes > right_planes) else: raise ValueError(f"`edgedir` has an an invalid value {edgedir}")
[docs] def remove_planes( self, features: torch.Tensor, planes: torch.Tensor, true_edge_index: TensorOrNone = None, ) -> typing.Tuple[torch.Tensor, TensorOrNone, torch.Tensor, torch.Tensor | None]: """Remove hits belonging to planes given by the hyperparameter ``removed_planes``. Args: features: hit features planes: hit plane indices truth_edge_index: Optionally, tensor of true edge indices Returns: Reindexed hit features, true edge indices, planes and original hit indices. If no plane is removed, this is indicated by original hit indices being None. """ n_hits = features.shape[0] if ( removed_planes := self.hparams.get("removed_planes") ) is not None and removed_planes: # Filter out hits belonging to planes to remove removed_planes = torch.as_tensor(removed_planes, device=planes.device) mask_hits_to_keep = ~torch.isin(planes, removed_planes) if true_edge_index is not None: old_to_new_indices = torch.full( size=planes.shape, fill_value=-1, device=planes.device, dtype=torch.long, ) old_to_new_indices[mask_hits_to_keep] = torch.arange( mask_hits_to_keep.sum(), # type: ignore device=mask_hits_to_keep.device, ) true_edge_index = true_edge_index[ :, mask_hits_to_keep[true_edge_index[0]] & mask_hits_to_keep[true_edge_index[1]], ] true_edge_index = old_to_new_indices[true_edge_index] planes = planes[mask_hits_to_keep] # Reindex plane number unique_planes = torch.arange(self.n_total_planes, device=planes.device) reindexed_unique_planes = torch.arange(self.n_planes, device=planes.device) old_to_new_planes = torch.full( size=(self.n_total_planes,), fill_value=-1, device=planes.device, dtype=torch.long, ) old_to_new_planes[~torch.isin(unique_planes, removed_planes)] = ( reindexed_unique_planes ) reindexed_planes = old_to_new_planes[planes] return ( features[mask_hits_to_keep], true_edge_index, reindexed_planes, torch.arange(n_hits, device=features.device)[mask_hits_to_keep], ) else: return features, true_edge_index, planes, None
[docs] def get_query_points( self, embeddings: torch.Tensor, true_edge_indices: torch.Tensor, planes: torch.Tensor | None = None, query_mask: torch.Tensor | None = None, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Get the points the edges will be drawn from to generate the training set. Args: embeddings: point embeddings true_edge_indices: true edge indices particle_ids: particle IDs for each point in ``embeddings`` Returns: 1D tensor of query indices and 2D tensor of query embeddings """ query_indices = true_edge_indices.unique() if query_mask is not None: query_indices = query_indices[query_mask[query_indices]] # Remove query points that belong to the last plane assert planes is not None last_plane = torch.max(planes) if (query_planes := self.query_planes) is not None: if self.edgedir == "left": assert 0 not in query_planes elif self.edgedir == "right": assert last_plane not in query_planes else: raise query_indices = query_indices[ torch.isin(planes[query_indices], query_planes) ] else: if self.edgedir == "left": query_indices = query_indices[planes[query_indices] > 0] elif self.edgedir == "right": query_indices = query_indices[planes[query_indices] < last_plane] else: raise query_indices = query_indices[torch.randperm(len(query_indices))][ : self.hparams["points_per_batch"] ] query_indices = torch.sort(query_indices).values query_embeddings = embeddings[query_indices] return query_indices, query_embeddings
[docs] def build_edges( self, embeddings: torch.Tensor, planes: torch.Tensor, k_max: int, squared_distance_max: float, query_embeddings: torch.Tensor | None = None, query_indices: torch.Tensor | None = None, ) -> torch.Tensor: """Build edges by applying kNNs. Edges are built by looping over the planes, and drawing neighbours between a plane and the next ``plane_range`` plane, where ``plane_range`` is an hyperparameter. Args: embeddings: embeddings of all the points planes: planes of all the points. k_max: maximum number of neigbhours for the kNN squared_distance_max: maximum (embedded) distance for 2 points to be considered as neighbours query_embeddings: embeddings of the query points query_indices: indices of the query points (in ``embeddings``) Returns: Edges build by the kNN. """ return build_edges_plane_by_plane( coords=embeddings, planes=planes, k_max=k_max, squared_distance_max=squared_distance_max, plane_range=self.hparams.get("plane_range"), start_coords=query_embeddings, start_indices=query_indices, n_planes=self.n_planes, )
[docs] def get_hnm_pairs( self, query_embeddings: torch.Tensor, query_indices: torch.Tensor, embeddings: torch.Tensor, planes: torch.Tensor, ) -> torch.Tensor: """Get the edges from hard-negative mining. Args: query_embeddings: Embeddings of the query points query_indices: Corresponding indices of the query points embeddings: Embeddings of all the points planes: planes of all the points Returns: Edge indices of the hard-negative mined edges """ knn_edge_indices = self.build_edges( query_embeddings=query_embeddings, embeddings=embeddings, query_indices=query_indices, squared_distance_max=self.hparams["squared_distance_max"], k_max=self.hparams["k_max"], planes=planes, ) return knn_edge_indices
[docs] def get_random_pairs( self, query_indices: torch.Tensor, planes: torch.Tensor, ) -> torch.Tensor: """Get random edges drawn from the query points. Args: query_indices: indices of the query points embeddings: Embeddings of all the points planes: planes of all the points. Only used for non-directional graphs, as random pairs are only drawn from one plane to one of the next ``plane_range`` planes (where ``plane_range`` is a hyperparameter). Returns: Edge indices of random edges drawn from the query points """ random_pairs = get_random_pairs_plane_by_plane( n_random=self.hparams["randomisation"], planes=planes, query_indices=query_indices, plane_range=self.hparams.get("plane_range"), n_planes=self.n_planes, ) return random_pairs
[docs] def append_true_pairs( self, training_edge_indices: torch.Tensor, y_truth: torch.Tensor, true_edge_indices: torch.Tensor, planes: torch.Tensor, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Append the true edges to the tensor of training edges. Args: training_edge_indices: training sample of edge indices y_truth: whether the edges in ``training_edge_indices`` are genuine or fake true_edge_indices: all the genuine edge indices Returns: Training edge indices with the true edge indices added, and updated ``y_truth``. """ plane_range = self.hparams["plane_range"] true_edge_indices = true_edge_indices[ :, torch.abs(planes[true_edge_indices[1]] - planes[true_edge_indices[0]]) <= plane_range, ] training_edge_indices = torch.cat( ( training_edge_indices, true_edge_indices, ), dim=-1, ) y_truth = torch.cat( [ y_truth, torch.ones( true_edge_indices.shape[1], device=y_truth.device, dtype=torch.bool ), ] ) return training_edge_indices, y_truth
[docs] def get_squared_distances( self, embeddings: torch.Tensor, edge_indices: torch.Tensor, ) -> torch.Tensor: """Get the squared distances Args: embeddings: Embeddings of all the points edge_indices: edge indices Returns: ``squared_distances`` tensor corresponding to the squred L2 distance between the embeddings of the hits of every edge. """ reference = embeddings[edge_indices[1]] neighbors = embeddings[edge_indices[0]] distance = torch.sum((reference - neighbors) ** 2, dim=-1) return distance
[docs] def get_truth( self, edge_indices: torch.Tensor, true_edge_indices: torch.Tensor ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Get the true label of each edge (whether it's genuine or fake). Args: edge_indices: edge indices true_edge_indices: the true edge indices Returns: 2 one-dimensional torch tensors. The first tensor is the tensor of edge indices,that could be shuffled a bit. The second tensor contains, for each edge (column) in ``edge_indices``, whether this edge is genuine (1) or fake (0). """ edge_indices, y_truth = get_truths_exatrkx( # type: ignore edge_indices=edge_indices, true_edge_indices=true_edge_indices, device=self.device, ) return edge_indices, y_truth
[docs] def get_loss( self, embeddings: torch.Tensor, edge_indices: torch.Tensor, y_truth: torch.Tensor, weights: torch.Tensor | None = None, ) -> torch.Tensor: """Compute the loss for the given embeddings and edges. Args: embeddings: embeddings of all the points edge_indices: edge indices y_truth: for each edge (column) in ``edge_indices``, whether this edge is genuine (``True``) or fake (``False``) weights: edge weights Returns: Value of the siamese-like loss """ squared_distances = self.get_squared_distances( embeddings=embeddings, edge_indices=edge_indices ) negative_squared_distances = squared_distances[~y_truth] positive_squared_distances = squared_distances[y_truth] margin = self.hparams["margin"] if weights is None: negative_weights = None positive_weights = None else: negative_weights = weights[~y_truth] positive_weights = weights[y_truth] if negative_weights is None: negative_loss = torch.nn.functional.hinge_embedding_loss( negative_squared_distances, torch.full_like(input=negative_squared_distances, fill_value=-1.0), margin=margin, reduction="mean", ) else: negative_losses = torch.nn.functional.hinge_embedding_loss( negative_squared_distances, torch.full_like(input=negative_squared_distances, fill_value=-1.0), margin=margin, reduction="none", ) negative_loss = (negative_weights * negative_losses).mean() if positive_weights is None: positive_loss = torch.nn.functional.hinge_embedding_loss( positive_squared_distances, torch.full_like(input=positive_squared_distances, fill_value=1.0), margin=margin, reduction="mean", ) else: positive_losses = torch.nn.functional.hinge_embedding_loss( positive_squared_distances, torch.full_like(input=positive_squared_distances, fill_value=1.0), margin=margin, reduction="mean", ) positive_loss = (positive_weights * positive_losses).mean() pos_weight: float = self.hparams.get("weight", 1.0) return negative_loss + pos_weight * positive_loss
[docs] def get_training_edges( self, embeddings: torch.Tensor, true_edge_indices: torch.Tensor, planes: torch.Tensor, query_mask: torch.Tensor | None = None, ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """Get the edges used for the training. Args: embeddings: Embeddings of all the points true_edge_indices: 2D tensor of genuine edge indices particle_ids: tensor of particle IDs for every point. Only used in the ``query_noise_points`` regime planes: tensor of planes for every point. Only used for one-directional graph. Returns: 2D tensor of training edge indices and 1D tensor indicating whether the corresponding edge is genuine or fake. """ # Sanity check if self.hparams.get("query_particle_requirement") is not None: assert query_mask is not None # Get the query points the edges will be drawn from query_indices, query_embeddings = self.get_query_points( embeddings=embeddings, true_edge_indices=true_edge_indices, planes=planes, query_mask=query_mask, ) list_training_edge_indices = [] # Append Hard Negative Mining (hnm) with KNN graph list_training_edge_indices.append( self.get_hnm_pairs( query_embeddings, query_indices, embeddings, planes=planes ) ) # Append random edges pairs (rp) for stability list_training_edge_indices.append( self.get_random_pairs(query_indices, planes=planes) ) # Remove true edge indices that cannot be queried if query_mask is not None: true_edge_indices = true_edge_indices[:, query_mask[true_edge_indices[0]]] training_edge_indices = torch.cat(list_training_edge_indices, dim=-1) # Calculate truth from intersection between prediction graph and truth graph training_edge_indices, y_truth = self.get_truth( edge_indices=training_edge_indices, true_edge_indices=true_edge_indices ) # Append all positive examples and their truth and weighting training_edge_indices, y_truth = self.append_true_pairs( training_edge_indices=training_edge_indices, y_truth=y_truth, true_edge_indices=true_edge_indices, planes=planes, ) # Remove duplicate edge indices training_edge_indices, [y_truth] = remove_duplicate_edges( training_edge_indices, edge_tensors=[y_truth] ) # Sanity check on the plane numbers of the edges assert planes is not None assert torch.all( planes[training_edge_indices[0]] < planes[training_edge_indices[1]] ) return training_edge_indices, y_truth
[docs] def training_validation_step( self, batch: Data, with_grad: bool = False ) -> typing.Dict[str, torch.Tensor]: """Common step for the training and validation steps. This encompasses selecting query hits, drawing edges from them, running the embedding inference and computing the loss. """ features = self.get_features(batch) true_edge_indices = batch["signal_true_edges"] planes = batch["plane"] # Filter out fake hits if self.hparams.get("remove_noise", True): n_hits = features.shape[0] features = features[~batch["fake"]] planes = planes[~batch["fake"]] old_indices = torch.where(~batch["fake"])[0] old_to_new_indices = torch.full( size=(n_hits,), fill_value=-1, device=old_indices.device, dtype=torch.long, ) old_to_new_indices[old_indices] = torch.arange( features.shape[0], device=old_to_new_indices.device ) true_edge_indices = old_to_new_indices[true_edge_indices] assert not ( true_edge_indices == -1 ).any(), "Existing edges between noise hits" ( features, true_edge_indices, reindexed_planes, original_indices, ) = self.remove_planes( features=features, true_edge_index=true_edge_indices, planes=planes ) self.validate_edges(edge_index=true_edge_indices, planes=reindexed_planes) # Get the training sample with torch.no_grad(): embeddings = self(features) training_edge_indices, y_truth = self.get_training_edges( embeddings=embeddings, true_edge_indices=true_edge_indices, planes=reindexed_planes, query_mask=(batch["query_mask"] if "query_mask" in batch.keys() else None), ) # Compute grad only for included hits if with_grad: included_hits = training_edge_indices.unique() embeddings[included_hits] = self(features[included_hits]) # Compute hinge loss if "target_mask" in batch.keys(): target_weight = self.hparams.get("target_weight") non_target_weight = self.hparams.get("non_target_weight") assert ( target_weight is not None or non_target_weight is not None ), "`target_mask` is specified but no target weights are set." edge_target_mask = batch["target_mask"][training_edge_indices[0]] weights = torch.ones( training_edge_indices.shape[1], device=embeddings.device ) if target_weight is not None: weights[edge_target_mask] = target_weight if non_target_weight is not None: weights[~edge_target_mask] = non_target_weight else: weights = None # Sanity check if self.hparams.get("target_particle_requirement") is not None: assert weights is not None loss = self.get_loss( embeddings=embeddings, edge_indices=training_edge_indices, y_truth=y_truth, weights=weights, ) outputs = { "training_edge_indices": training_edge_indices, "y_truth": y_truth, "loss": loss, "embeddings": embeddings, "original_indices": original_indices, } return outputs
[docs] def training_step(self, batch, batch_idx): outputs = self.training_validation_step(batch=batch, with_grad=True) self.log( "train_loss", outputs["loss"], on_epoch=True, on_step=self.hparams.get("on_step", False), batch_size=outputs["training_edge_indices"].shape[1], prog_bar=True, ) return outputs["loss"]
[docs] def validation_step(self, batch: Data, batch_idx: int) -> torch.Tensor: outputs = self.training_validation_step(batch=batch, with_grad=False) training_edge_indices = outputs["training_edge_indices"] embeddings = outputs["embeddings"] # Compute L2 distance between the edges training_edge_distances = self.get_squared_distances( embeddings=embeddings, edge_indices=training_edge_indices ) y_train = training_edge_distances < self.hparams["squared_distance_max"] eff, pur = compute_classification_efficiency_purity( predictions=y_train, truths=outputs["y_truth"] ) current_lr = self.optimizers().param_groups[0]["lr"] # type: ignore self.log_dict( { "val_loss": outputs["loss"], "eff": eff, "pur": pur, "current_lr": current_lr, }, on_epoch=True, on_step=False, batch_size=training_edge_indices.shape[1], ) return outputs["loss"]
[docs] def inference( self, batch: Data, squared_distance_max: float, k_max: int, evaluate: bool = False, overall: bool = False, log: bool = False, ) -> typing.Dict[str, torch.Tensor]: """Run the embedding inference + kNN to build edges of an event. Args: batch: event PyTorch data object squared_distance_max: squared maximal distance in the embedding space k_max: maximal number of neighbours evaluate: whether to also output the loss, efficiency and purity overall: if ``batch`` already contains ``edge_index``, whether to enable concatenaning new edges to the old edge indices instead of replacing them. log: whether to add an entry to the log """ features = self.get_features(batch) batch["x"] = features planes = batch["plane"] # Get the true edge indices true_edge_indices = batch["signal_true_edges"] ( features, true_edge_indices, reindexed_planes, original_indices, ) = self.remove_planes( features=features, planes=planes, true_edge_index=true_edge_indices, ) # Check that the true edge directions make sense self.validate_edges(edge_index=true_edge_indices, planes=reindexed_planes) # Run forward step to get hit embeddings embeddings = self(features) # If `query_planes` is specified, edges can only start from hits # belonging to these planes. query_planes = self.query_planes if query_planes is not None: hit_mask = torch.isin(planes, query_planes) query_indices = torch.arange(embeddings.shape[0], device=embeddings.device)[ hit_mask ] query_embeddings = embeddings[hit_mask] else: query_indices = None query_embeddings = None # Build whole KNN graph edge_indices = self.build_edges( embeddings=embeddings, squared_distance_max=squared_distance_max, k_max=k_max, planes=reindexed_planes, query_embeddings=query_embeddings, query_indices=query_indices, ) # Put the hit indices back to when no planes where removed if original_indices is not None: edge_indices = original_indices[edge_indices] true_edge_indices = original_indices[true_edge_indices] if query_planes is not None: edge_mask = torch.isin(planes[true_edge_indices[0]], query_planes) true_edge_indices = true_edge_indices[:, edge_mask] true_edge_indices = true_edge_indices[ :, planes[true_edge_indices[1]] == planes[true_edge_indices[0]] + 1, ] # Extract the edge targets edge_indices, y_truth = self.get_truth(edge_indices, true_edge_indices) # Check that the edge directions make sense self.validate_edges(edge_index=edge_indices, planes=planes) evaluation_outputs = {} # If batch already contains edge indices, concatenate the new edge indices # to this array. # This allow to stack several embedding networks + kNN to build the edges # of the overall event. if ( overall and "edge_index" in batch.keys() and (already_defined_edge_indices := batch["edge_index"]) is not None ): assert "y" in batch.keys() overall_edge_indices = torch.cat( (already_defined_edge_indices, edge_indices), dim=-1, ) overall_y_truth = torch.cat((batch["y"], y_truth), dim=-1) else: overall_edge_indices = edge_indices overall_y_truth = y_truth # In evaluation mode, the loss, efficiency and purity are computed # and can also be logged. if evaluate: if original_indices is not None: original_embeddings = torch.zeros( size=(planes.shape[0], embeddings.shape[1]), device=embeddings.device, dtype=embeddings.dtype, ) original_embeddings[original_indices] = embeddings else: original_embeddings = embeddings evaluation_outputs["loss"] = self.get_loss( embeddings=original_embeddings, edge_indices=edge_indices, y_truth=y_truth, ) eff, pur = compute_efficiency_purity( n_truths=true_edge_indices.shape[1], n_true_positives=overall_y_truth.sum().cpu().numpy(), n_positives=overall_edge_indices.shape[1], ) evaluation_outputs["eff"] = eff evaluation_outputs["pur"] = pur if log: evaluation_outputs["lr"] = self.optimizers().param_groups[0]["lr"] # type: ignore self.log_dict( evaluation_outputs, on_epoch=True, on_step=False, batch_size=edge_indices.shape[1], ) return { "edge_indices": overall_edge_indices, "y_truth": overall_y_truth, "true_edge_indices": true_edge_indices, "overall_true_edge_indices": batch["signal_true_edges"], **evaluation_outputs, }
@property def input_kwargs(self) -> typing.Dict[str, typing.Any]: return { **super(EmbeddingBase, self).input_kwargs, "input": dict(size=(1, 3), dtype=torch.float), } @property def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]: return { **super(EmbeddingBase, self).subnetwork_to_outputs, "default": ["output"], } @property def input_to_dynamic_axes(self): return { **super(EmbeddingBase, self).input_to_dynamic_axes, "input": {0: "n_nodes"}, "output": {0: "n_nodes"}, } def _onnx_default(self, input: torch.Tensor) -> torch.Tensor: return self.forward(input)
[docs] def to_onnx( self, outpath: str, mode: typing.Literal["default"] | None = None, options: typing.Iterable[str] | None = None, ) -> None: """Save model to an ONNX file. Args: outpath: Path where to save the ONNX model. mode: only ``default`` mode is supported. """ super(EmbeddingBase, self).to_onnx(outpath=outpath, mode=mode, options=options)
[docs]def get_example_data( path_or_config: str | dict, idx: int = 0 ) -> typing.Tuple[pd.DataFrame, Data]: configs = load_config(path_or_config) embedding_configs = configs["embedding"] model = EmbeddingBase(embedding_configs) training_example = model.valset[idx] example_hit_inputs = model.get_input_data(training_example.x) example_hit_df = pd.DataFrame(example_hit_inputs.numpy()) return example_hit_df, training_example