from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Optional
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
import skimage
from ipywidgets import interact
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.figure import Figure
from matplotlib.patches import Rectangle
from matplotlib.transforms import Affine2D
from scipy.signal import windows
from PyLorentz.visualize.colorwheel import get_cmap, shift_cmap_center
_HAS_TORCH = False
if TYPE_CHECKING:
# from torch import Tensor
import torch
from torch import Tensor
else:
try:
import torch
from torch import Tensor
_HAS_TORCH = True
except: # noqa: E722
_HAS_TORCH = False
Tensor = np.ndarray
_HAS_CUPY = False
if TYPE_CHECKING:
import cupy as cp
else:
try:
import cupy as cp
_HAS_CUPY = True
except: # noqa: E722
_HAS_CUPY = False
# warnings.filterwarnings("error") # plt.tight_layout() sometimes throws a UserWarning
warnings.filterwarnings("ignore", category=DeprecationWarning)
[docs]
def show_im(
image: np.ndarray,
title: str | None = None,
scale: Optional[float] = None,
simple: bool = False,
save: Optional[str] = None,
cmap: str | Colormap = "gray",
figax: Optional[tuple[Figure, Axes]] = None,
roi: Optional[dict] = None,
cbar_title: Optional[str] = None,
cbar: Optional[bool] = None,
intensity_range: str = "minmax",
origin: str = "upper",
**kwargs,
) -> None:
"""
Display an image with optional features like colorbar, title, and scale.
Args:
image (np.ndarray): 2D image to be displayed.
title (str, optional): Title of the plot.
scale (float, optional): Scale in nm/pixel for axis markers.
simple (bool): Simplified display without colorbar and labels.
save (str, optional): Path to save the figure.
cmap (str): Colormap for displaying the image.
figax (tuple, optional): Figure and axis to plot on.
roi (dict, optional): Region of interest to highlight.
cbar_title (str, optional): Title for the colorbar.
cbar (bool, optional): Whether to display the colorbar.
intensity_range (str): Method to set intensity range. 'minmax' or 'ordered'.
origin (str): Image origin ('upper' or 'lower').
**kwargs: Additional keyword arguments.
Returns:
None
"""
try:
image = np.array(image) # dtype = float64?
except (TypeError, RuntimeError):
if _HAS_CUPY:
if isinstance(image, cp.ndarray):
image = cp.asnumpy(image)
if _HAS_TORCH:
if isinstance(image, torch.Tensor):
image = image.cpu().detach().numpy()
if not isinstance(image, np.ndarray):
raise TypeError(f"Image should be np.ndarray, got {type(image)}")
if image.dtype == "bool":
image = image.astype("int")
if cbar is None:
cbar = not simple
ndim = np.ndim(image)
if ndim == 2:
pass
elif ndim == 3:
if image.shape[2] not in [3, 4] and image.shape[0] not in [3, 4]:
if image.shape[0] != 1:
print("Summing along first axis")
image = np.sum(image, axis=0)
else:
cbar = False
else:
print(f"Input image is of dimension {ndim}. Please input a 2D or 3D image.")
return
if figax is not None:
fig, ax = figax
else:
aspect = image.shape[0] / image.shape[1]
size = kwargs.pop("figsize", (5, 5 * aspect))
if isinstance(size, (int, float)):
size = (size, size)
if simple and title is None:
fig = plt.figure()
fig.set_size_inches(size)
ax = Axes(fig, (0.0, 0.0, 1.0, 1.0))
fig.add_axes(ax)
else:
fig, ax = plt.subplots(figsize=size)
if intensity_range.lower() in ["minmax", "abs", "absolute"]:
vmin = kwargs.get("vmin", None)
vmax = kwargs.get("vmax", None)
vm_check = not cbar and np.ptp(image) < 1e-12
ptp_check = np.ptp(image) < 1e-15
mode_check = intensity_range.lower() == "minmax"
if (vm_check or ptp_check) and mode_check:
if vmin is None:
vmin = np.min(image) - 1e-12
if vmax is None:
vmax = np.max(image) + 1e-12
else:
vmin = np.min(image) if vmin is None else vmin
vmax = np.max(image) if vmax is None else vmax
elif intensity_range.lower() in ["ordered", "o", "ord"]:
vmin = kwargs.get("vmin", 0.01)
vmax = kwargs.get("vmax", 0.99)
vals = np.sort(image.ravel())
ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
ind_vmin = np.max([0, ind_vmin])
ind_vmax = np.min([len(vals) - 1, ind_vmax])
vmin = vals[ind_vmin]
vmax = vals[ind_vmax]
if vmax == vmin:
print("vmax = vmin, setting intensity range to full")
vmin = vals[0]
vmax = vals[-1]
else:
raise ValueError(
f"Unknown intensity_range, should be 'minmax' or 'ordered', got {intensity_range}"
)
cmap = get_cmap(cmap, **kwargs)
midpoint = kwargs.get("cmap_midpoint")
if midpoint is not None:
assert isinstance(vmin, float) and isinstance(vmax, float)
cmap = shift_cmap_center(cmap, midpointval=midpoint, vmin=vmin, vmax=vmax)
if kwargs.get("white_to_transparent"):
image = _white_to_transparent(image)
im = ax.matshow(image, origin=origin, vmin=vmin, vmax=vmax, cmap=cmap)
if title is not None:
ax.set_title(str(title), fontsize=kwargs.pop("title_fontsize", 12))
if simple or kwargs.pop("ticks_off", False):
if (title is not None or cbar) or (kwargs.get("show_bbox", False) or not save):
ax.set_xticks([])
ax.set_yticks([])
else:
ax.set_axis_off()
if not kwargs.pop("show_bbox", True):
ax.set_axis_off()
else:
plt.tick_params(axis="x", top=False)
if not kwargs.get("show_bbox", True):
for spine in ax.spines.values():
spine.set_visible(False)
ax.xaxis.tick_bottom()
ax.tick_params(direction=kwargs.get("tick_direction", "out"))
# if scale is None:
if scale is None:
ticks_label = kwargs.get("scale_unit", "pixels")
else:
if isinstance(scale, (tuple, list, np.ndarray)):
assert len(scale) == 2
if scale[0] != scale[1]:
warnings.warn(
"show_im() does not currently support different x/y scales. Using scale[0]"
)
scale = scale[0]
assert isinstance(scale, (float, int))
ax_ysize_inch = ax.get_position().height * fig.get_size_inches()[1]
ax_xsize_inch = ax.get_position().width * fig.get_size_inches()[0]
num_ticks_y = max(round(ax_ysize_inch + 1), 3)
num_ticks_x = max(round(ax_xsize_inch + 1), 3)
fov_y, fov_x = np.array(image.shape)[:2] * scale
ylim = ax.get_ylim()
ymax = ylim[0] if origin == "upper" else ylim[1]
nround_y = len(str(round(ymax * scale))) - 2
floor_fov_y = np.floor(fov_y / 10**nround_y) * 10**nround_y
yticks = np.linspace(0, floor_fov_y / scale, int(num_ticks_y))
if origin == "lower":
yticks = yticks[1:]
ax.set_yticks(yticks - 0.5)
ylabs, unit = tick_label_formatter(
yticks, fov_y, scale, kwargs.get("scale_unit", "nm")
)
ax.set_yticklabels(ylabs)
ticks_label = unit
_, xmax = ax.get_xlim()
nround_x = len(str(round(xmax * scale))) - 2
floor_fov_x = np.floor(fov_x / 10**nround_x) * 10**nround_x
xticks = np.linspace(0, floor_fov_x / scale, int(num_ticks_x))[1:]
ax.set_xticks(xticks - 0.5)
xlabs, unit = tick_label_formatter(
xticks, fov_y, scale, kwargs.get("scale_unit", "nm")
)
ax.set_xticklabels(xlabs)
if kwargs.pop("ticks_label_off", False):
pass
elif origin == "lower":
ax.text(y=-0.5, x=-0.5, s=ticks_label, rotation=-45, va="top", ha="right")
elif origin == "upper":
ax.text(
y=image.shape[0] - 0.5,
x=-0.5,
s=ticks_label,
rotation=-45,
va="top",
ha="right",
)
if roi is not None:
lw = kwargs.get("roi_lw", 2)
pad = kwargs.get("roi_pad", 0)
color = kwargs.get("roi_color", "white")
dy, dx = image.shape
if isinstance(roi, dict):
left = (roi["left"] - lw * pad) / dx
bottom = (dy - roi["bottom"] - lw * pad) / dy
width = (roi["right"] - roi["left"] + 2 * lw * pad) / dx
height = (roi["bottom"] - roi["top"] + 2 * lw * pad) / dy
elif isinstance(roi, (list, tuple)):
top, bottom, left, right = roi
width = right - left
height = bottom - top
roi = {}
p = Rectangle((left, bottom), width, height, fill=False, edgecolor=color, linewidth=lw)
if "rotation" in roi.keys():
transform = Affine2D().rotate_deg_around(0.5, 0.5, -1 * roi["rotation"]) + ax.transAxes
else:
transform = ax.transAxes
p.set_transform(transform)
p.set_clip_on(False)
ax.add_patch(p)
if cbar:
aspect = image.shape[-2] / image.shape[-1]
cb = plt.colorbar(
im,
ax=ax,
pad=0.02,
format="%g",
fraction=0.047 * aspect,
)
if cbar_title is not None:
cb.set_label(str(cbar_title), labelpad=5)
if save:
print("saving: ", save)
dpi = kwargs.get("dpi", 400)
trns = kwargs.get("white_to_transparent", False)
if simple and title is None:
plt.savefig(save, dpi=dpi, bbox_inches=0, transparent=trns)
else:
plt.savefig(save, dpi=dpi, bbox_inches="tight", transparent=trns)
if figax is None:
# try:
# plt.tight_layout()
# except (UserWarning, RuntimeWarning):
# pass
plt.show()
[docs]
def show_stack(
images: list[np.ndarray] | np.ndarray,
titles: Optional[list[str]] = None,
scale_each: bool = True,
**kwargs,
) -> None:
"""
Display a stack of images interactively using a slider.
Args:
images (list): List of 2D numpy arrays representing the images.
titles (list, optional): List of titles for each image.
scale_each (bool): Scale each image individually or use the same scale.
Returns:
None
"""
_fig, _ax = plt.subplots()
images = np.array(images)
if not scale_each:
vmin = np.min(images)
vmax = np.max(images)
else:
vmin = kwargs.pop("vmin", None)
vmax = kwargs.pop("vmax", None)
N = images.shape[0]
if titles is not None:
assert len(titles) == len(images)
else:
title = kwargs.pop("title", None)
titles = [title] * len(images)
show_im(
images[0],
figax=(_fig, _ax),
title=titles[0],
cbar=False,
vmin=vmin,
vmax=vmax,
**kwargs,
)
def view_image(i=0):
t = titles[i]
show_im(
images[i],
figax=(_fig, _ax),
title=t,
ticks_label_off=True,
cbar=False,
vmin=vmin,
vmax=vmax,
**kwargs,
)
interact(view_image, i=(0, N - 1))
[docs]
def show_sims(
phi: np.ndarray,
im_un: np.ndarray,
im_in: np.ndarray,
im_ov: np.ndarray,
title: Optional[str] = None,
save: Optional[str] = None,
) -> None:
"""
Plot phase, underfocus, infocus, and overfocus images in one plot.
Args:
phi (np.ndarray): Image of phase shift.
im_un (np.ndarray): Underfocus image.
im_in (np.ndarray): Infocus image.
im_ov (np.ndarray): Overfocus image.
title (str, optional): Title for the plot.
save (str, optional): Path to save the figure.
Returns:
None
"""
vmax = np.max(phi) + 0.05
vmin = np.min(phi) - 0.05
fig = plt.figure(figsize=(12, 3))
ax11 = fig.add_subplot(141)
ax11.imshow(phi, cmap="gray", origin="upper", vmax=vmax, vmin=vmin)
plt.axis("off")
plt.title("Phase")
vmax = np.max([im_un, im_in, im_ov])
vmin = np.min([im_un, im_in, im_ov])
ax = fig.add_subplot(142)
ax.imshow(im_un, cmap="gray", origin="upper", vmax=vmax, vmin=vmin)
plt.axis("off")
plt.title("Underfocus")
ax2 = fig.add_subplot(143)
ax2.imshow(im_in, cmap="gray", origin="upper", vmax=vmax, vmin=vmin)
plt.axis("off")
plt.title("In-focus")
ax3 = fig.add_subplot(144)
ax3.imshow(im_ov, cmap="gray", origin="upper", vmax=vmax, vmin=vmin)
plt.axis("off")
plt.title("Overfocus")
if title is not None:
fig.suptitle(str(title))
if save is not None:
if not (save.endswith(".png") or save.endswith(".tiff") or save.endswith(".tif")):
save = save + ".png"
plt.savefig(save, dpi=300, bbox_inches="tight")
plt.show()
[docs]
def show_im_peaks(
im: Optional[np.ndarray] = None,
peaks: Optional[np.ndarray] = None,
peaks2: Optional[np.ndarray] = None,
title: Optional[str] = None,
cbar: bool = False,
**kwargs,
) -> None:
"""
Show image with overlaid peaks.
Args:
im (np.ndarray, optional): Image to display.
peaks (np.ndarray, optional): Peaks to overlay on the image.
peaks2 (np.ndarray, optional): Second set of peaks to overlay.
title (str, optional): Title for the plot.
cbar (bool): Whether to display a colorbar.
Returns:
None
"""
fig, ax = plt.subplots()
if im is not None:
show_im(im, title=title, cbar=cbar, figax=(fig, ax), **kwargs)
if peaks is not None:
peaks = np.array(peaks)
assert peaks.ndim == 2, f"Peaks dimension {peaks.ndim} != 2"
if peaks.shape[1] == 2 and peaks.shape[0] != 2:
peaks = peaks.T
ax.plot(
peaks[1],
peaks[0],
c=kwargs.get("color1", "r"),
alpha=kwargs.get("alpha", 0.9),
ms=kwargs.get("ms", None),
marker=kwargs.get("marker", "o"),
fillstyle="none",
linestyle="none",
)
if peaks2 is not None and np.size(peaks2) != 0:
peaks2 = np.array(peaks2)
if peaks2.shape[1] == 2 and peaks2.shape[0] != 2:
peaks2 = peaks2.T
ax.plot(
peaks2[1],
peaks2[0],
c=kwargs.get("color2", "b"),
alpha=kwargs.get("alpha", 0.9),
ms=kwargs.get("ms", None),
marker=kwargs.get("marker", "o"),
fillstyle="none",
linestyle="none",
)
plt.show()
[docs]
def show_fft(
im: np.ndarray, window: bool = True, log=False, alpha=1, gaussian_sigma1=0, gaussian_sigma2=0, **kwargs
) -> None:
"""
Compute and display the FFT of an image with logarithmic scaling.
Args:
im (np.ndarray): the image for which the FFT will be computed and displayed.
window (bool): Whether or not to window the image before taking the fft. Default True.
**kwargs: Additional keyword arguments passed to show_im.
Returns:
None
"""
if window:
win = tukey2D(im.shape, alpha=alpha) # type:ignore bug or something np1.26 only?
else:
win = np.ones_like(im)
if gaussian_sigma1 > 0:
im = ndi.gaussian_filter(im, gaussian_sigma1)
fft = np.fft.fft2(win * im)
mag = np.abs(np.fft.fftshift(fft))
if gaussian_sigma2 > 0:
mag = ndi.gaussian_filter(mag, gaussian_sigma2)
if log:
# bads = np.where(mag == 0)
# mag[bads] = 1
# mag[bads] = np.min(mag)
mag = np.log(mag + 1)
show_im(mag, **kwargs)
[docs]
def lineplot_im(
image: np.ndarray,
center: Optional[tuple[int, int]] = None,
phi: float = 0,
linewidth: int = 1,
line_len: int = -1,
show: bool = False,
**kwargs,
) -> np.ndarray:
"""
Generate a line plot through an image.
Args:
image (np.ndarray): Image to analyze.
center (tuple, optional): Center point for the line plot (cy, cx).
phi (float): Angle for the line plot in degrees.
linewidth (int): Line width for the plot.
line_len (int): Length of the line plot.
show (bool): Whether to display the plot.
**kwargs: Additional keyword arguments passed to show_im()
Returns:
np.ndarray: Line plot data.
"""
im = np.array(image)
if np.ndim(im) > 2:
print("More than 2 dimensions given, collapsing along first axis")
im = np.sum(im, axis=0)
dy, dx = im.shape
if center is None:
center = (dy // 2, dx // 2)
cy, cx = round(center[0]), round(center[1])
sp, ep = _box_intercepts((dy, dx), center, phi, line_len)
profile = skimage.measure.profile_line(
im, sp, ep, linewidth=linewidth, mode="constant", reduce_func=np.mean
)
if line_len > 0 and len(profile) > line_len:
lp = int(len(profile))
profile = profile[(lp - line_len) // 2 : -(lp - line_len) // 2]
if show:
show_scan = kwargs.get("show_scan", True)
if show_scan:
fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=kwargs.get("figsize"))
if kwargs.get("crop"):
crp = kwargs.pop("crop")
p = profile[crp:-crp]
else:
p = profile
ax0.plot(p)
ax0.set_aspect(1 / ax0.get_data_ratio(), adjustable="box")
ax0.set_ylabel("intensity")
ax0.set_xlabel("pixels")
else:
fig, ax1 = plt.subplots()
save = kwargs.pop("save", False)
show_im(
im,
figax=(fig, ax1),
cbar=kwargs.pop("cbar", False),
ticks_off=kwargs.pop("ticks_off", True),
**kwargs,
)
if linewidth > 1:
th = np.arctan2((ep[0] - sp[0]), (ep[1] - sp[1]))
spp, epp = _box_intercepts(
(dy, dx),
(cy + np.cos(th) * linewidth / 2, cx - np.sin(th) * linewidth / 2),
phi,
line_len,
)
spm, epm = _box_intercepts(
(dy, dx),
(cy - np.cos(th) * linewidth / 2, cx + np.sin(th) * linewidth / 2),
phi,
line_len,
)
color = kwargs.get("color", "r")
ax1.fill(
[spp[1], epp[1], epm[1], spm[1]],
[spp[0], epp[0], epm[0], spm[0]],
alpha=0.3,
facecolor=color,
edgecolor=None,
)
ax1.plot([sp[1], ep[1]], [sp[0], ep[0]], color=color, linewidth=0.5)
else:
ax1.plot(
[sp[1], ep[1]],
[sp[0], ep[0]],
color=kwargs.get("color", "r"),
linewidth=1,
)
ax1.set_xlim((0, dx - 1))
ax1.set_ylim((dy - 1, 0))
title = kwargs.get("title")
if title is not None:
plt.suptitle(title)
if save:
print("saving: ", save)
dpi = kwargs.get("dpi", 400)
plt.savefig(save, dpi=dpi, bbox_inches="tight")
plt.show()
return profile
def _box_intercepts(
dims: tuple[int, int], center: tuple[int, int], phi: float, line_len: int = -1
) -> tuple[tuple[int, int], tuple[int, int]]:
"""
Calculate box intercept points for a line in a box.
Args:
dims (tuple): Dimensions of the box (dy, dx).
center (tuple): Center of the line (cy, cx).
phi (float): Angle of the line in degrees.
line_len (int): Length of the line.
Returns:
tuple[tuple[int, int], tuple[int, int]]: Start and end points of the line.
"""
dy, dx = dims
cy, cx = center
phir = np.deg2rad(phi)
tphi = np.tan(phir)
tphi2 = np.tan(phir - np.pi / 2)
epy = round((dx - cx) * tphi + cy)
if 0 <= epy < dy:
epx = dx - 1
elif epy < 0:
epy = 0
epx = round(cx + cy * tphi2)
else:
if phir == 0:
raise ValueError(f"Center y = {cy} and dimy = {dy} with phi == 0")
epy = dy - 1
epx = round(cx + (dy - cy) / tphi)
spy = round(cy - cx * tphi)
if 0 <= spy < dy:
spx = 0
elif spy >= dy:
spy = dy - 1
spx = round(cx - (dy - cy) * tphi2)
else:
spy = 0
spx = round(cx - cy / tphi)
if line_len > 0:
sp2y = cy - np.sin(np.deg2rad(phi)) * line_len / 2
sp2x = cx - np.cos(np.deg2rad(phi)) * line_len / 2
ep2y = cy + np.sin(np.deg2rad(phi)) * line_len / 2
ep2x = cx + np.cos(np.deg2rad(phi)) * line_len / 2
spy = spy if sp2y < 0 else sp2y
spx = spx if sp2x < 0 else sp2x
epy = epy if ep2y > dy - 1 else ep2y
epx = epx if ep2x > dx - 1 else ep2x
sp = (spy, spx) # start point
ep = (epy, epx) # end point
return sp, ep
def _white_to_transparent(image):
"""
take a greyscale image and convert the white portions to transparent, i.e. mapping
intensity to alpha. Returns as 8 bit 4 channel image
"""
im_scaled = image - image.min()
im_scaled /= im_scaled.max()
im_scaled = np.array(im_scaled, dtype=np.float32)
rgba_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.float32)
rgba_image[:, :, :3] = im_scaled[..., None]
rgba_image[:, :, 3] = 1 - im_scaled
return rgba_image
[docs]
def tukey2D(shape: tuple[int, int], alpha: float = 0.5) -> np.ndarray:
y = windows.tukey(shape[0], alpha=alpha)
x = windows.tukey(shape[1], alpha=alpha)
return y[:, None] * x[None, ...]