Source code for pipeline.utils.modelutils.basemodel

"""Define a base model for GNN and Embedding, to avoid copy of functions.
"""

from __future__ import annotations
import typing
import logging
import inspect
import os
import os.path as op

from tqdm.auto import tqdm
import numpy as np
import numpy.typing as npt
import torch
from torch.utils.data import Subset
from pytorch_lightning import LightningModule
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from utils.commonutils.cfeatures import get_input_features, get_number_input_features
from utils.loaderutils.dataiterator import LazyDatasetBase


[docs]def check_and_discard(s: typing.Set[typing.Any], element: typing.Any) -> bool: if element in s: s.discard(element) return True return False
#: Associates a column name with a lambda function that takes as input the batch #: object and returns the compute column feature_to_compute_fct: typing.Dict[str, typing.Callable[[Data], torch.Tensor]] = { "r": lambda batch: torch.sqrt(batch["un_x"] ** 2 + batch["un_y"] ** 2), "phi": lambda batch: torch.arctan2(batch["un_y"], batch["un_x"]), }
[docs]class ModelBase(LightningModule): def __init__(self, hparams): super().__init__() self._trainset = None self._valset = None self._testset: typing.List[Data] | None = None self.save_hyperparameters(hparams) self._idx_trainset_split: int | None = None self._trainset_split_indices: typing.List[npt.NDArray] | None = None
[docs] def setup(self, stage): self.load_partition("train") self.load_partition("val") self._testset = None
@property def lazy(self) -> bool: """Whether to load the training set and val set into memory only when needed. """ return self.hparams["lazy"] @property def feature_names(self) -> typing.List[str] | None: """List of node feature names.""" return self.hparams.get("feature_names") @property def feature_means(self) -> typing.List[float] | None: """List of feature means corresponding to the feature names listed in :py:attr:`~features`. They are used for normalising the node features. """ return self.hparams.get("feature_means") @property def feature_scales(self) -> typing.List[float] | None: """List of feature scales corresponding to the feature names listed in :py:attr:`~features`. They are used for normalising the node features. """ return self.hparams.get("feature_scales") @property def feature_indices(self) -> int | typing.List[int] | None: """I want to deprecate this...""" return self.hparams.get("feature_indices") @property def on_step(self) -> bool: """Whether to log on step.""" return self.hparams.get("on_step", False) @property def trainset(self) -> typing.List[Data] | LazyDatasetBase: if self._trainset is None: self.load_partition(partition="train") assert self._trainset is not None return self._trainset @trainset.setter def trainset(self, batches: typing.List[Data]): self._trainset = batches @property def valset(self) -> typing.List[Data]: if self._valset is None: self.load_partition(partition="val") assert self._valset is not None return self._valset @property def testset(self) -> typing.List[Data]: if self._testset is None: raise ValueError( "Test set not loaded. Please load it with `fetch_partition` " "or `load_testset_from_directory`." ) else: return self._testset @valset.setter def valset(self, batches: typing.List[Data]): self._valset = batches
[docs] def load_trainset_split_indices(self, trainset_split: int): data_indices = np.arange(len(self.trainset)) np.random.shuffle(data_indices) sub_data_indices = np.array_split(data_indices, trainset_split) self._trainset_split_indices = sub_data_indices self._idx_trainset_split = 0
[docs] def train_dataloader(self): """Train dataloader, with random splitting of epochs.""" print("Load train dataloader.") trainset = self.trainset if len(trainset) > 0: if (trainset_split := self.hparams.get("trainset_split")) is not None: if not isinstance(trainset, LazyDatasetBase): raise TypeError( "In order to use the `trainset_split` property, " "the trainset should be loaded in a lazy way. " "Please consider switching `lazy` to `True`." ) if self._trainset_split_indices is None: print("Define random splitting of epochs") self.load_trainset_split_indices(trainset_split) assert self._idx_trainset_split is not None assert self._trainset_split_indices is not None print("Load subset number", self._idx_trainset_split) trainset = Subset( trainset, self._trainset_split_indices[self._idx_trainset_split], # type: ignore ) # Prepare next already self._idx_trainset_split += 1 if self._idx_trainset_split == len(self._trainset_split_indices): self._trainset_split_indices = None self._idx_trainset_split = None shuffle = False else: trainset = self.trainset shuffle = True return DataLoader( trainset, # type: ignore batch_size=1, num_workers=8, shuffle=shuffle, ) else: return None
[docs] def val_dataloader(self): """Validation dataloader.""" if len(self.valset) > 0: return DataLoader(self.valset, batch_size=1, num_workers=0) else: return None
[docs] def test_dataloader(self): """Test dataloader.""" if self._testset is not None and len(self._testset) > 0: return DataLoader(self.testset, batch_size=1, num_workers=8) else: return None
[docs] def get_lazy_dataset( self, input_dir: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, **kwargs, ) -> LazyDatasetBase: """Get the lazy dataset object. Args: input_dir: input directory n_events: number of events to load shuffle: whether to shuffle the input paths (applied before selected the first ``n_events``) seed: seed for the shuffling **kwargs: Other keyword arguments passed to the :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` constructor. Returns: :py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` object """ return LazyDatasetBase( input_dir=input_dir, n_events=n_events, shuffle=shuffle, seed=seed, **kwargs, )
[docs] def fetch_datasets(self, lazy_dataset: LazyDatasetBase) -> typing.List[Data]: """Get the datasets located in a given directory. Args: input_dir: input directory n_events: number of events to load shuffle: whether to shuffle the input paths (applied before selected the first ``n_events``) seed: seed for the shuffling **kwargs: Other keyword arguments passed to :py:func:`ModelBase.get_lazy_dataset` Returns: List of loaded PyTorch Geometric Data objects """ logging.info( f"Load {len(lazy_dataset)} files located in {lazy_dataset.input_dir}" ) return [event for event in tqdm(iter(lazy_dataset), total=len(lazy_dataset))]
[docs] def load_testset_from_directory(self, input_dir: str, **kwargs): """Load a test dataset from a path to a directory. Args: input_dir: path to the directory that contains the PyTorch Geometric Data pickles files. """ lazy_dataset = self.get_lazy_dataset(input_dir=input_dir, **kwargs) self._testset = self.fetch_datasets(lazy_dataset=lazy_dataset)
[docs] def get_lazy_dataset_partition( self, partition: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, **kwargs, ) -> LazyDatasetBase: """Get the lazy dataset of a partition. Args: partition: ``train``, ``val`` or name of the test dataset n_events: number of events to load shuffle: whether to shuffle the input paths (applied before selected the first ``n_events``) seed: seed for the shuffling **kwargs: Other keyword arguments passed to :py:func:`ModelBase.get_lazy_dataset` Returns: Lazy dataset of the ``partition`` """ if partition in ["train", "val"]: lazy_dataset = self.get_lazy_dataset( input_dir=op.join(self.hparams["input_dir"], partition), n_events=( self.hparams.get(f"n_{partition}_events") if n_events is None else n_events ), shuffle=shuffle, seed=seed, **kwargs, ) else: lazy_dataset = self.get_lazy_dataset( input_dir=op.join(self.hparams["input_dir"], "test", partition), n_events=n_events, shuffle=shuffle, seed=seed, **kwargs, ) return lazy_dataset
[docs] def fetch_partition( self, partition: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, **kwargs, ) -> typing.List[Data] | LazyDatasetBase: """Load a partition. Args: partition: ``train``, ``val`` or name of the test dataset n_events: number of events to load for this partition shuffle: whether to shuffle the input paths (applied before selected the first ``n_events``) seed: seed for the shuffling **kwargs: Other keyword arguments passed to :py:func:`ModelBase.fetch_dataset` """ lazy_dataset = self.get_lazy_dataset_partition( partition=partition, n_events=n_events, shuffle=shuffle, seed=seed, **kwargs, ) if partition == "train" and self.lazy: return lazy_dataset else: return self.fetch_datasets(lazy_dataset=lazy_dataset)
[docs] def load_partition( self, partition: str, n_events: int | None = None, shuffle: bool = False, seed: int | None = None, ): """Load datasets of a partition. Args: partition: ``train``, ``val`` or name of the test dataset n_events: number of events to load for this partition shuffle: whether to shuffle the input paths (applied before selected the first ``n_events``) seed: seed for the shuffling """ datasets = self.fetch_partition( partition=partition, n_events=n_events, shuffle=shuffle, seed=seed, ) if partition == "train": self._trainset = datasets elif partition == "val": assert not isinstance(datasets, LazyDatasetBase) # shouldn't be the case self._valset = datasets else: assert not isinstance(datasets, LazyDatasetBase) # shouldn't be the case self._testset = datasets
[docs] def get_features(self, batch: Data) -> torch.Tensor: """Get the features of a batch, using as input for inference. If :py:attr:`~feature_names` is provided, they are used to build the tensor of node features, normalising them using :py:attr:`~feature_means` and :py:attr:`~feature_scales`. Otherwise, ``batch["x"]`` is returned. Args: batch: batch of nodes (typically an event) Returns: tensor of node features Notes: No gradient is recorded. """ with torch.no_grad(): if self.feature_names is not None: # fall back to the default behaviour feature_names = self.feature_names feature_means = self.feature_means feature_scales = self.feature_scales if feature_means is None: raise ValueError( "`features` was specified but `feature_means` was not" ) elif feature_scales is None: raise ValueError( "`features` was specified but `feature_scales` was not" ) elif not ( len(feature_names) == len(feature_means) == len(feature_scales) ): raise ValueError( "`feature_names`, `feature_means` and `feature_scales` have " "different sizes: " f"{len(feature_names)} ; {len(feature_means) } ; { len(feature_scales)}" ) else: list_normalised_features = [] for feature_name, feature_mean, feature_scale in zip( feature_names, feature_means, feature_scales ): if feature_name in batch: # feature is already in batch feature: torch.Tensor = batch[feature_name] else: # not there -> try to compute it computation_fct = feature_to_compute_fct.get(feature_name) if computation_fct is None: raise ValueError( f"Feature {feature_name} was not found in batch " "and was not found in `feature_to_compute_fct`" ) else: feature = computation_fct(batch) normalised_feature = (feature - feature_mean) / feature_scale list_normalised_features.append(normalised_feature.float()) return torch.stack(list_normalised_features, dim=-1) else: # fall back to default behaviour return get_input_features( batch["x"], feature_indices=self.feature_indices )
[docs] def get_n_features(self) -> int: """Number of input features of the network.""" if self.feature_names is not None: return len(self.feature_names) else: # fall back to default (to be deprecrated) assert self.feature_indices is not None return get_number_input_features(self.feature_indices)
[docs] def configure_optimizers(self): """Use Adam optimizer and step learning rate scheduler.""" optimizer = [ torch.optim.AdamW( self.parameters(), lr=(self.hparams["lr"]), betas=(0.9, 0.999), eps=1e-08, amsgrad=True, ) ] scheduler = [ { "scheduler": torch.optim.lr_scheduler.StepLR( optimizer[0], step_size=self.hparams["patience"], gamma=self.hparams["factor"], ), "interval": "epoch", "frequency": 1, } ] return optimizer, scheduler
[docs] def optimizer_step( self, epoch, batch_idx, optimizer, optimizer_closure, ): """Modified version of the optimizer step that implements warm up and properly enforce the learning rate. """ # warm up lr if (self.hparams["warmup"] is not None) and ( self.current_epoch < self.hparams["warmup"] ): lr_scale = min(1.0, float(self.current_epoch + 1) / self.hparams["warmup"]) for pg in optimizer.param_groups: # type: ignore pg["lr"] = lr_scale * self.hparams["lr"] else: for pg in optimizer.param_groups: # type: ignore pg["lr"] = self.lr_schedulers().get_last_lr()[0] # type: ignore # update params optimizer.step(closure=optimizer_closure) optimizer.zero_grad(set_to_none=True) # type: ignore
[docs] @classmethod def get_model_from_checkpoint( cls, checkpoint: LightningModule | str | None, default_checkpoint: str | None = None, **kwargs, ): """Helper function to get a model at inference step. Args: checkpoint: the model already loaded, or path to it Mode: Model class default_checkpoint: path to fall back to if ``checkpoint`` is None. **kwargs: other parameters passed to :py:func:`Model.load_from_checkpoint` Return: Loaded model """ if isinstance(checkpoint, cls): model = checkpoint elif checkpoint is None: # Default loading mode from last artifact assert ( default_checkpoint is not None ), "Both `checkpoint` and `default_checkpoint` are None." checkpoint = default_checkpoint model = cls.load_from_checkpoint( default_checkpoint, **kwargs, ) logging.info(f"Load model from {checkpoint}.") elif isinstance(checkpoint, str): model = cls.load_from_checkpoint( checkpoint_path=checkpoint, **kwargs, ) logging.info(f"Load model from {checkpoint}.") else: raise TypeError( f"Type of checkpoint is {type(checkpoint).__name__} " "which is not recognised" ) return model
@property def n_trainable_params(self) -> int: """Number of trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) @property def input_kwargs(self) -> typing.Dict[str, typing.Any]: """Associates an input name with a dictionary corresponding to the keyword arguments used to build a dummy tensor representing the input. This dictionary basically gives the ``size`` and ``dtype`` of the tensor. """ return {} @property def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]: """A dictionary that associates a subnetwork actually corresponding to a list of subnetworks, with this list of subnetworks. """ return {} @property def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]: """A dictionary that associates a subnetwork name with the list of its output names.""" return {} @property def subnetworks(self) -> typing.List[str]: """List of subnetworks available. It is derived from :py:attr:`subnetwork_to_outputs`. """ return list(self.subnetwork_to_outputs.keys()) @property def input_to_dynamic_axes(self): """A dictionary that associates an input name with the dynamic axis specification. """ return {}
[docs] def get_subnetwork_inputs(self, subnetwork: str) -> typing.List[str]: """Find the input names of a subnetwork by looking at the signature of its ONNX forward method ``_onnx_{subnetwork}``. Args: subnetwork: subnetwork name Returns: List of the input names of the subnetwork. """ forward_func = self._get_subnetwork_forward_func(subnetwork=subnetwork) return list(inspect.signature(forward_func).parameters.keys())
[docs] def get_subnetwork_outputs(self, subnetwork: str) -> typing.List[str]: """Get the outputs of a subnetwork, as configured by the :py:attr:`subnetwork_to_outputs` property. Args: subnetwork: subnetwork name Returns: List of the output names of the subnetwork. Raises: KeyError: if the outputs of the subnetwork were not specified in the :py:attr:`subnetwork_to_outputs` property. """ outputs = self.subnetwork_to_outputs.get(subnetwork) if outputs is None: raise KeyError( f"The outputs for the subnetwork {subnetwork} were not defined. " "To define it, you can modify the property `subnetwork_to_outputs` " f"of the {self.__class__.__name__} class." ) else: return outputs
def _get_subnetwork_forward_func(self, subnetwork: str): """Get the forward function of a given subnetwork. Args: subnetwork Returns: Method ``_onnx_{subnetwork}`` of this class raises: AttributeError: the method is missing. """ try: forward_func = getattr(self, f"_onnx_{subnetwork}") except AttributeError: raise AttributeError( f"The forward method `_onnx_{subnetwork}` " f"for the subnetwork {subnetwork} was not defined." ) return forward_func
[docs] def to_onnx( self, outpath: str, mode: str | None = None, options: typing.Iterable[str] | None = None, ) -> None: """Export a model to ONNX. Args: outpath: where to save the ONNX file options: ONNX export options """ options = set() if options is None else set(options) # Default mode is the first subnetwork if mode is None: mode = self.subnetworks[0] if (subnetworks := self.subnetwork_groups.get(mode)) is not None: # If the `mode` correspond to a group of subnetworks, # the output path must contain the placeholder `subnetwork`, # which will be replaced the subnetwork name. assert "{subnetwork}" in outpath, ( f"In `{mode}` mode, the output path should contain " "the placeholder {subnetwork}." ) print(f"Mode {mode} contains the following subnetworks:", subnetworks) for subnetwork in subnetworks: self.to_onnx( outpath=outpath.format(subnetwork=subnetwork), mode=subnetwork, options=options, ) else: use_fp16 = check_and_discard(options, "fp16") if options: raise ValueError( "The following options are not recognised: " + ", ".join(options) ) subnetwork = mode # Input names of the subnetwork input_names = self.get_subnetwork_inputs(subnetwork) input_kwargs = self.input_kwargs def _extract_input_kwargs( input_kwargs: typing.Dict[str, typing.Any], input_name: str ) -> typing.Dict[str, typing.Any]: """Extract the keyword parameters used to build the input tensor ``input_path``. The only reason the function was written is to raise an error if the input keyword parameters are not in ``input_kwargs``. """ if (kwargs := input_kwargs.get(input_name)) is not None: return kwargs else: raise KeyError( f"The subnetwork `{subnetwork}` needs the input `{input_name}` " "but the latter was not defined in `input_kwargs`" ) # Dummy input tensors of the subnetwork dummy_inputs = { input_name: torch.zeros( **_extract_input_kwargs(input_kwargs, input_name), device="cuda" ) for input_name in input_names } # Output names of the subnetwork output_names = self.get_subnetwork_outputs(subnetwork) # Add `_out` to the output names that are the same as the input names output_names_named_as_input = list( set(output_names).intersection(input_names) ) modified_output_names = [ ( f"{output_name}_out" if output_name in output_names_named_as_input else output_name ) for output_name in output_names ] os.makedirs(os.path.dirname(outpath), exist_ok=True) print(f"{subnetwork} input names:", ", ".join(input_names)) print(f"{subnetwork} output names:", ", ".join(modified_output_names)) torch.onnx.export( model=ModelONNXExport(model=self, subnetwork=subnetwork), args=tuple(dummy_inputs[input_name] for input_name in input_names), f=outpath, verbose=False, # Names to assign to the input nodes of the graph, in order input_names=input_names, # Names to assign to the output nodes of the graph, in order output_names=modified_output_names, # Apply the constant-folding optimisation: # replace some of the ops that have all constant inputs with pre-computed # constant nodes do_constant_folding=True, opset_version=17, dynamic_axes={ modified_name: self.input_to_dynamic_axes[name] for (name, modified_name) in zip( input_names + output_names, input_names + modified_output_names ) }, ) print("Model was exported to", os.path.abspath(outpath)) if use_fp16: from utils.modelutils.export import convert_model_to_fp16 convert_model_to_fp16(outpath)
[docs]class ModelONNXExport(torch.nn.Module): """Class used to export the forward pass of a subnetwork within a :py:class:`TripletGNNBase` model. Attributes: model: triplet GNN model subnetwork: name of the subnetwork to export """ def __init__(self, model: ModelBase, subnetwork: str): super(ModelONNXExport, self).__init__() self.model = model self.subnetwork = str(subnetwork)
[docs] def forward(self, *args) -> typing.Any: """Forward pass to use when the model is exported to ONNX.""" forward_func = self.model._get_subnetwork_forward_func( subnetwork=self.subnetwork ) return forward_func(*args)