from __future__ import annotations
import os
import warnings
from datetime import datetime, timedelta
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Self, Union
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
from matplotlib.ticker import FormatStrFormatter
from torch import nn
from tqdm import tqdm
from PyLorentz.dataset.defocused_dataset import DefocusedDataset
from PyLorentz.io.write import write_json
from PyLorentz.phase.base_phase import BasePhaseReconstruction
from PyLorentz.utils import Microscope
from PyLorentz.visualize import show_2D, show_im
from .DIP_NN import weight_reset
from .sitie import SITIE
_HAS_TORCH = False
if TYPE_CHECKING:
import torch
import torch.nn.functional as F
import torchvision.transforms as TvT
from torch import Tensor
else:
try:
import torch
import torch.nn.functional as F
import torchvision.transforms as TvT
from torch import Tensor
_HAS_TORCH = True
except:
_HAS_TORCH = False
[docs]
class ADPhase(BasePhaseReconstruction):
"""
ADPhase class for phase reconstruction using defocused datasets and DIPs.
"""
_default_sample_params = {
"dirt_V0": 20,
"dirt_xip0": 10,
}
_default_LRs = {
"phase": 2e-4, # 2e-4 for DIP, 0.2 otherwise
"amp": 2e-4, # 2e-4 for DIP, 0.02 otherwise
"amp_scale": 0.1, # used when not solve_amp
"amp2phi_scale": 1,
"TV_phase_weight": 5e-3, # only used if reconstructing without DIPs
"TV_amp_weight": 5e-3, # only used if reconstructing without DIPs
}
[docs]
def __init__(
self,
dd: DefocusedDataset,
device: Union[str, int, torch.device],
save_dir: Optional[os.PathLike] = None,
name: Optional[str] = None,
verbose: bool = True,
scope: Optional[Microscope] = None,
sample_params: dict = {},
rng_seed: Optional[int] = None,
LRs: dict = {},
scheduler_type: Optional[str] = None,
noise_frac: float = 1 / 100,
gaussian_sigma: float = 1,
):
"""
Initialize the ADPhase object.
Args:
dd (DefocusedDataset): The defocused dataset.
device (Union[str, int]): The device to use (CPU or GPU).
save_dir (Optional[os.PathLike], optional): Directory to save results.
name (Optional[str], optional): Name for the results.
verbose (bool, optional): Verbosity level.
scope (Optional[Microscope], optional): Microscope object.
sample_params (dict, optional): Sample parameters.
rng_seed (Optional[int], optional): Random seed.
LRs (dict, optional): Learning rates for optimization.
scheduler_type (Optional[str], optional): Type of learning rate scheduler.
noise_frac (float, optional): Fraction of noise to add.
gaussian_sigma (float, optional): Sigma value for Gaussian filter.
"""
self.dd = dd
if len(dd) == 1:
self._mode = "SIPRAD"
else:
self._mode = "ADPhase"
if save_dir is None and dd.data_dir is not None:
topdir = Path(dd.data_dir)
if topdir.exists():
save_dir = topdir / f"{self._mode}_outputs"
super().__init__(save_dir, name, dd.scale, verbose)
self.sample_params = self._default_sample_params | sample_params
self.LRs = self._default_LRs | LRs
self.device = device
self.inp_ims = torch.tensor(self.dd.images, device=self.device, dtype=torch.float32)
self.defvals = dd.defvals
if scope is not None:
self.scope = scope
self._rng = np.random.default_rng(rng_seed)
self._noise_frac = noise_frac
self._scheduler_type = scheduler_type
self.pad = (0, 0)
# to be set later:
# self._guess_phase: Tensor = torch.tensor(0)
self._runtype: str = ""
self._recon_amp: Tensor = torch.tensor(0)
self._recon_phase: Tensor = torch.tensor(0)
self._use_DIP: Optional[bool] = None
self._solve_amp: Optional[bool] = None
self._solve_amp_scale: Optional[bool] = None
self._best_phase: Tensor = torch.tensor(0)
self._best_amp: Tensor = torch.tensor(0)
self._best_iter: Optional[int] = None
self._phase_iterations: list[tuple[Tensor, int]] = []
self._amp_iterations: list[tuple[Tensor, int]] = []
self.phase_iterations: list[tuple[np.ndarray, int]] = []
self.amp_iterations: list[tuple[np.ndarray, int]] = []
self.loss_iterations: list | np.ndarray = [] # TODO clean this up with property
self.LR_iterations: list | np.ndarray = []
self.gaussian_sigma = gaussian_sigma
self._TFs = self.get_TFs()
@property
def shape(self):
return self.dd.shape
@property
def shape_full(self):
"""shape with padding"""
return self.dd.shape[0] + 2 * self.pad[0], self.dd.shape[1] + 2 * self.pad[1]
@property
def pad(self):
"""
Padding around all items in pixels, equal padding before/after each axis. (pad_y, pad_x)
"""
return self._pad
@pad.setter
def pad(self, pad: tuple):
if len(pad) != 2:
raise ValueError(f"Bad shape, pad should be (pad_y, pad_x)")
else:
self._pad = (int(round(pad[0])), int(round(pad[1])))
def _detach_and_crop(self, im: np.ndarray | Tensor) -> np.ndarray:
if _HAS_TORCH:
if isinstance(im, torch.Tensor):
out = im.cpu().detach().numpy()
else:
out = im.copy()
else:
assert isinstance(im, np.ndarray)
out = im.copy()
if self._pad[0] > 0:
out = out[self._pad[0] : -self._pad[0]]
if self._pad[1] > 0:
out = out[:, self._pad[1] : -self._pad[1]]
return out
@property
def recon_phase(self) -> np.ndarray:
"""
Returns the reconstructed phase after applying Gaussian filter.
This is the cropped recon phase without padding.
Returns:
Optional[np.ndarray]: Reconstructed phase image.
"""
if self._recon_phase.ndim == 0:
raise AttributeError(f"recon_phase has not yet been set")
else:
ph = self._detach_and_crop(self._recon_phase)
ph = ndi.gaussian_filter(ph, self._gaussian_sigma)
ph -= ph.min()
return ph
@property
def recon_phase_full(self) -> np.ndarray:
"""
Returns the reconstructed phase after applying Gaussian filter.
This includes any padding.
Returns:
Optional[np.ndarray]: Reconstructed phase image.
"""
if self._recon_phase.ndim == 0:
raise AttributeError(f"recon_phase has not yet been set")
else:
ph = self._recon_phase.cpu().detach().numpy()
ph = ndi.gaussian_filter(ph, self._gaussian_sigma)
ph -= ph.min()
return ph
@property
def best_phase(self) -> np.ndarray:
"""
Returns the best phase after applying Gaussian filter.
This is the cropped best phase without padding.
Returns:
np.ndarray: Best phase image.
"""
if self._best_phase.ndim == 0:
raise AttributeError(f"best_phase has not yet been set")
else:
ph = self._detach_and_crop(self._best_phase)
ph = ndi.gaussian_filter(ph, self._gaussian_sigma)
ph -= ph.min()
return ph
@property
def best_phase_full(self) -> np.ndarray:
"""
Returns the best phase after applying Gaussian filter.
This is the full phase with any padding.
Returns:
np.ndarray: Best phase image.
"""
if self._best_phase.ndim == 0:
raise AttributeError(f"best_phase has not yet been set")
else:
ph = self._best_phase.cpu().detach().numpy()
ph = ndi.gaussian_filter(ph, self._gaussian_sigma)
ph -= ph.min()
return ph
@property
def phase_B_full(self) -> Optional[np.ndarray]:
return self.best_phase_full
[docs]
def set_best_iter(self, iter_ind: int = -1) -> None:
"""
Sets the best phase from the specified iteration index.
Args:
iter_ind (int, optional): Index of the iteration to use for the best phase.
"""
self._best_phase, iter = self._phase_iterations[iter_ind]
if len(self._amp_iterations) > 0:
self._best_amp, _ = self._amp_iterations[iter_ind]
self._best_iter = iter
self.phase_B = self.best_phase
@property
def recon_amp(self) -> np.ndarray:
"""
Returns the reconstructed amplitude after applying Gaussian filter.
This is the cropped recon amplitude.
Returns:
Optional[np.ndarray]: Reconstructed amplitude image.
"""
if self._recon_amp.ndim == 0:
raise AttributeError(f"recon_amp has not yet been set")
else:
amp = self._detach_and_crop(self._recon_amp)
amp = ndi.gaussian_filter(amp, self._gaussian_sigma)
return amp
@property
def recon_amp_full(self) -> np.ndarray:
"""
Returns the reconstructed amplitude after applying Gaussian filter.
This is the full recon amplitude with any padding.
Returns:
Optional[np.ndarray]: Reconstructed amplitude image.
"""
if self._recon_amp.ndim == 0:
raise AttributeError(f"recon_amp has not yet been set")
else:
amp = self._recon_amp.cpu().detach().numpy()
amp = ndi.gaussian_filter(amp, self._gaussian_sigma)
return amp
@property
def best_amp(self) -> np.ndarray:
"""
Returns the best amplitude after applying Gaussian filter.
This is the cropped best amplitude without padding.
Returns:
np.ndarray: Best amplitude image.
"""
if self._best_amp.ndim == 0:
raise AttributeError(f"best_amp has not yet been set")
else:
amp = self._detach_and_crop(self._best_amp)
amp = ndi.gaussian_filter(amp, self._gaussian_sigma)
return amp
@property
def best_amp_full(self) -> np.ndarray:
"""
Returns the best amplitude after applying Gaussian filter.
This is the full amplitude with padding.
Returns:
np.ndarray: Best amplitude image.
"""
if self._best_amp.ndim == 0:
raise AttributeError(f"best_amp has not yet been set")
else:
amp = self._best_amp.cpu().detach().numpy()
amp = ndi.gaussian_filter(amp, self._gaussian_sigma)
return amp
@property
def gaussian_sigma(self) -> float:
"""
Returns the Gaussian sigma value.
Returns:
float: Gaussian sigma value.
"""
return self._gaussian_sigma
@gaussian_sigma.setter
def gaussian_sigma(self, val: Optional[float]) -> None:
"""
Sets the Gaussian sigma value and updates the blurring transformation.
Args:
val (Optional[float]): Gaussian sigma value.
Raises:
TypeError: If `val` is not a numeric type.
ValueError: If `val` is less than 0.
"""
if val is None:
self._gaussian_sigma = 0
self._blurrer = None
elif not isinstance(val, (float, int)):
raise TypeError(f"gaussian_sigma must be numeric, received {type(val)}")
elif val < 0:
raise ValueError(f"gaussian_sigma must be >= 0 or None, received {val}")
else:
self._blurrer = TvT.GaussianBlur(kernel_size=(9, 9), sigma=(val, val))
self._gaussian_sigma = val
try:
if self.best_phase is not None:
self.phase_B = self.best_phase
except AttributeError:
pass
self._set_recon_iterations()
@property
def _TFs(self) -> Tensor:
"""
Returns the transfer functions.
Returns:
Tensor: Transfer functions.
"""
return self.transfer_functions
@_TFs.setter
def _TFs(self, arr: Union[Tensor, np.ndarray]) -> None:
"""
Sets the transfer functions.
Args:
arr (Union[Tensor, np.ndarray]): Transfer functions array.
Raises:
AssertionError: If `arr` does not match expected shape.
"""
if not isinstance(arr, torch.Tensor):
arr = torch.tensor(arr, device=self.device, dtype=torch.complex64)
assert len(arr.shape) == 3, f"Bad TF shape: {arr.shape}, should be {self.dd.images.shape}"
assert len(arr) == len(self.dd), f"len TFs {len(arr)} != # defvals {len(self.dd)}"
self.transfer_functions = arr
@property
def scope(self) -> Microscope:
"""
Returns the microscope object.
Returns:
Optional[Microscope]: Microscope object.
"""
if self._scope is None:
raise AttributeError(f"self.scope has not been set and is None")
return self._scope
@scope.setter
def scope(self, microscope: Microscope) -> None:
"""
Sets the microscope object.
Args:
microscope (Microscope): Microscope object to set.
Raises:
TypeError: If `microscope` is not of type `Microscope`.
"""
if not isinstance(microscope, Microscope):
raise TypeError(
f"microscope must be a PyLorentz.utils.microscopes.Microscope object, received {type(microscope)}"
)
else:
self._scope = microscope
@property
def device(self) -> str | torch.device:
"""
Returns the device used for computation.
Returns:
str: Device for computation.
"""
return self._device
@device.setter
def device(self, dev: Union[int, str, torch.device]) -> None:
"""
Sets the device for computation.
Args:
dev (Union[int, str]): Device identifier.
Raises:
TypeError: If `dev` is not of type int or str.
ValueError: If `dev` exceeds available GPUs.
"""
if isinstance(dev, torch.device):
if dev.type == "gpu":
self._device = dev
ind = dev.index
elif dev.type == "cuda":
self._device = dev
elif dev.type == "cpu":
self._device = dev
else:
raise TypeError(
f"Unknown device type: {dev.type} This can likely be fixed easily."
)
elif isinstance(dev, (str, int, float)):
if dev in ["cpu", "CPU"]:
warnings.warn("Setting device to cpu, this will be slow and might fail.")
ind = "cpu"
elif dev in ["gpu", "GPU"]:
assert torch.cuda.is_available(), f"No GPUs available"
ind = 0
else:
assert torch.cuda.is_available(), f"No GPUs available"
ind = int(dev)
if ind >= torch.cuda.device_count():
raise ValueError(
f"device must be < num_devices, which is {torch.cuda.device_count()}. Received device {ind}"
)
self._device = torch.device(ind)
if isinstance(ind, int):
self.vprint(f"Proceeding with GPU {ind}: {torch.cuda.get_device_name(ind)}")
else:
raise TypeError(f"Device should be int, str, or torch.device. Received {type(dev)}")
@property
def guess_phase(self) -> Tensor:
"""
Returns the guess phase used to pre-train the DIP.
Returns:
Optional[Tensor]: Guess phase tensor.
"""
if self._guess_phase.ndim == 0:
raise AttributeError(f"recon_amp has not yet been set")
return self._guess_phase
@guess_phase.setter
def guess_phase(self, im: Union[np.ndarray, Tensor]) -> None:
"""
Sets the guess phase.
Args:
im (Union[np.ndarray, Tensor]): Guess phase image.
Raises:
ValueError: If `im` shape does not match input image shape.
"""
if not isinstance(im, torch.Tensor):
im = torch.tensor(im, dtype=torch.float32)
im = im.to(self.device)
if im.shape != self.shape_full:
raise ValueError(
f"Guess phase shape, {im.shape} should match full reconstruction shape, {self.shape_full}"
)
self._guess_phase = im
@property
def guess_amp(self) -> Tensor:
"""
Returns the guess amplitude used to pre-train the DIP or if `solve_amp` is False.
Returns:
Tensor: Guess amplitude tensor.
"""
return self._guess_amp
@guess_amp.setter
def guess_amp(self, im: Union[np.ndarray, Tensor]) -> None:
"""
Sets the guess amplitude.
Args:
im (Union[np.ndarray, Tensor]): Guess amplitude image.
Raises:
ValueError: If `im` shape does not match input image shape.
"""
if not isinstance(im, torch.Tensor):
im = torch.tensor(im, dtype=torch.float32)
im = im.to(self.device)
if im.shape != self.shape_full:
raise ValueError(
f"Guess amp shape, {im.shape} should match full reconstruction shape, {self.shape_full}"
)
self._guess_amp = im
@property
def input_DIP(self) -> Optional[Tensor]:
"""
Returns the noise used as input for one or both DIPs.
Returns:
Optional[Tensor]: Input noise tensor.
"""
return self._input_DIP
@input_DIP.setter
def input_DIP(self, im: Union[np.ndarray, Tensor]) -> None:
"""
Sets the input noise for DIPs.
Args:
im (Union[np.ndarray, Tensor]): Input noise image.
Raises:
ValueError: If `im` shape does not match input image shape.
"""
if not isinstance(im, torch.Tensor):
im = torch.tensor(im, device=self.device, dtype=torch.float32)
if im.shape[1:] != self.shape_full:
# TODO not sure if this is correct for multiple images
raise ValueError(
f"DIP input shape, {im.shape} should match full reconstruction shape, {self.shape_full}"
)
self._input_DIP = im
[docs]
def reconstruct(
self,
num_iter: int,
model: Optional[Union[nn.Module, list[nn.Module]]] = None,
num_pretrain_iter: int = 0,
solve_amp: bool = False,
solve_amp_scale: bool = True,
guess_amp: Optional[Union[float, np.ndarray]] = None,
LRs: dict = {},
scheduler_type: Optional[str] = None,
save: bool = False,
name: Optional[str] = None,
save_dir: Optional[os.PathLike] = None,
noise_frac: Optional[float] = None,
guess_phase: Union[str, np.ndarray] = "SITIE",
input_DIP: str | np.ndarray | None = "SITIE",
reset: bool = True,
print_every: int = -1,
verbose: int = 1,
store_iters_every: int = -1,
qc: Optional[float] = None,
pad: tuple | None = None,
**kwargs, # scheduler params
) -> Self:
"""
Performs the reconstruction process.
Args:
num_iter (int): Number of iterations for reconstruction.
model (Optional[Union[nn.Module, list[nn.Module]]], optional): Model or list of models for DIP.
num_pretrain_iter (int, optional): Number of pretraining iterations.
solve_amp (bool, optional): Whether to solve for amplitude.
solve_amp_scale (bool, optional): Whether to solve amplitude scale.
guess_amp (Optional[Union[float, np.ndarray]], optional): Guess amplitude.
LRs (dict, optional): Learning rates for optimization.
scheduler_type (Optional[str], optional): Type of learning rate scheduler.
save (bool, optional): Whether to save results.
name (Optional[str], optional): Name for the saved results.
save_dir (Optional[os.PathLike], optional): Directory to save results.
noise_frac (Optional[float], optional): Fraction of noise to add.
guess_phase (Union[str, np.ndarray, None], optional): Guess phase or method to obtain it.
reset (bool, optional): Whether to reset the model.
print_every (int, optional): Frequency of printing progress.
verbose (int, optional): Verbosity level.
store_iters_every (int, optional): Frequency of storing iterations.
qc (Optional[any], optional): Quality control object.
**kwargs: Additional keyword arguments for scheduler parameters.
"""
### SETUP
self._start_time = datetime.now()
self._num_pretrain_iter = num_pretrain_iter
if pad is not None:
self.pad = pad
self._TFs = self.get_TFs()
if noise_frac is not None:
self._noise_frac = noise_frac
if verbose is not None:
self._verbose = verbose
self.LRs = self.LRs | LRs
self._solve_amp_scale = solve_amp_scale if not solve_amp else False
if isinstance(model, nn.Module):
self._use_DIP = True
DIP_phase = model
DIP_amp = None
elif model is not None:
self._use_DIP = True
DIP_phase, DIP_amp = model
else:
self._use_DIP = False
DIP_phase = DIP_amp = None
if scheduler_type is not None:
self._scheduler_type = scheduler_type
self._qc = qc
if save:
if name is None:
if self.name is not None:
name = self.name
else:
now = self._start_time.strftime("%y%m%d-%H%M%S")
if len(self.dd) == 1:
mode = "SIPRAD"
self._results["input_image"] = self.dd.images[0]
else:
mode = f"N{len(self.dd)}AD"
name = f"{now}_{mode}"
self._check_save_name(save_dir, name=name)
self._noise_frac = noise_frac if noise_frac is not None else self._noise_frac
### Initialization/reset
reset = True if len(self.loss_iterations) == 0 else reset
if reset:
# guess phase is what the DIP is trained to output during pre-training, so is distinct
# from the DIP_input
self._set_guess_phase(guess_phase)
self._set_guess_amp(guess_amp)
self.loss_iterations = []
self.LR_iterations = []
self._phase_iterations = []
self._amp_iterations = []
self._amp_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self._amp2phi_scale = torch.tensor(
[self._get_amp2phi_scale()], dtype=torch.float32, device=self.device
)
if self._use_DIP:
assert isinstance(DIP_phase, nn.Module)
self._runtype = "DIP"
self._set_input_DIP(input_DIP=input_DIP)
DIP_phase = DIP_phase.to(self.device)
self.optimizer = torch.optim.Adam( # type:ignore
[{"params": DIP_phase.parameters(), "lr": self.LRs["phase"]}],
)
DIP_phase.apply(weight_reset)
if solve_amp:
assert DIP_amp is not None
self._runtype += "-amp"
DIP_amp.apply(weight_reset)
DIP_amp = DIP_amp.to(self.device)
self.optimizer.add_param_group(
{"params": DIP_amp.parameters(), "lr": self.LRs["amp"]},
)
### pretrain DIP
self._pretrain_DIP(DIP_phase, DIP_amp)
else:
self._runtype = "AD"
self._recon_phase = self.guess_phase.clone()
if solve_amp:
self._runtype += "-amp"
self._recon_amp = self.guess_amp.clone()
else:
# maybe should check that more things exist, but shouldn't have to
self.LR_iterations = list(self.LR_iterations)
self.loss_iterations = list(self.loss_iterations)
# reinitializing optimizer here so have chance to change scheduler, LRs, etc.
if reset or scheduler_type != "continue":
if self._use_DIP:
assert DIP_phase is not None
DIP_phase = DIP_phase.to(self.device)
self.optimizer = torch.optim.Adam( # type:ignore
[{"params": DIP_phase.parameters(), "lr": self.LRs["phase"]}],
)
if solve_amp:
assert DIP_amp is not None
self._solve_amp = True
DIP_amp = DIP_amp.to(self.device)
self.optimizer.add_param_group(
{"params": DIP_amp.parameters(), "lr": self.LRs["amp"]},
)
else:
self._solve_amp = False
DIP_amp = None
self._recon_amp = self.guess_amp.clone()
self._recon_amp.requires_grad = False
else:
DIP_phase = DIP_amp = None
assert self._recon_phase is not None
self._recon_phase.requires_grad = True
self.optimizer = torch.optim.Adam( # type:ignore
[{"params": self._recon_phase, "lr": self.LRs["phase"]}]
)
if solve_amp:
self._solve_amp = True
self._recon_amp.requires_grad = True
self.optimizer.add_param_group(
{"params": self._recon_amp, "lr": self.LRs["amp"]},
)
else:
self._solve_amp = False
self._recon_amp = self.guess_amp.clone()
self._recon_amp.requires_grad = False
if self._solve_amp_scale:
self.optimizer.add_param_group(
{"params": self._amp_scale, "lr": self.LRs["amp_scale"]}
)
self._amp_scale.requires_grad = True
if self._recon_amp.min() != self._recon_amp.max():
amp2phi_LR = self.LRs.get("amp2phi_scale", self.LRs["amp_scale"])
self.optimizer.add_param_group(
{"params": self._amp2phi_scale, "lr": amp2phi_LR}
)
self._amp2phi_scale.requires_grad = True
else:
self._amp_scale.requires_grad = False
self.scheduler = self._get_scheduler(**kwargs)
if save_dir is not None or self.save_dir is not None:
self._check_save_name(save_dir, name, mode=f"{self._mode}_{self._runtype}")
### Recon
self.vprint("Reconstructing")
self._recon_loop(num_iter, print_every, DIP_phase, DIP_amp, save, store_iters_every)
ttime = timedelta(seconds=(datetime.now() - self._start_time).seconds)
print(f"total time (h:m:s) = {ttime}")
self._recon_phase -= self._recon_phase.min()
if self._best_phase is not None:
self._best_phase -= self._best_phase.min()
self.phase_B = self.best_phase
else:
warnings.warn("Was unable to find a best_phase, recon likely failed.")
self.LR_iterations = np.array(self.LR_iterations)
self.loss_iterations = np.array(self.loss_iterations)
self._set_recon_iterations()
return self
def _recon_loop(
self,
num_iter: int,
print_every: int,
DIP_phase: Optional[nn.Module],
DIP_amp: Optional[nn.Module],
save: bool,
store_iters_every: int,
) -> None:
"""
Runs the reconstruction loop for a given number of iterations.
Args:
num_iter (int): Number of iterations to run.
print_every (int): Frequency of printing progress information.
DIP_phase (Optional[nn.Module]): DIP module for phase reconstruction.
DIP_amp (Optional[nn.Module]): DIP module for amplitude reconstruction.
save (bool): Whether to save the best reconstruction.
store_iters_every (int): Frequency of storing intermediate iterations.
"""
assert isinstance(self.input_DIP, Tensor)
assert isinstance(self.loss_iterations, list)
assert isinstance(self.LR_iterations, list)
stime = self._start_time
for a0 in tqdm(range(num_iter)):
if self._noise_frac >= 0:
self.input_DIP = self.input_DIP + self._noise_frac * torch.randn(
self.input_DIP.shape, device=self.device
)
if DIP_phase is not None:
self._recon_phase = DIP_phase(self.input_DIP)[0]
if self._solve_amp and DIP_amp is not None:
self._recon_amp = torch.abs(DIP_amp(self.input_DIP)[0])
if (a0 + 1) % 100 == 0:
print(f"a0 {a0} applying amp constraints")
self._apply_amp_constraints()
loss = self._compute_loss()
assert isinstance(loss, Tensor) # remove ned by make a compute_loss_sep function
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
self.loss_iterations.append(loss.item())
self.LR_iterations.append([pg["lr"] for pg in self.optimizer.param_groups])
if (a0 == 0 or (a0 + 1) % print_every == 0) and print_every > 0:
lrs = [f"{pg['lr']:.2e}" for pg in self.optimizer.param_groups]
lrsp = ", ".join(lrs)
ctime = timedelta(seconds=(datetime.now() - stime).seconds)
self.vprint(f"{a0+1}/{num_iter} | {ctime} | loss {loss.item():.3e} | LR {lrsp}")
stime = datetime.now()
if (a0 == 0 or (a0 + 1) % store_iters_every == 0) and store_iters_every > 0:
self._phase_iterations.append((self._recon_phase.detach().clone(), a0 + 1))
if self._solve_amp:
self._amp_iterations.append((self._recon_amp.detach().clone(), a0 + 1))
if self._verbose >= 2 and a0 != 0:
self.show_final()
if len(self.loss_iterations) > 100 and loss.item() < min(self.loss_iterations[:-1]):
self._best_phase = self._recon_phase.detach().clone()
self._best_iter = len(self.loss_iterations)
self.phase_B = self.best_phase
if self._solve_amp:
self._best_amp = self._recon_amp.detach().clone()
if save:
# TODO maybe add a checkpoint save function (once made save function)
raise NotImplementedError
if self.scheduler is not None:
if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(loss.item())
else:
self.scheduler.step()
def _apply_amp_constraints(self, mode: str = "binary") -> None:
"""
Apply amplitude constraints based on the specified mode.
Args:
mode (str): The mode of constraints. Currently only "binary" is implemented.
"""
if mode == "binary":
ampr = self._recon_amp.ravel()
highval = torch.mode(torch.round(ampr, decimals=2))[0]
threshval = highval * 3 / 4
lowval = torch.sort(ampr)[len(ampr) // 10]
amp = torch.where(self._recon_amp > threshval, highval, lowval)
if self._blurrer is not None:
amp = self._blurrer(amp[None])[0]
self._recon_amp = amp
else:
raise NotImplementedError
def _sim_images(self) -> Tensor:
"""
Simulate images from the current amplitude and phase reconstructions.
Returns:
Tensor: The simulated images.
"""
obj_waves = (
self._recon_amp
* self._amp_scale
* torch.exp(1.0j * (self._recon_phase - self._amp2phi_scale * self._recon_amp))
)
img_waves = torch.fft.ifft2(torch.fft.fft2(obj_waves) * self._TFs)
images = torch.abs(img_waves) ** 2
if self.pad[0] > 0:
images = images[:, self.pad[0] : -self.pad[0]]
if self.pad[1] > 0:
images = images[:, :, self.pad[1] : -self.pad[1]]
return images
def _compute_loss(
self, return_seperate: bool = False
) -> Union[Tensor, tuple[Tensor, Optional[Tensor]]]:
"""
Compute the loss for the current reconstruction.
Args:
return_seperate (bool): Whether to return separate MSE and TV losses.
Returns:
Union[Tensor, tuple[Tensor, Optional[Tensor]]]: The total loss, and optionally the MSE and TV losses.
"""
guess_ims = self._sim_images()
MSE_loss = torch.mean((guess_ims - self.inp_ims) ** 2)
if self.LRs["TV_phase_weight"] == 0 and (
self.LRs["TV_amp_weight"] == 0 or not self._solve_amp
):
if return_seperate:
return MSE_loss, None
else:
return MSE_loss
else:
TV_loss = self._calc_TV_loss_PBC()
if return_seperate:
return MSE_loss, TV_loss
else:
return MSE_loss + TV_loss
def _calc_TV_loss_PBC(self) -> Optional[Tensor]:
"""
Calculate the total variation loss with periodic boundary conditions.
Returns:
Optional[Tensor]: The total variation loss.
"""
assert self._recon_phase.ndim == 2 and self._recon_amp.ndim == 2
dy, dx = self.shape_full
if self.LRs["TV_phase_weight"] > 0:
phase_pad_h = F.pad(self._recon_phase[None, None], (0, 0, 0, 1), mode="circular")[0, 0]
phase_pad_w = F.pad(self._recon_phase[None, None], (0, 1, 0, 0), mode="circular")[0, 0]
TV_phase_h = torch.pow(phase_pad_h[1:, :] - phase_pad_h[:-1, :], 2).sum()
TV_phase_w = torch.pow(phase_pad_w[:, 1:] - phase_pad_w[:, :-1], 2).sum()
TV_phase = self.LRs["TV_phase_weight"] * (TV_phase_h + TV_phase_w) / (dy * dx)
else:
TV_phase = None
if self._solve_amp and self.LRs["TV_amp_weight"] > 0:
amp_pad_h = F.pad(self._recon_amp[None, None], (0, 0, 0, 1), mode="circular")[0, 0]
amp_pad_w = F.pad(self._recon_amp[None, None], (0, 1, 0, 0), mode="circular")[0, 0]
TV_amp_h = torch.pow(amp_pad_h[1:, :] - amp_pad_h[:-1, :], 2).sum()
TV_amp_w = torch.pow(amp_pad_w[:, 1:] - amp_pad_w[:, :-1], 2).sum()
TV_amp = self.LRs["TV_amp_weight"] * (TV_amp_h + TV_amp_w) / (dy * dx)
if TV_phase is None:
return TV_amp
else:
return (TV_amp + TV_phase) / 2
else:
return TV_phase
def _set_recon_iterations(self) -> None:
"""
Update phase_iterations and amp_iterations with filtered tensors converted to np arrays.
"""
if len(self._phase_iterations) > 0:
phase_iterations = []
for iter in self._phase_iterations:
ph = ndi.gaussian_filter(iter[0].cpu().detach().numpy(), self._gaussian_sigma)
ph -= ph.min()
phase_iterations.append((ph, iter[1]))
self.phase_iterations = phase_iterations
if len(self._amp_iterations) > 0:
amp_iterations = []
for iter in self._amp_iterations:
amp = ndi.gaussian_filter(iter[0].cpu().detach().numpy(), self._gaussian_sigma)
amp_iterations.append((amp, iter[1]))
self.amp_iterations = amp_iterations
def _get_amp2phi_scale(self) -> float:
"""
Calculate the scale factor for converting amplitude to phase.
Returns:
float: The scale factor.
"""
if self.scope is None:
raise AttributeError(f"self.scope has not been set and is None")
return self.sample_params["dirt_V0"] * self.scope.sigma * self.sample_params["dirt_xip0"]
def _pretrain_DIP(self, DIP_phase: nn.Module, DIP_amp: nn.Module | None):
"""Perform the pre-training of the DIP model.
Args:
DIP_phase (nn.Module): Guess phase (also used as input) to train towards
DIP_amp (nn.Module | None): Guess amp (also used as input) to train towards
"""
if self._num_pretrain_iter > 0:
self.vprint(f"Pre-training")
for _ in tqdm(range(self._num_pretrain_iter)):
loss = self._compute_loss_pretrain(DIP_phase, DIP_amp)
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
if self._verbose >= 2:
ph = DIP_phase.forward(self.input_DIP).squeeze().cpu().detach().numpy()
ph -= ph.min()
show_im(
ph,
title=f"Recon phase after pre-training DIP for {self._num_pretrain_iter} iters",
)
if self._solve_amp:
assert DIP_amp is not None
show_im(
DIP_amp.forward(self.input_DIP).squeeze().cpu().detach().numpy(),
title=f"Recon amp after pre-training DIP for {self._num_pretrain_iter} iters",
)
def _compute_loss_pretrain(self, DIP_phase: nn.Module, DIP_amp: nn.Module | None):
"""Helper function for `self._pretrain_DIP`"""
pred_phase = DIP_phase.forward(self.input_DIP).squeeze()
loss = torch.mean((pred_phase - self.guess_phase) ** 2)
if self._solve_amp:
assert DIP_amp is not None
pred_amp = DIP_amp.forward(self.input_DIP).squeeze()
loss += torch.mean((pred_amp - self.guess_amp) ** 2)
return loss
def _get_scheduler(self, **kwargs):
"""Return a torch scheduler according to `self._scheduler_type`"""
mode = str(self._scheduler_type).lower()
LR = self.LRs["phase"]
if mode == "none":
scheduler = None
elif mode == "cyclic":
scheduler = torch.optim.lr_scheduler.CyclicLR(
self.optimizer,
base_lr=kwargs.get("scheduler_base_lr", LR / 4),
max_lr=kwargs.get("scheduler_max_lr", LR * 4),
step_size_up=kwargs.get("scheduler_step_size_up", 100),
mode=kwargs.get("scheduler_mode", "triangular2"),
cycle_momentum=kwargs.get("scheduler_momentum", False),
)
elif mode.startswith("plat"):
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode="min",
factor=kwargs.get("scheduler_factor", 0.75),
patience=kwargs.get("scheduler_patience", 100),
threshold=kwargs.get("scheduler_threshold", 1e-4),
min_lr=kwargs.get("scheduler_min_lr", LR / 20),
)
elif mode in ["exp", "gamma"]:
gamma = kwargs.get("scheduler_gamma", 0.9997)
scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=gamma)
else:
raise ValueError(f"Unknown scheduler type: {mode}")
return scheduler
def _set_guess_amp(self, guess_amp=None):
"""Setting the guess amplitude used in AD reconstruction (no DIP?)"""
if isinstance(guess_amp, (np.ndarray, torch.Tensor)):
if guess_amp.shape and np.any(self.pad):
# guess amp does not account for padded shape. add padding
guess_amp = np.pad(
guess_amp,
((self.pad[0], self.pad[0]), (self.pad[1], self.pad[1])),
mode="constant",
constant_values=guess_amp.max(),
)
self.guess_amp = guess_amp
else:
im = self._padded_dd().images[len(self.dd) // 2]
if isinstance(guess_amp, (float, int)):
thresh = guess_amp # TODO update this to be percent saturated?
guess_amp = np.where(im >= thresh, 1, 0).astype(np.float32)
elif self._solve_amp:
thresh = im.min() + np.ptp(im) / 10
guess_amp = np.where(im >= thresh, 1, 0).astype(np.float32)
else:
guess_amp = np.ones(self.shape_full).astype(np.float32)
guess_amp *= np.sqrt(im.mean())
guess_amp = torch.tensor(guess_amp, device=self.device, dtype=torch.float32)
self.guess_amp = guess_amp
def _padded_dd(self):
dd2 = self.dd.copy()
dd2.images = np.pad(
self.dd.images,
((0, 0), (self.pad[0], self.pad[0]), (self.pad[1], self.pad[1])),
mode="constant",
constant_values=self.dd.images.mean(),
)
return dd2
def _set_guess_phase(self, guess_phase: str | np.ndarray):
"""Setting the guess phase used in AD reconstruction (no DIP?)"""
if isinstance(guess_phase, str):
guess_phase = guess_phase.lower()
# if guess_phase == "none":
# guess_phase = None
# self._num_pretrain_iter = 0
# return
if guess_phase == "uniform":
guess_phase = np.zeros(self.shape_full)
elif guess_phase == "sitie":
sitie = SITIE(self._padded_dd(), verbose=0)
sitie.reconstruct(qc=self._qc)
if self._verbose >= 2:
print("SITIE guess phase:")
sitie.visualize(cbar=True)
guess_phase = sitie.phase_B
self.guess_phase = torch.tensor(guess_phase, dtype=torch.float32)
return
def _set_input_DIP(self, input_DIP: str | np.ndarray | None = None):
"""
Generate the input to the DIP. Currently only a SITIE is available as that is frankly the
best option, but this could be easily expanded to using input noise (like for a traditional
DIP) or really anything
"""
if isinstance(input_DIP, str):
if input_DIP.lower() == "sitie":
sitie = SITIE(self._padded_dd(), verbose=0)
sitie.reconstruct(qc=self._qc)
inp = sitie.phase_B
elif input_DIP.lower() in ["random", "rand"]:
inp = self._rng.random(self.shape_full) * 2 - 1
else:
raise ValueError(
f"Input mode string should be 'SITIE' or 'random'. Got {input_DIP}"
)
elif isinstance(input_DIP, np.ndarray):
inp = np.squeeze(input_DIP)
else:
raise TypeError(f"input_DIP should be str or np.ndarray, got {type(input_DIP)}")
self.input_DIP = torch.tensor(
inp[None, ...], device=self.device, dtype=torch.float32, requires_grad=False
)
[docs]
def get_TFs(self):
"""
Returns a tensor containing the transfer functions according to self.scope and self.defavls
"""
tfs = np.array([self._get_TF(df) for df in self.defvals])
return torch.tensor(tfs, device=self.device, dtype=torch.complex64)
def _get_TF(self, defocus):
"""Returns a single transfer function for a given defocus value in nm"""
self.scope.defocus = defocus
return self.scope.get_transfer_function(self.scale, self.shape_full)
[docs]
def show_best(self, crop=5):
minloss_iter = self._best_iter # np.argmin(self.loss_iterations)
ph = self.best_phase
By, Bx = self.induction_from_phase(ph)
if crop > 0:
crop = int(crop)
ph = ph[crop:-crop, crop:-crop]
Bx = Bx[crop:-crop, crop:-crop]
By = By[crop:-crop, crop:-crop]
if self._solve_amp:
fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
show_im(
ph,
title=f"best phase: iter {minloss_iter} / {len(self.loss_iterations)}",
figax=(fig, axs[0]),
cbar_title="rad",
)
show_2D(
Bx,
By,
title=f"Best B: iter {minloss_iter} / {len(self.loss_iterations)}",
figax=(fig, axs[1]),
)
show_im(
self.best_amp,
title=f"best amp: iter {minloss_iter} / {len(self.loss_iterations)}",
figax=(fig, axs[2]),
)
plt.tight_layout()
plt.show()
else:
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
show_im(
ph,
title=f"best phase: iter {minloss_iter} / {len(self.loss_iterations)}",
scale=self.scale,
cbar_title="rad",
figax=(fig, axs[0]),
)
show_2D(
Bx,
By,
title=f"Best B: iter {minloss_iter} / {len(self.loss_iterations)}",
figax=(fig, axs[1]),
)
plt.tight_layout()
plt.show()
[docs]
def show_prediction(self) -> None:
"""
Show the predicted image that is compared to the input image.
"""
if len(self.inp_ims) > 1:
raise NotImplementedError
pred_image = self._sim_images().squeeze()
fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
show_im(
self.inp_ims,
title="Input image",
figax=(fig, axs[0]),
scale=self.scale,
)
show_im(
pred_image,
title="Predicted image",
figax=(fig, axs[1]),
ticks_off=True,
)
show_im(
self.inp_ims - pred_image,
title="Input - predicted",
figax=(fig, axs[2]),
ticks_off=True,
)
plt.tight_layout()
plt.show()
[docs]
def show_final(self, crop: int = 5) -> None:
"""Show the phase and induction of the final iteration.
Args:
crop (int, optional): Amount to crop off of induction maps before displaying; often
necessary in order to avoid edge artifacts. Defaults to 5.
"""
ph = self.recon_phase
By, Bx = self.induction_from_phase(ph)
if crop > 0:
crop = int(crop)
ph = ph[crop:-crop, crop:-crop]
Bx = Bx[crop:-crop, crop:-crop]
By = By[crop:-crop, crop:-crop]
if self._solve_amp:
fig, axs = plt.subplots(ncols=3, figsize=(12, 4))
show_im(
ph,
title=f"Recon phase: iter {len(self.loss_iterations)}",
scale=self.scale,
figax=(fig, axs[0]),
cbar_title="rad",
)
show_2D(
Bx,
By,
title=f"Recon B: iter {len(self.loss_iterations)}",
figax=(fig, axs[1]),
)
show_im(
self.recon_amp,
title=f"Recon amp: iter {len(self.loss_iterations)}",
scale=self.scale,
figax=(fig, axs[2]),
)
plt.tight_layout()
plt.show()
else:
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
show_im(
ph,
title=f"Recon phase: iter {len(self.loss_iterations)}",
scale=self.scale,
cbar_title="rad",
figax=(fig, axs[0]),
)
show_2D(
Bx,
By,
title=f"Recon B: iter {len(self.loss_iterations)}",
figax=(fig, axs[1]),
)
plt.tight_layout()
plt.show()
[docs]
def visualize(self, crop=5):
"""Plot the best reconstructed phase and induction maps.
Args:
crop (int, optional): Amount to crop off of induction maps before displaying; often
necessary in order to avoid edge artifacts. Defaults to 5.
"""
if self._solve_amp:
fig = plt.figure(figsize=(12, 8))
ax1 = fig.add_subplot(231)
self.show_phase_B(figax=(fig, ax1), cbar_title=None, crop=crop)
ax2 = fig.add_subplot(232)
self.show_B(figax=(fig, ax2), crop=crop)
ax2p5 = fig.add_subplot(233)
show_im(self.best_amp, figax=(fig, ax2p5), ticks_off=True, title="amp")
ax3 = fig.add_subplot(212)
l1 = ax3.semilogy(self.loss_iterations, color="tab:blue", label="loss")
ax3.set_xlabel("iterations")
ax3.set_ylabel("loss")
ax4 = ax3.twinx()
LRs = np.array(self.LR_iterations)
l2 = ax4.semilogy(LRs[:, 0], color="tab:orange", label="phase LR")
lns = l1 + l2
labs = [l.get_label() for l in lns]
ax4.legend(lns, labs, loc=0) # type:ignore
ax4.set_ylabel("LR")
ax4.yaxis.set_major_formatter(FormatStrFormatter("%.2e"))
else:
fig = plt.figure(figsize=(8, 8))
ax1 = fig.add_subplot(221)
ax2 = fig.add_subplot(222)
if self.phase_B is not None:
self.show_phase_B(figax=(fig, ax1), cbar_title=None, crop=crop)
self.show_B(figax=(fig, ax2), crop=crop)
ax3 = fig.add_subplot(212)
l1 = ax3.semilogy(self.loss_iterations, color="tab:blue", label="loss")
ax3.set_xlabel("iterations")
ax3.set_ylabel("loss")
ax4 = ax3.twinx()
LRs = np.array(self.LR_iterations)
l2 = ax4.semilogy(LRs[:, 0], color="tab:orange", label="phase LR")
lns = l1 + l2
labs = [l.get_label() for l in lns]
ax4.legend(lns, labs, loc=0) # type:ignore
ax4.set_ylabel("LR")
ax4.yaxis.set_major_formatter(FormatStrFormatter("%.2e"))
plt.show()
return
def __len__(self):
return len(self.loss_iterations)
[docs]
def save_results(
self,
iter_ind: int | None = None,
save_dir: Optional[os.PathLike] = None,
name: Optional[str] = None,
overwrite: bool = False,
):
"""Save the recontructed phase, Bx, By, and color images.
# TODO add saving amplitude and phase_E
Args:
iter_ind (int, optional): Index to save. Defaults to None which saves best phase.
save_dir (os.PathLik], optional): Directory to save in. Defaults to None
which saves in self.save_dir.
name (str, optional): Name to prepend saved files. Defaults to None.
overwrite (bool, optional): Whether or not to overwrite files. Defaults to False.
"""
if save_dir is not None or name is not None:
# don't want to overwrite original name with timestamp
self._check_save_name(save_dir, name=name, default_name=False)
if iter_ind is not None:
assert len(self.phase_iterations) > 0
phase, iter = self.phase_iterations[iter_ind]
By, Bx = self.induction_from_phase(np.array(phase))
else:
iter = self._best_iter
phase = self.best_phase
Bx = self.Bx
By = self.By
results = {
"phase_B": phase,
"By": By,
"Bx": Bx,
"color": None,
}
if self._mode == "SIPRAD":
results["input_image"] = self.dd.image
else:
raise NotImplementedError("need to figure out how saving defocus vals")
save_name_no_iter = "_".join(self._save_name.split("_")[:3])
self._save_name = save_name_no_iter + f"_i{iter}"
self.save_dir.mkdir(exist_ok=True)
self._save_keys(list(results.keys()), self.defvals[0], overwrite, res_dict=results)
self._save_log(overwrite)
def _save_log(self, overwrite: Optional[bool] = None):
"""
Save the reconstruction log.
Args:
overwrite (bool, optional): Whether to overwrite existing files. Default is None.
"""
log_dict = {
"name": self.name,
"_save_name": self._save_name,
"defval": self.defvals.squeeze(),
"scale": self.scale,
"transforms": self.dd.transforms,
"filters": self.dd._filters,
"beam_energy": self.dd.beam_energy,
"simulated": self.dd._simulated,
"data_dir": self.dd.data_dir,
"data_files": self.dd.data_files,
"save_dir": self._save_dir,
}
ovr = overwrite if overwrite is not None else self._overwrite
name = f"{self._save_name}_{self._fmt_defocus(self.defvals[0])}_log.json"
write_json(log_dict, self.save_dir / name, overwrite=ovr, v=self._verbose)