Source code for pipeline.utils.plotutils.plotools

"""Define some global utilies for plots.
"""

from __future__ import annotations
import typing
import os

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.text import Text


@typing.overload
def get_figs_axes_on_grid(
    nrows: int, ncols: int, same_fig: typing.Literal[True] = True, **kwargs
) -> typing.Tuple[Figure, npt.NDArray]: ...


@typing.overload
def get_figs_axes_on_grid(
    nrows: int, ncols: int, same_fig: typing.Literal[False], **kwargs
) -> typing.Tuple[npt.NDArray, npt.NDArray]: ...


@typing.overload
def get_figs_axes_on_grid(
    nrows: int, ncols: int, same_fig: bool, **kwargs
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray]: ...


[docs]def get_figs_axes_on_grid( nrows: int, ncols: int, same_fig: bool = True, figsize: typing.Tuple[int, int] | typing.List[typing.Tuple[int, int]] = (8, 6), **kwargs, ) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray]: if same_fig: fig, axes = plt.subplots( nrows, ncols, figsize=(figsize[0] * ncols, figsize[1] * nrows), **kwargs, ) return fig, axes else: figs = [] axes = [] for _ in range(nrows): line_figs = [] line_axes = [] for _ in range(ncols): fig, ax = plt.subplots(figsize=figsize, **kwargs) line_figs.append(fig) line_axes.append(ax) figs.append(line_figs) axes.append(line_axes) figs = np.array(figs) axes = np.array(axes) if nrows == 1: figs = figs[0] axes = axes[0] return np.array(figs), np.array(axes)
[docs]def pad_on_top(ax: Axes, factor: float = 1.09): """Pad add a figure on its top by a multiplicative factor.""" ymin, ymax = ax.get_ylim() ax.set_ylim(ymin, (ymax - ymin) * factor + ymin)
[docs]def add_text( ax: Axes, ha: str | None = None, va: str | None = None, x: float | None = None, y: float | None = None, text: str = "LHCb Run 3 Simulation", fontsize: float | None = 20, **kwargs, ) -> Text: """Add text inside a matplotlib figure. Args: ax: matplotlib ax ha: Horizontal alignment: ``left``, ``center`` or ``right`` va: vertical alignment: ``top``, ``center`` or ``bottom`` x: Relative position along ``x`` y: Relative position along ``y`` """ ha_to_x = { "left": 0.02, "right": 0.98, "center": 0.50, } va_to_y = { "top": 0.95, "bottom": 0.02, "center": 0.50, } if x is None: if ha is None: raise ValueError("If `ha` is not provided, `x` should be") else: x = ha_to_x[ha] if y is None: if va is None: raise ValueError("If `va` is not provided, `y` should be") else: y = va_to_y[va] if ha is None: if x is None: raise ValueError("If `x` is not provided, `ha` should be") else: ha = "left" if x > 0.5 else "right" if va is None: if y is None: raise ValueError("If `y` is not provided, `va` should be") else: va = "top" if y > 0.5 else "bottom" return ax.text( x, y, text, verticalalignment=va, horizontalalignment=ha, transform=ax.transAxes, fontsize=fontsize, **kwargs, )
[docs]def set_same_y_lim_for_all_axes( axes: typing.Iterable[Axes], ymin: float | None = None, ymax: float | None = None ): """Set the same (most extended) y limit for all the axes given as input. Args: axes: an iterable of axes ymin: Enforced min value of the y-axis ymin: Enforced max value of the y-axis """ # First find the lowest min and the largest max if ymin is None: ymin = min([ax.get_ylim()[0] for ax in axes]) if ymax is None: ymax = max([ax.get_ylim()[1] for ax in axes]) # Affect this value to the axes for ax in axes: ax.set_ylim(ymin, ymax)
[docs]def hide_repetitive_labels_in_grid(axes: typing.Collection[typing.Collection[Axes]]): for idx_line, line_axes in enumerate(axes): for idx_col, ax in enumerate(line_axes): # Disable x tick labels except on the last line if idx_line != len(axes) - 1: ax.tick_params( axis="x", labelbottom=False, ) ax.xaxis.label.set_visible(False) # Disable y tick labels except on the first column if idx_col != 0: ax.tick_params( axis="y", labelleft=False, ) ax.yaxis.label.set_visible(False)
[docs]def save_fig( fig: Figure, path: str, exts: typing.List[str] = [".pdf", ".png"], bbox_inches: str | None = "tight", **kwargs, ): """Save a figure. Args: fig: Matplotlib figure to save path: path where to save the figure """ os.makedirs(os.path.dirname(path), exist_ok=True) fig.tight_layout() path_without_ext, ext = os.path.splitext(path) if ext not in exts: path_without_ext = path for ext in exts: assert ext.startswith(".") overall_path = path_without_ext + ext fig.savefig( path_without_ext + ext, format=ext[1:], bbox_inches=bbox_inches, **kwargs ) print("Figure was saved in", overall_path)