Source code for pipeline.Preprocessing.hit_filtering

"""A module that implements the ability of filter hits, grouping them by
particles.
"""
from __future__ import annotations
import typing
from math import isclose

import numpy as np
import numpy.typing as npt
import pandas as pd
import numba as nb

from utils.tools.tgroupby import get_group_indices


@nb.jit(nopython=True, cache=True)
def cut_long_tracks_impl(
    array_mask: npt.NDArray[np.bool_],
    event_ids: npt.NDArray[np.int_],
    particle_ids: npt.NDArray[np.int_],
    track_sizes: npt.NDArray[np.int_],
    proportions: npt.NDArray[np.float_],
    rng: np.random.Generator,
) -> None:
    """Cut long tracks to get smaller tracks. The first hits are removed.

    Args:
        array_mask: the array of mask that indicates the hits that are kept
        event_ids: array of event IDs
        particle_ids: array of particle IDs
        track_sizes: array of track sizes
        proportions: array of proportion of track sizes, corresponding to
            ``track_sizes``
        rng: random generator for which track is cut to which size
    """
    indices_groupby_events = get_group_indices(event_ids)

    for event_start_idx, event_end_idx in zip(
        indices_groupby_events[:-1],
        indices_groupby_events[1:],
    ):
        event_array_mask = array_mask[event_start_idx:event_end_idx]

        event_particle_ids = particle_ids[event_start_idx:event_end_idx]
        indices_groupby_particles = get_group_indices(event_particle_ids)

        n_particles = event_end_idx - event_start_idx

        # Array that contains for every event how many hits it should have
        # given `track_size_proportions`
        event_required_n_hits = np.zeros(
            shape=n_particles,
            dtype=indices_groupby_particles.dtype,
        )

        current_particle_idx = 0
        for track_size, proportion in zip(track_sizes, proportions):
            # how many particles with the given track size
            n_particles_for_track_size = round(n_particles * proportion)

            event_required_n_hits[
                current_particle_idx : current_particle_idx + n_particles_for_track_size
            ] = track_size
            current_particle_idx += n_particles_for_track_size

        # Shuffle track size
        rng.shuffle(event_required_n_hits)

        for particle_start_idx, particle_end_idx, particle_required_n_hits in zip(
            indices_groupby_particles[:-1],
            indices_groupby_particles[1:],
            event_required_n_hits,
        ):
            particle_array_mask = event_array_mask[particle_start_idx:particle_end_idx]
            if particle_required_n_hits != 0:
                particle_array_mask[-particle_required_n_hits:] = True
            elif particle_required_n_hits == -1:  # keep all hits
                particle_array_mask[:] = True


[docs]def mask_long_into_small_tracks( hits_particles: pd.DataFrame, track_size_proportions: typing.Dict[int, float], seed: int | None = None, ) -> pd.Series: """Create a mask to remove the first hits of long tracks to match the proportions of track sizes given as input. Args: hits_particles: dataframe of hits-particles track_sizes: dictionary that associates a track size with the expected proportion after the cut. seed: Random seed for which track is cut to which size Returns: Pandas series indexed by `event`, `particle_id` and `hit_id`, which indicates which hits are kept """ track_sizes = np.array(list(track_size_proportions.keys())) proportions = np.array(list(track_size_proportions.values())) sum_props = sum(track_size_proportions.values()) assert isclose( sum_props, 1.0 ), f"The sum of proportions in `track_sizes` is equal to {sum_props} != 1.0" sorted_hits_particles = hits_particles.sort_values( by=["event_id", "particle_id", "plane"] ) array_mask = np.zeros(shape=(sorted_hits_particles.shape[0]), dtype=bool) rng = np.random.default_rng(seed=seed) cut_long_tracks_impl( array_mask=array_mask, event_ids=sorted_hits_particles["event_id"].to_numpy(), particle_ids=sorted_hits_particles["particle_id"].to_numpy(), track_sizes=track_sizes, proportions=proportions, rng=rng, ) return pd.Series( array_mask, index=sorted_hits_particles.index, name="hits_mask", ).reindex(hits_particles.index)