"""Define a base model for GNN and Embedding, to avoid copy of functions.
"""
from __future__ import annotations
import typing
import logging
import inspect
import os
import os.path as op
from tqdm.auto import tqdm
import numpy as np
import numpy.typing as npt
import torch
from torch.utils.data import Subset
from pytorch_lightning import LightningModule
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from utils.commonutils.cfeatures import get_input_features, get_number_input_features
from utils.loaderutils.dataiterator import LazyDatasetBase
[docs]def check_and_discard(s: typing.Set[typing.Any], element: typing.Any) -> bool:
if element in s:
s.discard(element)
return True
return False
#: Associates a column name with a lambda function that takes as input the batch
#: object and returns the compute column
feature_to_compute_fct: typing.Dict[str, typing.Callable[[Data], torch.Tensor]] = {
"r": lambda batch: torch.sqrt(batch["un_x"] ** 2 + batch["un_y"] ** 2),
"phi": lambda batch: torch.arctan2(batch["un_y"], batch["un_x"]),
}
[docs]class ModelBase(LightningModule):
def __init__(self, hparams):
super().__init__()
self._trainset = None
self._valset = None
self._testset: typing.List[Data] | None = None
self.save_hyperparameters(hparams)
self._idx_trainset_split: int | None = None
self._trainset_split_indices: typing.List[npt.NDArray] | None = None
[docs] def setup(self, stage):
self.load_partition("train")
self.load_partition("val")
self._testset = None
@property
def lazy(self) -> bool:
"""Whether to load the training set and val set into memory only when
needed.
"""
return self.hparams["lazy"]
@property
def feature_names(self) -> typing.List[str] | None:
"""List of node feature names."""
return self.hparams.get("feature_names")
@property
def feature_means(self) -> typing.List[float] | None:
"""List of feature means corresponding to the feature names listed in
:py:attr:`~features`. They are used for normalising the node features.
"""
return self.hparams.get("feature_means")
@property
def feature_scales(self) -> typing.List[float] | None:
"""List of feature scales corresponding to the feature names listed in
:py:attr:`~features`. They are used for normalising the node features.
"""
return self.hparams.get("feature_scales")
@property
def feature_indices(self) -> int | typing.List[int] | None:
"""I want to deprecate this..."""
return self.hparams.get("feature_indices")
@property
def on_step(self) -> bool:
"""Whether to log on step."""
return self.hparams.get("on_step", False)
@property
def trainset(self) -> typing.List[Data] | LazyDatasetBase:
if self._trainset is None:
self.load_partition(partition="train")
assert self._trainset is not None
return self._trainset
@trainset.setter
def trainset(self, batches: typing.List[Data]):
self._trainset = batches
@property
def valset(self) -> typing.List[Data]:
if self._valset is None:
self.load_partition(partition="val")
assert self._valset is not None
return self._valset
@property
def testset(self) -> typing.List[Data]:
if self._testset is None:
raise ValueError(
"Test set not loaded. Please load it with `fetch_partition` "
"or `load_testset_from_directory`."
)
else:
return self._testset
@valset.setter
def valset(self, batches: typing.List[Data]):
self._valset = batches
[docs] def load_trainset_split_indices(self, trainset_split: int):
data_indices = np.arange(len(self.trainset))
np.random.shuffle(data_indices)
sub_data_indices = np.array_split(data_indices, trainset_split)
self._trainset_split_indices = sub_data_indices
self._idx_trainset_split = 0
[docs] def train_dataloader(self):
"""Train dataloader, with random splitting of epochs."""
print("Load train dataloader.")
trainset = self.trainset
if len(trainset) > 0:
if (trainset_split := self.hparams.get("trainset_split")) is not None:
if not isinstance(trainset, LazyDatasetBase):
raise TypeError(
"In order to use the `trainset_split` property, "
"the trainset should be loaded in a lazy way. "
"Please consider switching `lazy` to `True`."
)
if self._trainset_split_indices is None:
print("Define random splitting of epochs")
self.load_trainset_split_indices(trainset_split)
assert self._idx_trainset_split is not None
assert self._trainset_split_indices is not None
print("Load subset number", self._idx_trainset_split)
trainset = Subset(
trainset,
self._trainset_split_indices[self._idx_trainset_split], # type: ignore
)
# Prepare next already
self._idx_trainset_split += 1
if self._idx_trainset_split == len(self._trainset_split_indices):
self._trainset_split_indices = None
self._idx_trainset_split = None
shuffle = False
else:
trainset = self.trainset
shuffle = True
return DataLoader(
trainset, # type: ignore
batch_size=1,
num_workers=8,
shuffle=shuffle,
)
else:
return None
[docs] def val_dataloader(self):
"""Validation dataloader."""
if len(self.valset) > 0:
return DataLoader(self.valset, batch_size=1, num_workers=0)
else:
return None
[docs] def test_dataloader(self):
"""Test dataloader."""
if self._testset is not None and len(self._testset) > 0:
return DataLoader(self.testset, batch_size=1, num_workers=8)
else:
return None
[docs] def get_lazy_dataset(
self,
input_dir: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
**kwargs,
) -> LazyDatasetBase:
"""Get the lazy dataset object.
Args:
input_dir: input directory
n_events: number of events to load
shuffle: whether to shuffle the input paths (applied before
selected the first ``n_events``)
seed: seed for the shuffling
**kwargs: Other keyword arguments passed to the
:py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` constructor.
Returns:
:py:class:`utils.loaderutils.dataiterator.LazyDatasetBase` object
"""
return LazyDatasetBase(
input_dir=input_dir,
n_events=n_events,
shuffle=shuffle,
seed=seed,
**kwargs,
)
[docs] def fetch_datasets(self, lazy_dataset: LazyDatasetBase) -> typing.List[Data]:
"""Get the datasets located in a given directory.
Args:
input_dir: input directory
n_events: number of events to load
shuffle: whether to shuffle the input paths (applied before
selected the first ``n_events``)
seed: seed for the shuffling
**kwargs: Other keyword arguments passed to
:py:func:`ModelBase.get_lazy_dataset`
Returns:
List of loaded PyTorch Geometric Data objects
"""
logging.info(
f"Load {len(lazy_dataset)} files located in {lazy_dataset.input_dir}"
)
return [event for event in tqdm(iter(lazy_dataset), total=len(lazy_dataset))]
[docs] def load_testset_from_directory(self, input_dir: str, **kwargs):
"""Load a test dataset from a path to a directory.
Args:
input_dir: path to the directory that contains the PyTorch Geometric Data
pickles files.
"""
lazy_dataset = self.get_lazy_dataset(input_dir=input_dir, **kwargs)
self._testset = self.fetch_datasets(lazy_dataset=lazy_dataset)
[docs] def get_lazy_dataset_partition(
self,
partition: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
**kwargs,
) -> LazyDatasetBase:
"""Get the lazy dataset of a partition.
Args:
partition: ``train``, ``val`` or name of the test dataset
n_events: number of events to load
shuffle: whether to shuffle the input paths (applied before
selected the first ``n_events``)
seed: seed for the shuffling
**kwargs: Other keyword arguments passed to
:py:func:`ModelBase.get_lazy_dataset`
Returns:
Lazy dataset of the ``partition``
"""
if partition in ["train", "val"]:
lazy_dataset = self.get_lazy_dataset(
input_dir=op.join(self.hparams["input_dir"], partition),
n_events=(
self.hparams.get(f"n_{partition}_events")
if n_events is None
else n_events
),
shuffle=shuffle,
seed=seed,
**kwargs,
)
else:
lazy_dataset = self.get_lazy_dataset(
input_dir=op.join(self.hparams["input_dir"], "test", partition),
n_events=n_events,
shuffle=shuffle,
seed=seed,
**kwargs,
)
return lazy_dataset
[docs] def fetch_partition(
self,
partition: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
**kwargs,
) -> typing.List[Data] | LazyDatasetBase:
"""Load a partition.
Args:
partition: ``train``, ``val`` or name of the test dataset
n_events: number of events to load for this partition
shuffle: whether to shuffle the input paths (applied before
selected the first ``n_events``)
seed: seed for the shuffling
**kwargs: Other keyword arguments passed to
:py:func:`ModelBase.fetch_dataset`
"""
lazy_dataset = self.get_lazy_dataset_partition(
partition=partition,
n_events=n_events,
shuffle=shuffle,
seed=seed,
**kwargs,
)
if partition == "train" and self.lazy:
return lazy_dataset
else:
return self.fetch_datasets(lazy_dataset=lazy_dataset)
[docs] def load_partition(
self,
partition: str,
n_events: int | None = None,
shuffle: bool = False,
seed: int | None = None,
):
"""Load datasets of a partition.
Args:
partition: ``train``, ``val`` or name of the test dataset
n_events: number of events to load for this partition
shuffle: whether to shuffle the input paths (applied before
selected the first ``n_events``)
seed: seed for the shuffling
"""
datasets = self.fetch_partition(
partition=partition,
n_events=n_events,
shuffle=shuffle,
seed=seed,
)
if partition == "train":
self._trainset = datasets
elif partition == "val":
assert not isinstance(datasets, LazyDatasetBase) # shouldn't be the case
self._valset = datasets
else:
assert not isinstance(datasets, LazyDatasetBase) # shouldn't be the case
self._testset = datasets
[docs] def get_features(self, batch: Data) -> torch.Tensor:
"""Get the features of a batch, using as input for inference.
If :py:attr:`~feature_names` is provided, they are used to build
the tensor of node features, normalising them using :py:attr:`~feature_means`
and :py:attr:`~feature_scales`.
Otherwise, ``batch["x"]`` is returned.
Args:
batch: batch of nodes (typically an event)
Returns:
tensor of node features
Notes:
No gradient is recorded.
"""
with torch.no_grad():
if self.feature_names is not None: # fall back to the default behaviour
feature_names = self.feature_names
feature_means = self.feature_means
feature_scales = self.feature_scales
if feature_means is None:
raise ValueError(
"`features` was specified but `feature_means` was not"
)
elif feature_scales is None:
raise ValueError(
"`features` was specified but `feature_scales` was not"
)
elif not (
len(feature_names) == len(feature_means) == len(feature_scales)
):
raise ValueError(
"`feature_names`, `feature_means` and `feature_scales` have "
"different sizes: "
f"{len(feature_names)} ; {len(feature_means) } ; { len(feature_scales)}"
)
else:
list_normalised_features = []
for feature_name, feature_mean, feature_scale in zip(
feature_names, feature_means, feature_scales
):
if feature_name in batch: # feature is already in batch
feature: torch.Tensor = batch[feature_name]
else: # not there -> try to compute it
computation_fct = feature_to_compute_fct.get(feature_name)
if computation_fct is None:
raise ValueError(
f"Feature {feature_name} was not found in batch "
"and was not found in `feature_to_compute_fct`"
)
else:
feature = computation_fct(batch)
normalised_feature = (feature - feature_mean) / feature_scale
list_normalised_features.append(normalised_feature.float())
return torch.stack(list_normalised_features, dim=-1)
else: # fall back to default behaviour
return get_input_features(
batch["x"], feature_indices=self.feature_indices
)
[docs] def get_n_features(self) -> int:
"""Number of input features of the network."""
if self.feature_names is not None:
return len(self.feature_names)
else: # fall back to default (to be deprecrated)
assert self.feature_indices is not None
return get_number_input_features(self.feature_indices)
[docs] def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_closure,
):
"""Modified version of the optimizer step that implements warm up
and properly enforce the learning rate.
"""
# warm up lr
if (self.hparams["warmup"] is not None) and (
self.current_epoch < self.hparams["warmup"]
):
lr_scale = min(1.0, float(self.current_epoch + 1) / self.hparams["warmup"])
for pg in optimizer.param_groups: # type: ignore
pg["lr"] = lr_scale * self.hparams["lr"]
else:
for pg in optimizer.param_groups: # type: ignore
pg["lr"] = self.lr_schedulers().get_last_lr()[0] # type: ignore
# update params
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad(set_to_none=True) # type: ignore
[docs] @classmethod
def get_model_from_checkpoint(
cls,
checkpoint: LightningModule | str | None,
default_checkpoint: str | None = None,
**kwargs,
):
"""Helper function to get a model at inference step.
Args:
checkpoint: the model already loaded, or path to it
Mode: Model class
default_checkpoint: path to fall back to if ``checkpoint`` is None.
**kwargs: other parameters passed to :py:func:`Model.load_from_checkpoint`
Return:
Loaded model
"""
if isinstance(checkpoint, cls):
model = checkpoint
elif checkpoint is None: # Default loading mode from last artifact
assert (
default_checkpoint is not None
), "Both `checkpoint` and `default_checkpoint` are None."
checkpoint = default_checkpoint
model = cls.load_from_checkpoint(
default_checkpoint,
**kwargs,
)
logging.info(f"Load model from {checkpoint}.")
elif isinstance(checkpoint, str):
model = cls.load_from_checkpoint(
checkpoint_path=checkpoint,
**kwargs,
)
logging.info(f"Load model from {checkpoint}.")
else:
raise TypeError(
f"Type of checkpoint is {type(checkpoint).__name__} "
"which is not recognised"
)
return model
@property
def n_trainable_params(self) -> int:
"""Number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@property
def input_kwargs(self) -> typing.Dict[str, typing.Any]:
"""Associates an input name with a dictionary corresponding to
the keyword arguments used to build a dummy tensor representing the input.
This dictionary basically gives the ``size`` and ``dtype`` of the tensor.
"""
return {}
@property
def subnetwork_groups(self) -> typing.Dict[str, typing.List[str]]:
"""A dictionary that associates a subnetwork actually corresponding
to a list of subnetworks, with this list of subnetworks.
"""
return {}
@property
def subnetwork_to_outputs(self) -> typing.Dict[str, typing.List[str]]:
"""A dictionary that associates a subnetwork name with the list of its
output names."""
return {}
@property
def subnetworks(self) -> typing.List[str]:
"""List of subnetworks available. It is derived from
:py:attr:`subnetwork_to_outputs`.
"""
return list(self.subnetwork_to_outputs.keys())
@property
def input_to_dynamic_axes(self):
"""A dictionary that associates an input name
with the dynamic axis specification.
"""
return {}
[docs] def get_subnetwork_outputs(self, subnetwork: str) -> typing.List[str]:
"""Get the outputs of a subnetwork, as configured
by the :py:attr:`subnetwork_to_outputs` property.
Args:
subnetwork: subnetwork name
Returns:
List of the output names of the subnetwork.
Raises:
KeyError: if the outputs of the subnetwork were not specified
in the :py:attr:`subnetwork_to_outputs` property.
"""
outputs = self.subnetwork_to_outputs.get(subnetwork)
if outputs is None:
raise KeyError(
f"The outputs for the subnetwork {subnetwork} were not defined. "
"To define it, you can modify the property `subnetwork_to_outputs` "
f"of the {self.__class__.__name__} class."
)
else:
return outputs
def _get_subnetwork_forward_func(self, subnetwork: str):
"""Get the forward function of a given subnetwork.
Args:
subnetwork
Returns:
Method ``_onnx_{subnetwork}`` of this class
raises:
AttributeError: the method is missing.
"""
try:
forward_func = getattr(self, f"_onnx_{subnetwork}")
except AttributeError:
raise AttributeError(
f"The forward method `_onnx_{subnetwork}` "
f"for the subnetwork {subnetwork} was not defined."
)
return forward_func
[docs] def to_onnx(
self,
outpath: str,
mode: str | None = None,
options: typing.Iterable[str] | None = None,
) -> None:
"""Export a model to ONNX.
Args:
outpath: where to save the ONNX file
options: ONNX export options
"""
options = set() if options is None else set(options)
# Default mode is the first subnetwork
if mode is None:
mode = self.subnetworks[0]
if (subnetworks := self.subnetwork_groups.get(mode)) is not None:
# If the `mode` correspond to a group of subnetworks,
# the output path must contain the placeholder `subnetwork`,
# which will be replaced the subnetwork name.
assert "{subnetwork}" in outpath, (
f"In `{mode}` mode, the output path should contain "
"the placeholder {subnetwork}."
)
print(f"Mode {mode} contains the following subnetworks:", subnetworks)
for subnetwork in subnetworks:
self.to_onnx(
outpath=outpath.format(subnetwork=subnetwork),
mode=subnetwork,
options=options,
)
else:
use_fp16 = check_and_discard(options, "fp16")
if options:
raise ValueError(
"The following options are not recognised: " + ", ".join(options)
)
subnetwork = mode
# Input names of the subnetwork
input_names = self.get_subnetwork_inputs(subnetwork)
input_kwargs = self.input_kwargs
def _extract_input_kwargs(
input_kwargs: typing.Dict[str, typing.Any], input_name: str
) -> typing.Dict[str, typing.Any]:
"""Extract the keyword parameters used to build the input tensor
``input_path``.
The only reason the function was written is to raise an error
if the input keyword parameters are not in ``input_kwargs``.
"""
if (kwargs := input_kwargs.get(input_name)) is not None:
return kwargs
else:
raise KeyError(
f"The subnetwork `{subnetwork}` needs the input `{input_name}` "
"but the latter was not defined in `input_kwargs`"
)
# Dummy input tensors of the subnetwork
dummy_inputs = {
input_name: torch.zeros(
**_extract_input_kwargs(input_kwargs, input_name), device="cuda"
)
for input_name in input_names
}
# Output names of the subnetwork
output_names = self.get_subnetwork_outputs(subnetwork)
# Add `_out` to the output names that are the same as the input names
output_names_named_as_input = list(
set(output_names).intersection(input_names)
)
modified_output_names = [
(
f"{output_name}_out"
if output_name in output_names_named_as_input
else output_name
)
for output_name in output_names
]
os.makedirs(os.path.dirname(outpath), exist_ok=True)
print(f"{subnetwork} input names:", ", ".join(input_names))
print(f"{subnetwork} output names:", ", ".join(modified_output_names))
torch.onnx.export(
model=ModelONNXExport(model=self, subnetwork=subnetwork),
args=tuple(dummy_inputs[input_name] for input_name in input_names),
f=outpath,
verbose=False,
# Names to assign to the input nodes of the graph, in order
input_names=input_names,
# Names to assign to the output nodes of the graph, in order
output_names=modified_output_names,
# Apply the constant-folding optimisation:
# replace some of the ops that have all constant inputs with pre-computed
# constant nodes
do_constant_folding=True,
opset_version=17,
dynamic_axes={
modified_name: self.input_to_dynamic_axes[name]
for (name, modified_name) in zip(
input_names + output_names, input_names + modified_output_names
)
},
)
print("Model was exported to", os.path.abspath(outpath))
if use_fp16:
from utils.modelutils.export import convert_model_to_fp16
convert_model_to_fp16(outpath)
[docs]class ModelONNXExport(torch.nn.Module):
"""Class used to export the forward pass of a subnetwork within
a :py:class:`TripletGNNBase` model.
Attributes:
model: triplet GNN model
subnetwork: name of the subnetwork to export
"""
def __init__(self, model: ModelBase, subnetwork: str):
super(ModelONNXExport, self).__init__()
self.model = model
self.subnetwork = str(subnetwork)
[docs] def forward(self, *args) -> typing.Any:
"""Forward pass to use when the model is exported to ONNX."""
forward_func = self.model._get_subnetwork_forward_func(
subnetwork=self.subnetwork
)
return forward_func(*args)