"""Define the base class to infer on data.
"""
import typing
from types import ModuleType
import abc
import os
import logging
from functools import partial
from pathlib import Path
from joblib import Parallel, delayed
from tqdm.auto import tqdm
import torch
from pytorch_lightning import LightningModule
from torch_geometric.data import Data
from utils.tools.tfiles import delete_directory
from utils.scriptutils.loghandler import configure_logger
configure_logger()
[docs]class BuilderBase(abc.ABC):
"""Base class for looping over input files located in a directory, processing
them and saving the output in a different directory.
"""
def __init__(self) -> None:
pass
[docs] def infer(
self,
input_dir: str,
output_dir: str,
reproduce: bool = True,
processing: str | typing.List[str] | None = None,
file_names: typing.List[str] | None = None,
n_workers: int = 1,
):
"""Load the torch datasets located in ``input_dir``, run the model inference
and save the output in ``output_dir``.
Args:
input_dir: input directory path
output_dir: output directory path
reproduce: whether to delete the output directory if it exists,
and run again the inference
processing: name(s) of supplementary function(s) that process the event.
after :py:func:`ModelBase.construct_downstream`.
file_names: list of file names to run the inference on. If not specified,
the inference is run on all the datasets located in the input directory.
parallel:
Whether to run the inference in parallel. This seems quite unstable...
"""
# List paths to the input files
if file_names is None:
file_names = [
file_.path for file_ in os.scandir(input_dir) if file_.name != "done"
]
assert len(file_names) > 0, f"No input files in {input_dir}"
if reproduce:
delete_directory(output_dir)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(os.path.join(output_dir, "done")):
logging.info(
f"Output folder is not empty so the inference was not run: {output_dir}"
)
else:
logging.info(f"Inference from {input_dir} to {output_dir}")
with torch.no_grad():
infer_one_step_partial = partial(
self.infer_one_step,
input_dir=input_dir,
output_dir=output_dir,
processing=processing,
)
if n_workers == 1:
for file_name in tqdm(file_names):
infer_one_step_partial(file_name=file_name)
else:
self._parallel_run(
n_workers=n_workers,
infer_one_step_partial=infer_one_step_partial,
file_names=file_names,
)
Path(os.path.join(output_dir, "done")).touch()
def _parallel_run(
self,
n_workers: int,
infer_one_step_partial: typing.Callable[[str], None],
file_names: typing.List[str],
):
return Parallel(n_jobs=n_workers)(
delayed(infer_one_step_partial)(file_name) for file_name in tqdm(file_names)
)
[docs] def infer_one_step(
self,
file_name: str,
input_dir: str,
output_dir: str,
processing: str | typing.List[str] | None = None,
):
"""Run the inference on a single file and save the output in another file.
Args:
file_name: input file name
input_dir: input directory path
output_dir: output directory path
processing: name(s) of supplementary function(s) that process the event.
after :py:func:`ModelBase.construct_downstream`.
"""
input_path = os.path.join(input_dir, file_name)
batch = self.load_batch(input_path)
batch = self.process_one_step(
batch=batch,
processing=processing,
)
self.save_downstream(batch, os.path.join(output_dir, batch.event_str))
[docs] def process_one_step(
self,
batch: Data,
processing: str | typing.List[str] | None = None,
) -> Data:
"""Process one event.
Args:
batch: event stored in a PyTorch Geometric data object
processing: name(s) of supplementary function(s) that process the event.
after :py:func:`ModelBase.construct_downstream`.
Returns:
Processed event, first by :py:func:`BuilderBase.construct_downstream`,
then by the filtering and building functions provided as inputs.
"""
batch = self.construct_downstream(batch)
if processing is not None:
# Apply processing functions (building or filtering)
processing_fct_names = (
[processing] if isinstance(processing, str) else processing
)
for processing_fct_name in processing_fct_names:
processing_fct = getattr(
self._get_building_custom_module(), str(processing_fct_name)
)
batch = processing_fct(batch)
return batch
def _get_building_custom_module(self) -> ModuleType:
"""Return the module where the building and filtering functions are."""
raise NotImplementedError()
[docs] def load_batch(self, input_path: str) -> Data:
"""Load a PyTorch Data object from its path.
Might apply necessary pre-processing.
"""
return torch.load(input_path, map_location=torch.device("cpu"))
[docs] def filter_batch(self, batch: Data) -> Data:
"""Filter the batch. This should only performed in the train and val
sets.
Args:
batch: PyTorch Data Geometric object
Returns:
filtered batch
"""
return batch
[docs] def build_weights(self, batch: Data) -> Data:
"""Builder weights in the batch for training.
This should only be needed in the train and val sets.
Args:
batch: PyTorch Data Geometric object
Returns:
filtered batch
"""
return batch
[docs] def build_features(self, batch: Data) -> Data:
return batch
[docs] @abc.abstractmethod
def construct_downstream(self, batch: Data):
"""Run the inference on a PyTorch Data. In-place."""
raise NotImplementedError
[docs] def save_downstream(self, batch: Data, output_path: str):
"""Save the PyTorch data object ``data`` in ``output_path``."""
with open(output_path, "wb") as pickle_file:
torch.save(batch, pickle_file)
[docs]class ModelBuilderBase(BuilderBase):
"""Base class for model inference."""
def __init__(self, model: LightningModule) -> None:
self.model = model
model.eval()
[docs] def load_batch(self, input_path: str) -> Data:
"""Load a PyTorch Data object from its path.
Might apply necessary pre-processing.
"""
return torch.load(input_path, map_location=self.model.device)