Source code for PyLorentz.visualize.vector_show

from __future__ import annotations
import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colors
from matplotlib.axes import Axes
from matplotlib.colors import Colormap
from matplotlib.figure import Figure

from .colorwheel import color_im, get_cmap
from .show import show_im


[docs] def show_2D( Vx: np.ndarray, Vy: np.ndarray, Vz: Optional[np.ndarray] = None, num_arrows: int = 0, arrow_size: Optional[float] = None, arrow_width: Optional[float] = None, title: Optional[str] = None, simple: bool = False, color: bool = True, cmap: str | Colormap = "hsv", origin: str = "upper", save: str | os.PathLike | None = None, figax: Optional[tuple[Figure, Axes]] = None, rad: Optional[int] = None, scale: Optional[float] = None, **kwargs, ) -> None: """ Display a 2D vector field with arrows and optional color mapping. Args: Vx (np.ndarray): X-component of the vector field. Vy (np.ndarray): Y-component of the vector field. Vz (np.ndarray, optional): Z-component of the vector field. num_arrows (int): Number of arrows to plot along x and y axes. arrow_size (float, optional): Scale factor for arrow length. arrow_width (float, optional): Width of arrows. title (str, optional): Title of the plot. color (bool): Whether to display a colormap underneath the arrows. cmap (str): Colormap to use for color mapping. origin (str): Origin of the image coordinate system. save (str, optional): Path to save the figure. figax (tuple, optional): Figure and axes to plot on. rad (int, optional): Radius for the color wheel. scale (float, optional): Scale factor for the plot. **kwargs: Additional keyword arguments for customization. Returns: plt.Figure: The matplotlib figure object. """ assert Vx.ndim == Vy.ndim if Vx.ndim == 3: if Vx.shape[0] != 1: print("Summing along first axis") Vx = np.sum(Vx, axis=0) Vy = np.sum(Vy, axis=0) if Vz is not None: Vz = np.sum(Vz, axis=0) if num_arrows > 0: a = int(((Vx.shape[0] - 1) / num_arrows) + 1) else: a = -1 dimy, dimx = Vx.shape X = np.arange(0, dimx, 1) Y = np.arange(0, dimy, 1) U = Vx V = Vy sz_inches = kwargs.pop("figsize", 5) if isinstance(sz_inches, (list, tuple, np.ndarray)): sz_inches = sz_inches[0] # Aspect ratio depends on rad, use one value if color: if rad is None: rad = max(Vx.shape[0] // 16, 16) pad = 10 # pixels width = np.shape(Vy)[1] + 2 * rad + pad aspect = dimy / width elif rad == 0: width = np.shape(Vy)[1] aspect = dimy / width else: pad = 10 # pixels width = np.shape(Vy)[1] + 2 * rad + pad aspect = dimy / width else: aspect = dimy / dimx if figax is None: if simple and title is None: fig = plt.figure() size = (sz_inches, sz_inches * aspect) fig.set_size_inches(size) ax = Axes(fig, (0.0, 0.0, 1.0, 1.0)) fig.add_axes(ax) ax.set_aspect(aspect) else: fig, ax = plt.subplots() else: fig, ax = figax if color: cmap = get_cmap(cmap, **kwargs) cim = color_im( Vx, Vy, Vz, cmap=cmap, rad=rad, **kwargs, ) show_bbox = kwargs.pop("show_bbox", False) show_im( cim, cmap=cmap, origin=origin, ticks_off=scale is None or kwargs.pop("ticks_off", False), scale=scale, simple=simple, figax=(fig, ax), title=title, show_bbox=show_bbox, **kwargs, ) arrow_color = "white" else: arrow_color = "black" arrow_color = kwargs.get("arrow_color", arrow_color) if a > 0: ashift = (dimx - 1) % a // 2 arrow_scale = 1 / abs(arrow_size) if arrow_size is not None else None ax.quiver( X[ashift::a], Y[ashift::a], U[ashift::a, ashift::a], V[ashift::a, ashift::a], units="xy", scale=arrow_scale, scale_units="xy", width=arrow_width, angles="xy", pivot="mid", color=arrow_color, ) if not color and a > 0 and origin == "upper": ax.invert_yaxis() if save is not None: print(f"Saving: {save}") plt.axis("off") dpi = kwargs.get("dpi", max(dimy, dimx) * 5 / sz_inches) if title is None: plt.savefig(save, dpi=dpi, bbox_inches=0, transparent=True) else: plt.savefig(save, dpi=dpi, bbox_inches="tight", transparent=True) return
[docs] def show_3D( Vx: np.ndarray, Vy: np.ndarray, Vz: np.ndarray, num_arrows: int = 15, ay: Optional[int] = None, num_arrows_z: int = 15, arrow_size: float | None = None, show_all: bool = True, title: str | None = None, ) -> None: """ -- semi deprecated -- Display a 3D vector field with arrows, using color to represent vector direction. Arrow color is determined by direction, with in-plane mapping to an HSV color-wheel and out of plane to white (+z) and black (-z). Args: Vx (np.ndarray): (z, y, x) X-component of the vector field. Vy (np.ndarray): (z, y, x) Y-component of the vector field. Vz (np.ndarray): (z, y, x) Z-component of the vector field. num_arrows (int): Number of arrows to plot along the x-axis. ay (int, optional): Number of arrows to plot along the y-axis. num_arrows_z (int): Number of arrows to plot along the z-axis. arrow_size (float, optional): Scale factor for arrow length. show_all (bool): Whether to show all arrows with equal opacity. Returns: None """ bmax = max(Vx.max(), Vy.max(), Vz.max()) if arrow_size is None: arrow_size = Vx.shape[1] / (2 * bmax * num_arrows) assert isinstance(arrow_size, float) fig = plt.figure() ax = fig.add_subplot(111, projection="3d") if Vx.ndim == 3: dimz, dimy, dimx = Vx.shape if num_arrows_z > dimz: az = 1 else: az = ((dimz - 1) // num_arrows_z) + 1 else: az = 1 dimy, dimx = Vx.shape dimz = 1 Z, Y, X = np.meshgrid( np.arange(0, dimz, 1), np.arange(0, dimy, 1), np.arange(0, dimx, 1), indexing="ij", ) ay = ((dimy - 1) // num_arrows) + 1 if ay is None else ay axx = ((dimx - 1) // num_arrows) + 1 zeros = ~(Vx.astype("bool") | Vy.astype("bool") | Vz.astype("bool")) zinds = np.where(zeros) Vz[zinds] = bmax / 1e5 Vx[zinds] = bmax / 1e5 Vy[zinds] = bmax / 1e5 U = Vx.reshape((dimz, dimy, dimx)) V = Vy.reshape((dimz, dimy, dimx)) W = Vz.reshape((dimz, dimy, dimx)) phi = np.ravel(np.arctan2(V[::az, ::ay, ::axx], U[::az, ::ay, ::axx])) hue = phi / (2 * np.pi) + 0.5 theta = np.arctan2( W[::az, ::ay, ::axx], np.sqrt(U[::az, ::ay, ::axx] ** 2 + V[::az, ::ay, ::axx] ** 2), ) value = np.ravel(np.where(theta < 0, 1 + 2 * theta / np.pi, 1)) sat = np.ravel(np.where(theta > 0, 1 - 2 * theta / np.pi, 1)) arrow_colors = np.squeeze(np.dstack((hue, sat, value))) arrow_colors = colors.hsv_to_rgb(arrow_colors) if show_all: alphas = np.ones((np.shape(arrow_colors)[0], 1)) tcolor = "k" else: tcolor = "w" alphas = np.minimum(value, sat).reshape(len(value), 1) value = np.ones(value.shape) sat = np.ravel(1 - abs(2 * theta / np.pi)) arrow_colors = np.squeeze(np.dstack((hue, sat, value))) arrow_colors = colors.hsv_to_rgb(arrow_colors) ax.set_facecolor("black") for axs in [ax.xaxis, ax.yaxis, ax.zaxis]: # type:ignore axs.set_pane_color((0, 0, 0, 1.0)) axs.pane.set_edgecolor(tcolor) [t.set_color(tcolor) for t in axs.get_ticklines()] [t.set_color(tcolor) for t in axs.get_ticklabels()] ax.grid(False) arrow_colors = np.array( [np.concatenate((arrow_colors[i], alphas[i])) for i in range(len(alphas))] ) arrow_colors = np.concatenate((arrow_colors, np.repeat(arrow_colors, 2, axis=0))) dim = max(dimx, dimy, dimz) ax.set_xlim(0, dim) ax.set_ylim(0, dimy) if az >= dimz: ax.set_zlim(-dim // 2, dim // 2) # type:ignore else: ax.set_zlim(0, dim) # type:ignore Z += (dim - dimz) // 2 ax.quiver( X[::az, ::ay, ::axx], Y[::az, ::ay, ::axx], Z[::az, ::ay, ::axx], U[::az, ::ay, ::axx], V[::az, ::ay, ::axx], W[::az, ::ay, ::axx], color=arrow_colors, length=float(arrow_size), pivot="middle", normalize=False, ) if title is not None: ax.set_title(title) ax.set_xlabel("x", c=tcolor) ax.set_ylabel("y", c=tcolor) ax.set_zlabel("z", c=tcolor) # type:ignore plt.show()