Source code for PyLorentz.phase.sitie

import os
from pathlib import Path
from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np

from PyLorentz.dataset.defocused_dataset import DefocusedDataset
from PyLorentz.io.write import write_json
from PyLorentz.phase.base_tie import BaseTIE
from PyLorentz.visualize import show_2D, show_im


[docs]class SITIE(BaseTIE): """ Class for phase reconstruction using the SITIE method. """
[docs] def __init__( self, dd: DefocusedDataset, save_dir: Optional[os.PathLike] = None, name: Optional[str] = None, sym: bool = False, qc: Optional[float] = None, verbose: int = 1, ): """ Initialize the SITIE object. Args: dd (DefocusedDataset): Defocused dataset. save_dir (Optional[os.PathLike], optional): Directory to save results. Default is None. name (Optional[str], optional): Name for the reconstruction. Default is None. sym (bool, optional): Whether to symmetrize the images. Default is False. qc (Optional[float], optional): Tikhonov regularization parameter. Default is None. verbose (int, optional): Verbosity level. Default is 1. """ self.dd = dd if save_dir is None and dd.data_dir is not None: topdir = Path(dd.data_dir) if topdir.exists(): save_dir = topdir / "SITIE_outputs" super().__init__( save_dir=save_dir, scale=dd.scale, beam_energy=dd.beam_energy, name=name, verbose=verbose, ) self.qc = qc # for type checking self.sym = sym self.scale = dd.scale self._results["input_image"] = None self._recon_defval = None self._recon_defval_index = None if not self.dd._preprocessed and not self.dd._simulated: raise ValueError("dataset has not been preprocessed")
[docs] @classmethod def from_array( cls, image: np.ndarray, scale: Union[float, int, None] = None, defval: Optional[List[float]] = None, beam_energy: Optional[float] = None, name: Optional[str] = None, sym: bool = False, qc: Optional[float] = None, save_dir: Optional[os.PathLike] = None, simulated: bool = False, verbose: Union[int, bool] = 1, ) -> "SITIE": """ Create SITIE object from a numpy array. Args: image (np.ndarray): Input image array. scale (Union[float, int, None], optional): Scale factor for the dataset. Default is None. defvals (Optional[List[float]], optional): List of defocus values. Default is None. beam_energy (Optional[float], optional): Beam energy for the reconstruction. Default is None. name (Optional[str], optional): Name for the reconstruction. Default is None. sym (bool, optional): Whether to symmetrize the images. Default is False. qc (Optional[float], optional): Tikhonov regularization parameter. Default is None. save_dir (Optional[os.PathLike], optional): Directory to save results. Default is None. simulated (bool, optional): Whether the data is simulated. Default is False. verbose (Union[int, bool], optional): Verbosity level. Default is 1. Returns: SITIE: An instance of the SITIE class. """ dd = DefocusedDataset( images=image, defvals=defval, scale=scale, beam_energy=beam_energy, simulated=simulated, verbose=verbose, ) sitie = cls( dd=dd, save_dir=save_dir, name=name, sym=sym, qc=qc, verbose=verbose, ) return sitie
[docs] def reconstruct( self, index: Optional[int] = None, name: Optional[str] = None, sym: bool = False, qc: Optional[float] = None, save: Union[bool, str, List[str]] = False, save_dir: Optional[os.PathLike] = None, verbose: Optional[int] = None, pbcs: Optional[bool] = None, overwrite: bool = False, ) -> "SITIE": """ Perform SITIE reconstruction. Args: index (Optional[int], optional): Index of the image to reconstruct. Default is None. name (Optional[str], optional): Name for the reconstruction. Default is None. sym (bool, optional): Whether to symmetrize the images. Default is False. qc (Optional[float], optional): Tikhonov regularization parameter. Default is None. save (Union[bool, str, List[str]], optional): Whether and what to save. Default is False. save_dir (Optional[os.PathLike], optional): Directory to save results. Default is None. verbose (Union[int, bool], optional): Verbosity level. Default is 1. pbcs (Optional[bool], optional): Whether to apply periodic boundary conditions. Default is None. overwrite (bool, optional): Whether to overwrite existing files. Default is False. Returns: SITIE: The SITIE instance after reconstruction. """ if index is None: index = 0 elif index > len(self) - 1: raise IndexError(f"Index {index} not allowed for images of length {len(self)}") else: assert isinstance(index, int) if self.dd._transforms_modified: self.vprint("DD has unapplied transforms, applying now.") self.dd.apply_transforms() self._recon_defval_index = index self._recon_defval = self.dd.defvals[index] self.sym = sym if qc is not None: self.qc = qc if pbcs is not None: self._pbcs = pbcs self._verbose = verbose if verbose is not None else self._verbose if save: self._check_save_name(save_dir, name, mode="SITIE") self._overwrite = overwrite if overwrite is not None else self._overwrite self.vprint( f"Performing SITIE reconstruction with defocus " + f"{self._fmt_defocus(self._recon_defval, spacer=' ')}, index = {index}" ) # setup data dimy, dimx = self.dd.shape # select image recon_image = self.dd.images[index].copy() self._results["input_image"] = recon_image.copy() if self.sym: dimy *= 2 dimx *= 2 recon_image = self._symmetrize(recon_image) self._make_qi((dimy, dimx)) # construct the "infocus" image and get derivatives infocus_im = np.ones(np.shape(recon_image)) * np.mean(recon_image) dIdZ_B = 2 * (recon_image - infocus_im) dIdZ_B -= np.sum(dIdZ_B) / np.size(infocus_im) phase_B = self._reconstruct_phase(infocus_im, dIdZ_B, self._recon_defval) self._results["phase_B"] = phase_B - phase_B.min() By, Bx = self.induction_from_phase(phase_B) self._results["By"] = By self._results["Bx"] = Bx if save: self.save_results(save, overwrite) return self # self or None?
[docs] def save_results( self, save_mode: Union[bool, str, List[str]] = True, save_dir: Optional[os.PathLike] = None, name: Optional[str] = None, overwrite: bool = False, ) -> 'SITIE': """ Save the reconstruction results. Args: save_mode (Union[bool, str, List[str]], optional): Keys to save. Default is True. save_dir (Optional[os.PathLike], optional): Directory to save results. Default is None. name (Optional[str], optional): Name for the reconstruction. Default is None. overwrite (bool, optional): Whether to overwrite existing files. Default is False. Returns: TIE: The TIE instance. """ self._check_save_name(save_dir, name=name, mode="SITIE") if isinstance(save_mode, bool): if not save_mode: return self save_keys = ["phase_B", "Bx", "By", "color", "input_image"] elif isinstance(save_mode, str): if save_mode.lower() in ["b", "induction"]: save_keys = ["Bx", "By", "color"] elif save_mode.lower() in ["phase"]: save_keys = ["phase_B"] elif save_mode.lower() == "all": # save_keys = list(self.results.keys()) # doesnt have color save_keys = ["phase_B", "Bx", "By", "color", "input_image"] elif hasattr(save_mode, "__iter__"): save_keys = [str(k) for k in save_mode] self.save_dir.mkdir(exist_ok=True) self._save_keys(save_keys, self.recon_defval, overwrite) self._save_log(overwrite) return self
def _save_log(self, overwrite: Optional[bool] = None): """ Save the reconstruction log. Args: overwrite (Optional[bool], optional): Whether to overwrite existing files. Default is None. """ log_dict = { "name": self.name, "_save_name": self._save_name, "defval": self.recon_defval, "sym": self.sym, "qc": self.qc, "scale": self.scale, "transforms": self.dd.transforms, "filters": self.dd._filters, "beam_energy": self.dd.beam_energy, "simulated": self.dd._simulated, "data_dir": self.dd.data_dir, "data_files": self.dd.data_files, "save_dir": self._save_dir, } ovr = overwrite if overwrite is not None else self._overwrite name = f"{self._save_name}_{self._fmt_defocus(self.recon_defval)}_log.json" write_json(log_dict, self.save_dir / name, overwrite=ovr, v=self._verbose) def __len__(self) -> int: """ Get the number of images in the dataset. Returns: int: Number of images in the dataset. """ return len(self.dd.images) @property def recon_defval(self) -> Optional[float]: """ Get the defocus value used for reconstruction. Returns: Optional[float]: Defocus value. """ if self._recon_defval is None: print("defval is None or has not yet been specified with an index") return self._recon_defval
[docs] def visualize(self, cbar: bool = False, plot_scale: Union[bool, str] = True) -> "SITIE": """ Visualize the phase and induction maps. Args: cbar (bool, optional): Whether to display a colorbar. Default is False. plot_scale (Union[bool, str], optional): Whether and what scale to plot. Default is True. Returns: SITIE: The SITIE instance. """ fig, axs = plt.subplots(ncols=2, figsize=(8, 4)) if isinstance(plot_scale, str): if plot_scale == "all": ticks1 = ticks2 = False elif plot_scale.lower() == "phase": ticks1 = False ticks2 = True elif plot_scale.lower() in ["color", "induction", "b", "ind"]: ticks1 = True ticks2 = False else: ticks1 = False ticks2 = True else: ticks1 = False ticks2 = True show_im( self.phase_B, title="Magnetic phase shift", scale=self.scale, figax=(fig, axs[0]), ticks_off=ticks1, cbar=cbar, cbar_title="rad", ) show_2D( self.Bx, self.By, figax=(fig, axs[1]), scale=self.scale, ticks_off=ticks2, title="Integrated induction map", ) axs[-1].axis("off") plt.tight_layout() plt.show() return self