Source code for pipeline.utils.modelutils.build

"""Define the base class to infer on data.
"""
import typing
from types import ModuleType
import abc
import os
import logging
from functools import partial
from pathlib import Path

from joblib import Parallel, delayed
from tqdm.auto import tqdm
import torch
from pytorch_lightning import LightningModule
from torch_geometric.data import Data

from utils.tools.tfiles import delete_directory
from utils.scriptutils.loghandler import configure_logger

configure_logger()


[docs]class BuilderBase(abc.ABC): """Base class for looping over input files located in a directory, processing them and saving the output in a different directory. """ def __init__(self) -> None: pass
[docs] def infer( self, input_dir: str, output_dir: str, reproduce: bool = True, processing: str | typing.List[str] | None = None, file_names: typing.List[str] | None = None, n_workers: int = 1, ): """Load the torch datasets located in ``input_dir``, run the model inference and save the output in ``output_dir``. Args: input_dir: input directory path output_dir: output directory path reproduce: whether to delete the output directory if it exists, and run again the inference processing: name(s) of supplementary function(s) that process the event. after :py:func:`ModelBase.construct_downstream`. file_names: list of file names to run the inference on. If not specified, the inference is run on all the datasets located in the input directory. parallel: Whether to run the inference in parallel. This seems quite unstable... """ # List paths to the input files if file_names is None: file_names = [ file_.path for file_ in os.scandir(input_dir) if file_.name != "done" ] assert len(file_names) > 0, f"No input files in {input_dir}" if reproduce: delete_directory(output_dir) os.makedirs(output_dir, exist_ok=True) if os.path.exists(os.path.join(output_dir, "done")): logging.info( f"Output folder is not empty so the inference was not run: {output_dir}" ) else: logging.info(f"Inference from {input_dir} to {output_dir}") with torch.no_grad(): infer_one_step_partial = partial( self.infer_one_step, input_dir=input_dir, output_dir=output_dir, processing=processing, ) if n_workers == 1: for file_name in tqdm(file_names): infer_one_step_partial(file_name=file_name) else: self._parallel_run( n_workers=n_workers, infer_one_step_partial=infer_one_step_partial, file_names=file_names, ) Path(os.path.join(output_dir, "done")).touch()
def _parallel_run( self, n_workers: int, infer_one_step_partial: typing.Callable[[str], None], file_names: typing.List[str], ): return Parallel(n_jobs=n_workers)( delayed(infer_one_step_partial)(file_name) for file_name in tqdm(file_names) )
[docs] def infer_one_step( self, file_name: str, input_dir: str, output_dir: str, processing: str | typing.List[str] | None = None, ): """Run the inference on a single file and save the output in another file. Args: file_name: input file name input_dir: input directory path output_dir: output directory path processing: name(s) of supplementary function(s) that process the event. after :py:func:`ModelBase.construct_downstream`. """ input_path = os.path.join(input_dir, file_name) batch = self.load_batch(input_path) batch = self.process_one_step( batch=batch, processing=processing, ) self.save_downstream(batch, os.path.join(output_dir, batch.event_str))
[docs] def process_one_step( self, batch: Data, processing: str | typing.List[str] | None = None, ) -> Data: """Process one event. Args: batch: event stored in a PyTorch Geometric data object processing: name(s) of supplementary function(s) that process the event. after :py:func:`ModelBase.construct_downstream`. Returns: Processed event, first by :py:func:`BuilderBase.construct_downstream`, then by the filtering and building functions provided as inputs. """ batch = self.construct_downstream(batch) if processing is not None: # Apply processing functions (building or filtering) processing_fct_names = ( [processing] if isinstance(processing, str) else processing ) for processing_fct_name in processing_fct_names: processing_fct = getattr( self._get_building_custom_module(), str(processing_fct_name) ) batch = processing_fct(batch) return batch
def _get_building_custom_module(self) -> ModuleType: """Return the module where the building and filtering functions are.""" raise NotImplementedError()
[docs] def load_batch(self, input_path: str) -> Data: """Load a PyTorch Data object from its path. Might apply necessary pre-processing. """ return torch.load(input_path, map_location=torch.device("cpu"))
[docs] def filter_batch(self, batch: Data) -> Data: """Filter the batch. This should only performed in the train and val sets. Args: batch: PyTorch Data Geometric object Returns: filtered batch """ return batch
[docs] def build_weights(self, batch: Data) -> Data: """Builder weights in the batch for training. This should only be needed in the train and val sets. Args: batch: PyTorch Data Geometric object Returns: filtered batch """ return batch
[docs] def build_features(self, batch: Data) -> Data: return batch
[docs] @abc.abstractmethod def construct_downstream(self, batch: Data): """Run the inference on a PyTorch Data. In-place.""" raise NotImplementedError
[docs] def save_downstream(self, batch: Data, output_path: str): """Save the PyTorch data object ``data`` in ``output_path``.""" with open(output_path, "wb") as pickle_file: torch.save(batch, pickle_file)
[docs]class ModelBuilderBase(BuilderBase): """Base class for model inference.""" def __init__(self, model: LightningModule) -> None: self.model = model model.eval()
[docs] def load_batch(self, input_path: str) -> Data: """Load a PyTorch Data object from its path. Might apply necessary pre-processing. """ return torch.load(input_path, map_location=self.model.device)