Source code for pipeline.Embedding.build_embedding

from __future__ import annotations
from types import ModuleType
import typing

from tqdm.auto import tqdm
import torch
from torch_geometric.data import Data
import torch.multiprocessing as tm

from Embedding.embedding_base import EmbeddingBase
from utils.modelutils.build import ModelBuilderBase
from utils.commonutils.config import load_config

from . import process_custom


[docs]def get_squared_distance_max_from_config(path_or_config: str | dict) -> float: """Get the value of ``squared_distance_max_inference``, or fall back to ``squared_distance_max`` if ``squared_distance_max_inference`` was set to ``None``. Args: path_or_config: configuration dictionary, or path to the YAML file that contains the configuration Returns: Squared maximal distance used for the embedding inference """ config = load_config(path_or_config=path_or_config) if ( squared_distance_max_inference := config["embedding"][ "squared_distance_max_inference" ] ) is not None: squared_distance_max = squared_distance_max_inference else: squared_distance_max = config["embedding"]["squared_distance_max"] assert isinstance( squared_distance_max, float ), f"The distance max is {squared_distance_max}, which is not a float." return squared_distance_max
[docs]class EmbeddingInferenceBuilder(ModelBuilderBase): def __init__( self, model: EmbeddingBase, k_max: int = 1000, squared_distance_max: float = 0.1, max_plane_diff: int | None = None, ): super(EmbeddingInferenceBuilder, self).__init__(model=model) self.k_max = k_max self.squared_distance_max = squared_distance_max self.max_plane_diff = max_plane_diff def _parallel_run( self, n_workers: int, infer_one_step_partial: typing.Callable[[str], None], file_names: typing.List[str], ): self.model.share_memory() with tm.Pool(n_workers) as pool: list( tqdm( pool.imap(infer_one_step_partial, file_names, chunksize=64), total=len(file_names), ) )
[docs] def construct_downstream(self, batch: Data): """Run embedding inference and kNN. Add the edges and their targets to the event data object. """ self.model: EmbeddingBase outputs = self.model.inference( batch=batch, squared_distance_max=self.squared_distance_max, k_max=self.k_max, evaluate=False, log=False, ) edge_indices = outputs["edge_indices"] y_truth = outputs["y_truth"] if ( "edge_index" in batch and (original_edge_indices := batch["edge_index"]) is not None ): assert "y" in batch batch["edge_index"] = torch.cat( (original_edge_indices, edge_indices), dim=-1, ) batch["y"] = torch.cat( (batch["y"], y_truth), dim=-1, ) batch["n_original_edge_index"] = original_edge_indices.shape[1] else: batch["edge_index"] = edge_indices batch["y"] = y_truth return batch
def _get_building_custom_module(self) -> ModuleType: return process_custom