"""Define some global utilies for plots.
"""
from __future__ import annotations
import typing
import os
import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from matplotlib.text import Text
@typing.overload
def get_figs_axes_on_grid(
nrows: int, ncols: int, same_fig: typing.Literal[True] = True, **kwargs
) -> typing.Tuple[Figure, npt.NDArray]: ...
@typing.overload
def get_figs_axes_on_grid(
nrows: int, ncols: int, same_fig: typing.Literal[False], **kwargs
) -> typing.Tuple[npt.NDArray, npt.NDArray]: ...
@typing.overload
def get_figs_axes_on_grid(
nrows: int, ncols: int, same_fig: bool, **kwargs
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray]: ...
[docs]def get_figs_axes_on_grid(
nrows: int,
ncols: int,
same_fig: bool = True,
figsize: typing.Tuple[int, int] | typing.List[typing.Tuple[int, int]] = (8, 6),
**kwargs,
) -> typing.Tuple[Figure | npt.NDArray, npt.NDArray]:
if same_fig:
fig, axes = plt.subplots(
nrows,
ncols,
figsize=(figsize[0] * ncols, figsize[1] * nrows),
**kwargs,
)
return fig, axes
else:
figs = []
axes = []
for _ in range(nrows):
line_figs = []
line_axes = []
for _ in range(ncols):
fig, ax = plt.subplots(figsize=figsize, **kwargs)
line_figs.append(fig)
line_axes.append(ax)
figs.append(line_figs)
axes.append(line_axes)
figs = np.array(figs)
axes = np.array(axes)
if nrows == 1:
figs = figs[0]
axes = axes[0]
return np.array(figs), np.array(axes)
[docs]def pad_on_top(ax: Axes, factor: float = 1.09):
"""Pad add a figure on its top by a multiplicative factor."""
ymin, ymax = ax.get_ylim()
ax.set_ylim(ymin, (ymax - ymin) * factor + ymin)
[docs]def add_text(
ax: Axes,
ha: str | None = None,
va: str | None = None,
x: float | None = None,
y: float | None = None,
text: str = "LHCb Run 3 Simulation",
fontsize: float | None = 20,
**kwargs,
) -> Text:
"""Add text inside a matplotlib figure.
Args:
ax: matplotlib ax
ha: Horizontal alignment: ``left``, ``center`` or ``right``
va: vertical alignment: ``top``, ``center`` or ``bottom``
x: Relative position along ``x``
y: Relative position along ``y``
"""
ha_to_x = {
"left": 0.02,
"right": 0.98,
"center": 0.50,
}
va_to_y = {
"top": 0.95,
"bottom": 0.02,
"center": 0.50,
}
if x is None:
if ha is None:
raise ValueError("If `ha` is not provided, `x` should be")
else:
x = ha_to_x[ha]
if y is None:
if va is None:
raise ValueError("If `va` is not provided, `y` should be")
else:
y = va_to_y[va]
if ha is None:
if x is None:
raise ValueError("If `x` is not provided, `ha` should be")
else:
ha = "left" if x > 0.5 else "right"
if va is None:
if y is None:
raise ValueError("If `y` is not provided, `va` should be")
else:
va = "top" if y > 0.5 else "bottom"
return ax.text(
x,
y,
text,
verticalalignment=va,
horizontalalignment=ha,
transform=ax.transAxes,
fontsize=fontsize,
**kwargs,
)
[docs]def set_same_y_lim_for_all_axes(
axes: typing.Iterable[Axes], ymin: float | None = None, ymax: float | None = None
):
"""Set the same (most extended) y limit for all the axes given as input.
Args:
axes: an iterable of axes
ymin: Enforced min value of the y-axis
ymin: Enforced max value of the y-axis
"""
# First find the lowest min and the largest max
if ymin is None:
ymin = min([ax.get_ylim()[0] for ax in axes])
if ymax is None:
ymax = max([ax.get_ylim()[1] for ax in axes])
# Affect this value to the axes
for ax in axes:
ax.set_ylim(ymin, ymax)
[docs]def hide_repetitive_labels_in_grid(axes: typing.Collection[typing.Collection[Axes]]):
for idx_line, line_axes in enumerate(axes):
for idx_col, ax in enumerate(line_axes):
# Disable x tick labels except on the last line
if idx_line != len(axes) - 1:
ax.tick_params(
axis="x",
labelbottom=False,
)
ax.xaxis.label.set_visible(False)
# Disable y tick labels except on the first column
if idx_col != 0:
ax.tick_params(
axis="y",
labelleft=False,
)
ax.yaxis.label.set_visible(False)
[docs]def save_fig(
fig: Figure,
path: str,
exts: typing.List[str] = [".pdf", ".png"],
bbox_inches: str | None = "tight",
**kwargs,
):
"""Save a figure.
Args:
fig: Matplotlib figure to save
path: path where to save the figure
"""
os.makedirs(os.path.dirname(path), exist_ok=True)
fig.tight_layout()
path_without_ext, ext = os.path.splitext(path)
if ext not in exts:
path_without_ext = path
for ext in exts:
assert ext.startswith(".")
overall_path = path_without_ext + ext
fig.savefig(
path_without_ext + ext, format=ext[1:], bbox_inches=bbox_inches, **kwargs
)
print("Figure was saved in", overall_path)