Source code for pipeline.Preprocessing.hit_filtering
"""A module that implements the ability of filter hits, grouping them by
particles.
"""
from __future__ import annotations
import typing
from math import isclose
import numpy as np
import numpy.typing as npt
import pandas as pd
import numba as nb
from utils.tools.tgroupby import get_group_indices
@nb.jit(nopython=True, cache=True)
def cut_long_tracks_impl(
    array_mask: npt.NDArray[np.bool_],
    event_ids: npt.NDArray[np.int_],
    particle_ids: npt.NDArray[np.int_],
    track_sizes: npt.NDArray[np.int_],
    proportions: npt.NDArray[np.float_],
    rng: np.random.Generator,
) -> None:
    """Cut long tracks to get smaller tracks. The first hits are removed.
    Args:
        array_mask: the array of mask that indicates the hits that are kept
        event_ids: array of event IDs
        particle_ids: array of particle IDs
        track_sizes: array of track sizes
        proportions: array of proportion of track sizes, corresponding to
            ``track_sizes``
        rng: random generator for which track is cut to which size
    """
    indices_groupby_events = get_group_indices(event_ids)
    for event_start_idx, event_end_idx in zip(
        indices_groupby_events[:-1],
        indices_groupby_events[1:],
    ):
        event_array_mask = array_mask[event_start_idx:event_end_idx]
        event_particle_ids = particle_ids[event_start_idx:event_end_idx]
        indices_groupby_particles = get_group_indices(event_particle_ids)
        n_particles = event_end_idx - event_start_idx
        # Array that contains for every event how many hits it should have
        # given `track_size_proportions`
        event_required_n_hits = np.zeros(
            shape=n_particles,
            dtype=indices_groupby_particles.dtype,
        )
        current_particle_idx = 0
        for track_size, proportion in zip(track_sizes, proportions):
            # how many particles with the given track size
            n_particles_for_track_size = round(n_particles * proportion)
            event_required_n_hits[
                current_particle_idx : current_particle_idx + n_particles_for_track_size
            ] = track_size
            current_particle_idx += n_particles_for_track_size
        # Shuffle track size
        rng.shuffle(event_required_n_hits)
        for particle_start_idx, particle_end_idx, particle_required_n_hits in zip(
            indices_groupby_particles[:-1],
            indices_groupby_particles[1:],
            event_required_n_hits,
        ):
            particle_array_mask = event_array_mask[particle_start_idx:particle_end_idx]
            if particle_required_n_hits != 0:
                particle_array_mask[-particle_required_n_hits:] = True
            elif particle_required_n_hits == -1:  # keep all hits
                particle_array_mask[:] = True
[docs]def mask_long_into_small_tracks(
    hits_particles: pd.DataFrame,
    track_size_proportions: typing.Dict[int, float],
    seed: int | None = None,
) -> pd.Series:
    """Create a mask to remove the first hits of long tracks to match the proportions
    of track sizes given as input.
    Args:
        hits_particles: dataframe of hits-particles
        track_sizes: dictionary that associates a track size with the expected
            proportion after the cut.
        seed: Random seed for which track is cut to which size
    Returns:
        Pandas series indexed by `event`, `particle_id` and `hit_id`,
        which indicates which hits are kept
    """
    track_sizes = np.array(list(track_size_proportions.keys()))
    proportions = np.array(list(track_size_proportions.values()))
    sum_props = sum(track_size_proportions.values())
    assert isclose(
        sum_props, 1.0
    ), f"The sum of proportions in `track_sizes` is equal to {sum_props} != 1.0"
    sorted_hits_particles = hits_particles.sort_values(
        by=["event_id", "particle_id", "plane"]
    )
    array_mask = np.zeros(shape=(sorted_hits_particles.shape[0]), dtype=bool)
    rng = np.random.default_rng(seed=seed)
    cut_long_tracks_impl(
        array_mask=array_mask,
        event_ids=sorted_hits_particles["event_id"].to_numpy(),
        particle_ids=sorted_hits_particles["particle_id"].to_numpy(),
        track_sizes=track_sizes,
        proportions=proportions,
        rng=rng,
    )
    return pd.Series(
        array_mask,
        index=sorted_hits_particles.index,
        name="hits_mask",
    ).reindex(hits_particles.index)