Source code for pipeline.Embedding.build_embedding
from __future__ import annotations
from types import ModuleType
import typing
from tqdm.auto import tqdm
import torch
from torch_geometric.data import Data
import torch.multiprocessing as tm
from Embedding.embedding_base import EmbeddingBase
from utils.modelutils.build import ModelBuilderBase
from utils.commonutils.config import load_config
from . import process_custom
[docs]def get_squared_distance_max_from_config(path_or_config: str | dict) -> float:
"""Get the value of ``squared_distance_max_inference``, or fall back
to ``squared_distance_max`` if ``squared_distance_max_inference``
was set to ``None``.
Args:
path_or_config: configuration dictionary, or path to the YAML file that contains
the configuration
Returns:
Squared maximal distance used for the embedding inference
"""
config = load_config(path_or_config=path_or_config)
if (
squared_distance_max_inference := config["embedding"][
"squared_distance_max_inference"
]
) is not None:
squared_distance_max = squared_distance_max_inference
else:
squared_distance_max = config["embedding"]["squared_distance_max"]
assert isinstance(
squared_distance_max, float
), f"The distance max is {squared_distance_max}, which is not a float."
return squared_distance_max
[docs]class EmbeddingInferenceBuilder(ModelBuilderBase):
def __init__(
self,
model: EmbeddingBase,
k_max: int = 1000,
squared_distance_max: float = 0.1,
max_plane_diff: int | None = None,
):
super(EmbeddingInferenceBuilder, self).__init__(model=model)
self.k_max = k_max
self.squared_distance_max = squared_distance_max
self.max_plane_diff = max_plane_diff
def _parallel_run(
self,
n_workers: int,
infer_one_step_partial: typing.Callable[[str], None],
file_names: typing.List[str],
):
self.model.share_memory()
with tm.Pool(n_workers) as pool:
list(
tqdm(
pool.imap(infer_one_step_partial, file_names, chunksize=64),
total=len(file_names),
)
)
[docs] def construct_downstream(self, batch: Data):
"""Run embedding inference and kNN. Add the edges and their targets to
the event data object.
"""
self.model: EmbeddingBase
outputs = self.model.inference(
batch=batch,
squared_distance_max=self.squared_distance_max,
k_max=self.k_max,
evaluate=False,
log=False,
)
edge_indices = outputs["edge_indices"]
y_truth = outputs["y_truth"]
if (
"edge_index" in batch
and (original_edge_indices := batch["edge_index"]) is not None
):
assert "y" in batch
batch["edge_index"] = torch.cat(
(original_edge_indices, edge_indices),
dim=-1,
)
batch["y"] = torch.cat(
(batch["y"], y_truth),
dim=-1,
)
batch["n_original_edge_index"] = original_edge_indices.shape[1]
else:
batch["edge_index"] = edge_indices
batch["y"] = y_truth
return batch
def _get_building_custom_module(self) -> ModuleType:
return process_custom