Source code for PyLorentz.phase.AD_phase

import os
import warnings
from datetime import datetime, timedelta
from pathlib import Path
from typing import List, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi

try:
    import torch
    import torch.nn.functional as F
    import torchvision.transforms as TvT
except (ModuleNotFoundError, ImportError) as e:
    torch = None

from matplotlib.ticker import FormatStrFormatter
from torch import nn
from tqdm import tqdm

from PyLorentz.dataset.defocused_dataset import DefocusedDataset
from PyLorentz.io.write import write_json
from PyLorentz.phase.base_phase import BasePhaseReconstruction
from PyLorentz.utils import Microscope
from PyLorentz.visualize import show_2D, show_im

from .DIP_NN import weight_reset
from .sitie import SITIE


[docs]class ADPhase(BasePhaseReconstruction): """ ADPhase class for phase reconstruction using defocused datasets and DIPs. """ _default_sample_params = { "dirt_V0": 20, "dirt_xip0": 10, } _default_LRs = { "phase": 2e-4, # 2e-4 for DIP, 0.2 otherwise "amp": 2e-4, # 2e-4 for DIP, 0.02 otherwise "amp_scale": 0.1, # used when not solve_amp "amp2phi_scale": 1, "TV_phase_weight": 5e-3, # only used if reconstructing without DIPs "TV_amp_weight": 5e-3, # only used if reconstructing without DIPs }
[docs] def __init__( self, dd: DefocusedDataset, device: Union[str, int], save_dir: Optional[os.PathLike] = None, name: Optional[str] = None, verbose: bool = 1, scope: Optional[Microscope] = None, sample_params: dict = {}, rng_seed: Optional[int] = None, LRs: dict = {}, scheduler_type: Optional[str] = None, noise_frac: float = 1 / 100, gaussian_sigma: float = 1, ): """ Initialize the ADPhase object. Args: dd (DefocusedDataset): The defocused dataset. device (Union[str, int]): The device to use (CPU or GPU). save_dir (Optional[os.PathLike], optional): Directory to save results. name (Optional[str], optional): Name for the results. verbose (bool, optional): Verbosity level. scope (Optional[Microscope], optional): Microscope object. sample_params (dict, optional): Sample parameters. rng_seed (Optional[int], optional): Random seed. LRs (dict, optional): Learning rates for optimization. scheduler_type (Optional[str], optional): Type of learning rate scheduler. noise_frac (float, optional): Fraction of noise to add. gaussian_sigma (float, optional): Sigma value for Gaussian filter. """ self.dd = dd if len(dd) == 1: self._mode = "SIPRAD" else: self._mode = "ADPhase" if save_dir is None and dd.data_dir is not None: topdir = Path(dd.data_dir) if topdir.exists(): save_dir = topdir / f"{self._mode}_outputs" super().__init__(save_dir, name, dd.scale, verbose) self.sample_params = self._default_sample_params | sample_params self.LRs = self._default_LRs | LRs self.device = device self.inp_ims = torch.tensor(self.dd.images, device=self.device, dtype=torch.float32) self.defvals = dd.defvals self.scope = scope self.shape = dd.shape self._rng = np.random.default_rng(rng_seed) self._noise_frac = noise_frac self._scheduler_type = scheduler_type # to be set later: self._guess_phase: Optional[torch.Tensor] = None self._runtype: Optional[str] = None self._recon_amp: Optional[torch.Tensor] = None self._recon_phase: Optional[torch.Tensor] = None self._use_DIP: Optional[bool] = None self._solve_amp: Optional[bool] = None self._solve_amp_scale: Optional[bool] = None self._best_phase: Optional[torch.Tensor] = None self._best_amp: Optional[torch.Tensor] = None self._best_iter: Optional[int] = None self._phase_iterations: List[torch.Tensor] = [] self._amp_iterations: List[torch.Tensor] = [] self.phase_iterations: Optional[np.ndarray] = None self.amp_iterations: Optional[np.ndarray] = None self.loss_iterations = [] self.LR_iterations = [] self.gaussian_sigma = gaussian_sigma self._TFs = self.get_TFs()
@property def recon_phase(self) -> Optional[np.ndarray]: """ Returns the reconstructed phase after applying Gaussian filter. Returns: Optional[np.ndarray]: Reconstructed phase image. """ if self._recon_phase is not None: ph = ndi.gaussian_filter( self._recon_phase.cpu().detach().numpy(), self._gaussian_sigma ) ph -= ph.min() return ph else: return None @property def best_phase(self) -> Optional[np.ndarray]: """ Returns the best phase after applying Gaussian filter. Returns: Optional[np.ndarray]: Best phase image. """ if self._best_phase is not None: ph = ndi.gaussian_filter(self._best_phase.cpu().detach().numpy(), self._gaussian_sigma) ph -= ph.min() return ph else: return None
[docs] def set_best_phase(self, iter_ind: int = -1) -> None: """ Sets the best phase from the specified iteration index. Args: iter_ind (int, optional): Index of the iteration to use for the best phase. """ self._best_phase, iter = self._phase_iterations[iter_ind] self._best_iter = iter self.phase_B = self.best_phase
@property def best_amp(self) -> Optional[np.ndarray]: """ Returns the best amplitude after applying Gaussian filter. Returns: Optional[np.ndarray]: Best amplitude image. """ if self._best_amp is not None: amp = ndi.gaussian_filter(self._best_amp.cpu().detach().numpy(), self._gaussian_sigma) return amp else: return None @property def gaussian_sigma(self) -> float: """ Returns the Gaussian sigma value. Returns: float: Gaussian sigma value. """ return self._gaussian_sigma @gaussian_sigma.setter def gaussian_sigma(self, val: Optional[float]) -> None: """ Sets the Gaussian sigma value and updates the blurring transformation. Args: val (Optional[float]): Gaussian sigma value. Raises: TypeError: If `val` is not a numeric type. ValueError: If `val` is less than 0. """ if val is None: self._gaussian_sigma = 0 self._blurrer = None elif not isinstance(val, (float, int)): raise TypeError(f"gaussian_sigma must be numeric, received {type(val)}") elif val < 0: raise ValueError(f"gaussian_sigma must be >= 0 or None, received {val}") else: self._blurrer = TvT.GaussianBlur(kernel_size=(9, 9), sigma=(val, val)) self._gaussian_sigma = val if self.best_phase is not None: self.phase_B = self.best_phase self._set_recon_iterations() @property def recon_amp(self) -> Optional[np.ndarray]: """ Returns the reconstructed amplitude after applying Gaussian filter. Returns: Optional[np.ndarray]: Reconstructed amplitude image. """ if self._recon_amp is not None: amp = ndi.gaussian_filter(self._recon_amp.cpu().detach().numpy(), self._gaussian_sigma) return amp else: return None @property def _TFs(self) -> torch.Tensor: """ Returns the transfer functions. Returns: torch.Tensor: Transfer functions. """ return self.transfer_functions @_TFs.setter def _TFs(self, arr: Union[torch.Tensor, np.ndarray]) -> None: """ Sets the transfer functions. Args: arr (Union[torch.Tensor, np.ndarray]): Transfer functions array. Raises: AssertionError: If `arr` does not match expected shape. """ if not isinstance(arr, torch.Tensor): arr = torch.tensor(arr, device=self.device, dtype=torch.complex64) assert len(arr.shape) == 3, f"Bad TF shape: {arr.shape}, should be {self.dd.images.shape}" assert len(arr) == len(self.dd), f"len TFs {len(arr)} != # defvals {len(self.dd)}" self.transfer_functions = arr @property def scope(self) -> Optional[Microscope]: """ Returns the microscope object. Returns: Optional[Microscope]: Microscope object. """ return self._scope @scope.setter def scope(self, microscope: Microscope) -> None: """ Sets the microscope object. Args: microscope (Microscope): Microscope object to set. Raises: TypeError: If `microscope` is not of type `Microscope`. """ if not isinstance(microscope, Microscope): raise TypeError( f"microscope must be a PyLorentz.utils.microscopes.Microscope object, received {type(microscope)}" ) else: self._scope = microscope @property def device(self) -> str: """ Returns the device used for computation. Returns: str: Device for computation. """ return self._device @device.setter def device(self, dev: Union[int, str]) -> None: """ Sets the device for computation. Args: dev (Union[int, str]): Device identifier. Raises: TypeError: If `dev` is not of type int or str. ValueError: If `dev` exceeds available GPUs. """ if isinstance(dev, torch.device): if dev.type == "gpu": self._device = dev ind = dev.index elif dev.type == "cuda": self._device = dev else: raise TypeError( f"Unknown device type: {dev.type} This can likely be fixed easily." ) elif isinstance(dev, (str, int, float)): if dev in ["cpu", "CPU"]: warnings.warn("Setting device to cpu, this will be slow and might fail.") ind = "cpu" elif dev in ["gpu", "GPU"]: assert torch.cuda.is_available(), f"No GPUs available" ind = 0 else: assert torch.cuda.is_available(), f"No GPUs available" ind = int(dev) if ind >= torch.cuda.device_count(): raise ValueError( f"device must be < num_devices, which is {torch.cuda.device_count()}. Received device {ind}" ) self._device = torch.device(ind) if isinstance(ind, int): self.vprint(f"Proceeding with GPU {ind}: {torch.cuda.get_device_name(ind)}") else: raise TypeError(f"Device should be int, str, or torch.device. Received {type(dev)}") @property def model_input(self) -> Optional[torch.Tensor]: """ Returns the model input tensor. Returns: Optional[torch.Tensor]: Model input tensor. """ return self._model_input @model_input.setter def model_input(self, arr: torch.Tensor) -> None: """ Sets the model input tensor. Args: arr (torch.Tensor): Input tensor to set. """ # set to tensor on device # make sure is same size as image return @property def guess_phase(self) -> Optional[torch.Tensor]: """ Returns the guess phase used to pre-train the DIP. Returns: Optional[torch.Tensor]: Guess phase tensor. """ return self._guess_phase @guess_phase.setter def guess_phase(self, im: Union[np.ndarray, torch.Tensor]) -> None: """ Sets the guess phase. Args: im (Union[np.ndarray, torch.Tensor]): Guess phase image. Raises: ValueError: If `im` shape does not match input image shape. """ if not isinstance(im, torch.Tensor): im = torch.tensor(im, dtype=torch.float32) im = im.to(self.device) if im.shape != self.shape: raise ValueError( f"Guess phase shape, {im.shape} should match input image shape, {self.shape}" ) self._guess_phase = im @property def guess_amp(self) -> Optional[torch.Tensor]: """ Returns the guess amplitude used to pre-train the DIP or if `solve_amp` is False. Returns: Optional[torch.Tensor]: Guess amplitude tensor. """ return self._guess_amp @guess_amp.setter def guess_amp(self, im: Union[np.ndarray, torch.Tensor]) -> None: """ Sets the guess amplitude. Args: im (Union[np.ndarray, torch.Tensor]): Guess amplitude image. Raises: ValueError: If `im` shape does not match input image shape. """ if not isinstance(im, torch.Tensor): im = torch.tensor(im, dtype=torch.float32) im = im.to(self.device) if im.shape != self.shape: raise ValueError( f"Guess amp shape, {im.shape} should match input image shape, {self.shape}" ) self._guess_amp = im @property def input_DIP(self) -> Optional[torch.Tensor]: """ Returns the noise used as input for one or both DIPs. Returns: Optional[torch.Tensor]: Input noise tensor. """ return self._input_DIP @input_DIP.setter def input_DIP(self, im: Union[np.ndarray, torch.Tensor]) -> None: """ Sets the input noise for DIPs. Args: im (Union[np.ndarray, torch.Tensor]): Input noise image. Raises: ValueError: If `im` shape does not match input image shape. """ if not isinstance(im, torch.Tensor): im = torch.tensor(im, device=self.device, dtype=torch.float32) if ( im.shape != self.dd.images.shape ): # TODO not sure if this is correct for multiple images raise ValueError( f"Input noise shape, {im.shape} should match input image shape, {self.shape}" ) self._input_DIP = im
[docs] def reconstruct( self, num_iter: int, model: Optional[Union[nn.Module, List[nn.Module]]] = None, num_pretrain_iter: int = 0, solve_amp: bool = False, solve_amp_scale: bool = True, guess_amp: Optional[Union[float, np.ndarray]] = None, LRs: dict = {}, scheduler_type: Optional[str] = None, save: bool = False, name: Optional[str] = None, save_dir: Optional[os.PathLike] = None, noise_frac: Optional[float] = None, guess_phase: Union[str, np.ndarray, None] = "SITIE", reset: bool = True, print_every: int = -1, verbose: int = 1, store_iters_every: int = -1, qc: Optional[any] = None, **kwargs, # scheduler params ) -> None: """ Performs the reconstruction process. Args: num_iter (int): Number of iterations for reconstruction. model (Optional[Union[nn.Module, List[nn.Module]]], optional): Model or list of models for DIP. num_pretrain_iter (int, optional): Number of pretraining iterations. solve_amp (bool, optional): Whether to solve for amplitude. solve_amp_scale (bool, optional): Whether to solve amplitude scale. guess_amp (Optional[Union[float, np.ndarray]], optional): Guess amplitude. LRs (dict, optional): Learning rates for optimization. scheduler_type (Optional[str], optional): Type of learning rate scheduler. save (bool, optional): Whether to save results. name (Optional[str], optional): Name for the saved results. save_dir (Optional[os.PathLike], optional): Directory to save results. noise_frac (Optional[float], optional): Fraction of noise to add. guess_phase (Union[str, np.ndarray, None], optional): Guess phase or method to obtain it. reset (bool, optional): Whether to reset the model. print_every (int, optional): Frequency of printing progress. verbose (int, optional): Verbosity level. store_iters_every (int, optional): Frequency of storing iterations. qc (Optional[any], optional): Quality control object. **kwargs: Additional keyword arguments for scheduler parameters. """ ### SETUP self._start_time = datetime.now() self._num_pretrain_iter = num_pretrain_iter if noise_frac is not None: self._noise_frac = noise_frac if verbose is not None: self._verbose = verbose self.LRs = self.LRs | LRs self._solve_amp_scale = solve_amp_scale if not solve_amp else False if isinstance(model, nn.Module): self._use_DIP = True DIP_phase = model DIP_amp = None elif model is not None: self._use_DIP = True DIP_phase, DIP_amp = model else: self._use_DIP = False DIP_phase = DIP_amp = None if scheduler_type is not None: self._scheduler_type = scheduler_type self._qc = qc if save: if name is None and self.name is None: now = self._start_time.strftime("%y%m%d-%H%M%S") if len(self.dd) == 1: mode = "SIPRAD" self._results["input_image"] = self.dd.images[0] else: mode = f"N{len(self.dd)}AD" self._check_save_name(save_dir, name=f"{now}_{mode}") self._noise_frac = noise_frac if noise_frac is not None else self._noise_frac ### Initialization/reset reset = True if len(self.loss_iterations) == 0 else reset if reset: self._set_guess_phase(guess_phase) self._set_guess_amp(guess_amp) self.loss_iterations = [] self.LR_iterations = [] self._phase_iterations = [] self._amp_iterations = [] self._amp_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device) self._amp2phi_scale = torch.tensor( [self._get_amp2phi_scale()], dtype=torch.float32, device=self.device ) if self._use_DIP: self._runtype = "DIP" self._set_input_DIP() DIP_phase = DIP_phase.to(self.device) self.optimizer = torch.optim.Adam( [{"params": DIP_phase.parameters(), "lr": self.LRs["phase"]}], ) DIP_phase.apply(weight_reset) if solve_amp: self._runtype += "-amp" DIP_amp.apply(weight_reset) DIP_amp = DIP_amp.to(self.device) self.optimizer.add_param_group( {"params": DIP_amp.parameters(), "lr": self.LRs["amp"]}, ) ### pretrain DIP self._pretrain_DIP(DIP_phase, DIP_amp) else: self._runtype = "AD" self._recon_phase = self.guess_phase.clone() if solve_amp: self._runtype += "-amp" self._recon_amp = self.guess_amp.clone() else: # maybe should check that more things exist, but shouldn't have to self.LR_iterations = list(self.LR_iterations) self.loss_iterations = list(self.loss_iterations) # reinitializing optimizer here so have chance to change scheduler, LRs, etc. if reset or scheduler_type != "continue": if self._use_DIP: DIP_phase = DIP_phase.to(self.device) self.optimizer = torch.optim.Adam( [{"params": DIP_phase.parameters(), "lr": self.LRs["phase"]}], ) if solve_amp: self._solve_amp = True DIP_amp = DIP_amp.to(self.device) self.optimizer.add_param_group( {"params": DIP_amp.parameters(), "lr": self.LRs["amp"]}, ) else: self._solve_amp = False DIP_amp = None self._recon_amp = self.guess_amp.clone() self._recon_amp.requires_grad = False else: DIP_phase = DIP_amp = None self._recon_phase.requires_grad = True self.optimizer = torch.optim.Adam( [{"params": self._recon_phase, "lr": self.LRs["phase"]}] ) if solve_amp: self._solve_amp = True self._recon_amp.requires_grad = True self.optimizer = self.optimizer.add_param_group( {"params": self._recon_amp, "lr": self.LRs["amp"]}, ) else: self._solve_amp = False self._recon_amp = self.guess_amp.clone() self._recon_amp.requires_grad = False if self._solve_amp_scale: self.optimizer.add_param_group( {"params": self._amp_scale, "lr": self.LRs["amp_scale"]} ) self._amp_scale.requires_grad = True if self._recon_amp.min() != self._recon_amp.max(): amp2phi_LR = self.LRs.get("amp2phi_scale", self.LRs["amp_scale"]) self.optimizer.add_param_group( {"params": self._amp2phi_scale, "lr": amp2phi_LR} ) self._amp2phi_scale.requires_grad = True else: self._amp_scale.requires_grad = False self.scheduler = self._get_scheduler(**kwargs) if save_dir is not None or self.save_dir is not None: self._check_save_name(save_dir, name, mode=f"{self._mode}_{self._runtype}") ### Recon self.vprint("Reconstructing") self._recon_loop(num_iter, print_every, DIP_phase, DIP_amp, save, store_iters_every) ttime = timedelta(seconds=(datetime.now() - self._start_time).seconds) print(f"total time (h:m:s) = {ttime}") self._recon_phase -= self._recon_phase.min() if self._best_phase is not None: self._best_phase -= self._best_phase.min() self.phase_B = self.best_phase else: warnings.warn("Was unable to find a best_phase, recon likely failed.") self.LR_iterations = np.array(self.LR_iterations) self.loss_iterations = np.array(self.loss_iterations) self._set_recon_iterations() return self
def _recon_loop( self, num_iter: int, print_every: int, DIP_phase: Optional[nn.Module], DIP_amp: Optional[nn.Module], save: bool, store_iters_every: int, ) -> None: """ Runs the reconstruction loop for a given number of iterations. Args: num_iter (int): Number of iterations to run. print_every (int): Frequency of printing progress information. DIP_phase (Optional[nn.Module]): DIP module for phase reconstruction. DIP_amp (Optional[nn.Module]): DIP module for amplitude reconstruction. save (bool): Whether to save the best reconstruction. store_iters_every (int): Frequency of storing intermediate iterations. """ stime = self._start_time for a0 in tqdm(range(num_iter)): if self._noise_frac >= 0: self.input_DIP += self._noise_frac * torch.randn( self.input_DIP.shape, device=self.device ) if DIP_phase is not None: self._recon_phase = DIP_phase(self.input_DIP)[0] if self._solve_amp and DIP_amp is not None: self._recon_amp = torch.abs(DIP_amp(self.input_DIP)[0]) if (a0 + 1) % 100 == 0: print(f"a0 {a0} applying amp constraints") self._apply_amp_constraints() loss = self._compute_loss() loss.backward() self.optimizer.step() self.optimizer.zero_grad() self.loss_iterations.append(loss.item()) self.LR_iterations.append([pg["lr"] for pg in self.optimizer.param_groups]) if (a0 == 0 or (a0 + 1) % print_every == 0) and print_every > 0: lrs = [f"{pg['lr']:.2e}" for pg in self.optimizer.param_groups] lrsp = ", ".join(lrs) ctime = timedelta(seconds=(datetime.now() - stime).seconds) self.vprint(f"{a0+1}/{num_iter} | {ctime} | loss {loss.item():.3e} | LR {lrsp}") stime = datetime.now() if (a0 == 0 or (a0 + 1) % store_iters_every == 0) and store_iters_every > 0: self._phase_iterations.append([self._recon_phase.detach().clone(), a0 + 1]) if self._solve_amp: self._amp_iterations.append([self._recon_amp.detach().clone(), a0 + 1]) if self._verbose >= 2 and a0 != 0: self.show_final() if len(self.loss_iterations) > 100 and loss.item() < min(self.loss_iterations[:-1]): self._best_phase = self._recon_phase.detach().clone() self._best_iter = len(self.loss_iterations) self.phase_B = self.best_phase if self._solve_amp: self._best_amp = self._recon_amp.detach().clone() if save: # TODO maybe add a checkpoint save function (once made save function) raise NotImplementedError if self.scheduler is not None: if hasattr(self.scheduler, "cooldown"): # is plateau self.scheduler.step(loss.item()) else: self.scheduler.step() def _apply_amp_constraints(self, mode: str = "binary") -> None: """ Apply amplitude constraints based on the specified mode. Args: mode (str): The mode of constraints. Currently only "binary" is implemented. """ if mode == "binary": ampr = self._recon_amp.ravel() highval = torch.mode(torch.round(ampr, decimals=2))[0] threshval = highval * 3 / 4 lowval = torch.sort(ampr)[len(ampr) // 10] amp = torch.where(self._recon_amp > threshval, highval, lowval) if self._blurrer is not None: amp = self._blurrer(amp[None])[0] self._recon_amp = amp else: raise NotImplementedError def _sim_images(self) -> torch.Tensor: """ Simulate images from the current amplitude and phase reconstructions. Returns: torch.Tensor: The simulated images. """ obj_waves = ( self._recon_amp * self._amp_scale * torch.exp(1.0j * (self._recon_phase - self._amp2phi_scale * self._recon_amp)) ) img_waves = torch.fft.ifft2(torch.fft.fft2(obj_waves) * self._TFs) images = torch.abs(img_waves) ** 2 return images def _compute_loss(self, return_seperate: bool = False) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Compute the loss for the current reconstruction. Args: return_seperate (bool): Whether to return separate MSE and TV losses. Returns: Union[torch.Tensor, tuple[torch.Tensor, Optional[torch.Tensor]]]: The total loss, and optionally the MSE and TV losses. """ guess_ims = self._sim_images() MSE_loss = torch.mean((guess_ims - self.inp_ims) ** 2) if self.LRs["TV_phase_weight"] == 0 and ( self.LRs["TV_amp_weight"] == 0 or not self._solve_amp ): if return_seperate: return MSE_loss, None else: return MSE_loss else: TV_loss = self._calc_TV_loss_PBC() if return_seperate: return MSE_loss, TV_loss else: return MSE_loss + TV_loss def _calc_TV_loss_PBC(self) -> Optional[torch.Tensor]: """ Calculate the total variation loss with periodic boundary conditions. Returns: Optional[torch.Tensor]: The total variation loss. """ assert self._recon_phase.ndim == 2 and self._recon_amp.ndim == 2 dy, dx = self.shape if self.LRs["TV_phase_weight"] > 0: phase_pad_h = F.pad(self._recon_phase[None, None], (0, 0, 0, 1), mode="circular")[0, 0] phase_pad_w = F.pad(self._recon_phase[None, None], (0, 1, 0, 0), mode="circular")[0, 0] TV_phase_h = torch.pow(phase_pad_h[1:, :] - phase_pad_h[:-1, :], 2).sum() TV_phase_w = torch.pow(phase_pad_w[:, 1:] - phase_pad_w[:, :-1], 2).sum() TV_phase = self.LRs["TV_phase_weight"] * (TV_phase_h + TV_phase_w) / (dy * dx) else: TV_phase = None if self._solve_amp and self.LRs["TV_amp_weight"] > 0: amp_pad_h = F.pad(self._recon_amp[None, None], (0, 0, 0, 1), mode="circular")[0, 0] amp_pad_w = F.pad(self._recon_amp[None, None], (0, 1, 0, 0), mode="circular")[0, 0] TV_amp_h = torch.pow(amp_pad_h[1:, :] - amp_pad_h[:-1, :], 2).sum() TV_amp_w = torch.pow(amp_pad_w[:, 1:] - amp_pad_w[:, :-1], 2).sum() TV_amp = self.LRs["TV_amp_weight"] * (TV_amp_h + TV_amp_w) / (dy * dx) if TV_phase is None: return TV_amp else: return (TV_amp + TV_phase) / 2 else: return TV_phase def _set_recon_iterations(self) -> None: """ Update phase_iterations and amp_iterations with filtered tensors converted to np arrays. """ if len(self._phase_iterations) > 0: self.phase_iterations = [] for iter in self._phase_iterations: ph = ndi.gaussian_filter(iter[0].cpu().detach().numpy(), self._gaussian_sigma) ph -= ph.min() self.phase_iterations.append((ph, iter[1])) if len(self._amp_iterations) > 0: self.amp_iterations = [] for iter in self._amp_iterations: amp = ndi.gaussian_filter(iter[0].cpu().detach().numpy(), self._gaussian_sigma) self.amp_iterations.append((amp, iter[1])) def _get_amp2phi_scale(self) -> float: """ Calculate the scale factor for converting amplitude to phase. Returns: float: The scale factor. """ return self.sample_params["dirt_V0"] * self.scope.sigma * self.sample_params["dirt_l0"] / (self.scope.W0)
[docs] def show_final(self, crop: int=5, **kwargs) -> None: """Show the phase and induction of the final iteration. Args: crop (int, optional): Amount to crop off of induction maps before displaying; often necessary in order to avoid edge artifacts. Defaults to 5. """ ph = self.recon_phase By, Bx = self.induction_from_phase(ph) if crop > 0: crop = int(crop) ph = ph[crop:-crop, crop:-crop] Bx = Bx[crop:-crop, crop:-crop] By = By[crop:-crop, crop:-crop] if self._solve_amp: fig, axs = plt.subplots(ncols=3, figsize=(12, 4)) show_im( ph, title=f"Recon phase: iter {len(self.loss_iterations)}", figax=(fig, axs[0]), cbar_title="rad", ) show_2D( Bx, By, title=f"Recon B: iter {len(self.loss_iterations)}", figax=(fig, axs[1]), ) show_im( self.recon_amp, title=f"Recon amp: iter {len(self.loss_iterations)}", figax=(fig, axs[2]), ) plt.tight_layout() plt.show() else: fig, axs = plt.subplots(ncols=2, figsize=(8, 4)) show_im( ph, title=f"Recon phase: iter {len(self.loss_iterations)}", scale=self.scale, cbar_title="rad", figax=(fig, axs[0]), ) show_2D( Bx, By, title=f"Recon B: iter {len(self.loss_iterations)}", figax=(fig, axs[1]), ) plt.tight_layout() plt.show()
[docs] def visualize(self, crop=5): """Plot the best reconstructed phase and induction maps. Args: crop (int, optional): Amount to crop off of induction maps before displaying; often necessary in order to avoid edge artifacts. Defaults to 5. """ if self._solve_amp: fig = plt.figure(figsize=(12, 8)) ax1 = fig.add_subplot(231) self.show_phase_B(figax=(fig, ax1), cbar_title=None, crop=crop) ax2 = fig.add_subplot(232) self.show_B(figax=(fig, ax2), crop=crop) ax2p5 = fig.add_subplot(233) show_im(self.best_amp, figax=(fig, ax2p5), ticks_off=True, title="amp") ax3 = fig.add_subplot(212) l1 = ax3.semilogy(self.loss_iterations, color="tab:blue", label="loss") ax3.set_xlabel("iterations") ax3.set_ylabel("loss") ax4 = ax3.twinx() LRs = np.array(self.LR_iterations) l2 = ax4.semilogy(LRs[:, 0], color="tab:orange", label="phase LR") lns = l1 + l2 labs = [l.get_label() for l in lns] ax4.legend(lns, labs, loc=0) ax4.set_ylabel("LR") ax4.yaxis.set_major_formatter(FormatStrFormatter("%.2e")) else: fig = plt.figure(figsize=(8, 8)) ax1 = fig.add_subplot(221) self.show_phase_B(figax=(fig, ax1), cbar_title=None, crop=crop) ax2 = fig.add_subplot(222) self.show_B(figax=(fig, ax2), crop=crop) ax3 = fig.add_subplot(212) l1 = ax3.semilogy(self.loss_iterations, color="tab:blue", label="loss") ax3.set_xlabel("iterations") ax3.set_ylabel("loss") ax4 = ax3.twinx() LRs = np.array(self.LR_iterations) l2 = ax4.semilogy(LRs[:, 0], color="tab:orange", label="phase LR") lns = l1 + l2 labs = [l.get_label() for l in lns] ax4.legend(lns, labs, loc=0) ax4.set_ylabel("LR") ax4.yaxis.set_major_formatter(FormatStrFormatter("%.2e")) plt.show() return
def __len__(self): return len(self.loss_iterations)
[docs] def save_results( self, iter_ind: int=None, save_dir: Optional[os.PathLike] = None, name: Optional[str] = None, overwrite: bool = False, ): """ Save the recontructed phase, Bx, By, and color images. # TODO add saving amplitude and phase_E Args: iter_ind (int, optional): Index to save. Defaults to None which saves best phase. save_dir (os.PathLik], optional): Directory to save in. Defaults to None which saves in self.save_dir. name (str, optional): Name to prepend saved files. Defaults to None. overwrite (bool, optional): Whether or not to overwrite files. Defaults to False. """ if save_dir is not None or name is not None: # don't want to overwrite original name with timestamp self._check_save_name(save_dir, name=name, default_name=False) if iter_ind is not None: phase, iter = self.phase_iterations[iter_ind] By, Bx = self.induction_from_phase(phase) else: iter = self._best_iter phase = self.best_phase Bx = self.Bx By = self.By results = { "phase_B": phase, "By": By, "Bx": Bx, "color": None, } if self._mode == "SIPRAD": results["input_image"] = self.dd.image else: raise NotImplementedError("need to figure out how saving defocus vals") save_name_no_iter = "_".join(self._save_name.split("_")[:3]) self._save_name = save_name_no_iter + f"_i{iter}" self.save_dir.mkdir(exist_ok=True) self._save_keys(list(results.keys()), self.defvals[0], overwrite, res_dict=results) self._save_log(overwrite)
def _save_log(self, overwrite: Optional[bool] = None): """ Save the reconstruction log. Args: overwrite (bool, optional): Whether to overwrite existing files. Default is None. """ log_dict = { "name": self.name, "_save_name": self._save_name, "defval": self.defvals.squeeze(), "scale": self.scale, "transforms": self.dd.transforms, "filters": self.dd._filters, "beam_energy": self.dd.beam_energy, "simulated": self.dd._simulated, "data_dir": self.dd.data_dir, "data_files": self.dd.data_files, "save_dir": self._save_dir, } ovr = overwrite if overwrite is not None else self._overwrite name = f"{self._save_name}_{self._fmt_defocus(self.defvals[0])}_log.json" write_json(log_dict, self.save_dir / name, overwrite=ovr, v=self._verbose)