import os
from pathlib import Path
from typing import Optional, Union, List
import matplotlib.pyplot as plt
import numpy as np
import scipy.ndimage as ndi
from tqdm import tqdm
from PyLorentz.io import read_image, read_json
from PyLorentz.utils.filter import filter_hotpix
from PyLorentz.visualize import show_im
from .base_dataset import BaseDataset
import copy
[docs]class DefocusedDataset(BaseDataset):
"""
A dataset class for handling defocused images and related metadata.
Parameters:
images (np.ndarray): The set of defocused images.
scale (Optional[float]): The scale of the images.
defvals (Optional[np.ndarray]): The defocus values corresponding to the images.
beam_energy (Optional[float]): The beam energy used during imaging.
data_files (List[os.PathLike]): File paths of the data files.
simulated (bool): Indicates if the data is simulated.
verbose (Union[int, bool]): Verbosity level for logging.
"""
[docs] def __init__(
self,
images: np.ndarray,
scale: Optional[float] = None,
defvals: Optional[np.ndarray] = None,
beam_energy: Optional[float] = None,
data_files: List[os.PathLike] = [],
simulated: bool = False,
verbose: Union[int, bool] = 1,
):
images = np.array(images).astype(np.float64)
if np.ndim(images) == 2:
images = images[None,]
if isinstance(defvals, (float, int)):
defvals = np.array([defvals])
if isinstance(data_files, (list, np.ndarray)) and np.size(data_files) > 0:
self.data_files = [Path(f).absolute() for f in data_files]
self.data_dirs = [f.parents[0] for f in self.data_files]
elif isinstance(data_files, (os.PathLike, str)):
self.data_files = [Path(data_files).absolute()]
self.data_dirs = [f.parents[0] for f in self.data_files]
else:
self.data_files = data_files
self.data_dirs = [None]
BaseDataset.__init__(
self,
imshape=images.shape[1:],
data_dir=self.data_dirs[0],
scale=scale,
verbose=verbose,
)
self.images = images
self._orig_images = images.copy()
self._orig_shape = images.shape[1:]
self._orig_images_preprocessed = None
self._images_cropped = None
self._images_filtered = None
self.defvals = defvals
self.beam_energy = beam_energy
self._simulated = simulated
self._verbose = verbose
self.mask = None # np.ones_like(self.shape)
self._preprocessed = False
self._cropped = False
self._filtered = False
[docs] @classmethod
def from_TFS(cls):
"""
Convert a Through Focal Series (TFS) to a DefocusedDataset.
"""
return cls
[docs] @classmethod
def load(
cls,
images: Union[np.ndarray, os.PathLike, List[os.PathLike]],
metadata: Optional[Union[os.PathLike, dict]] = None,
**kwargs,
) -> "DefocusedDataset":
"""
Load images and metadata to create a DefocusedDataset instance.
Args:
images (Union[np.ndarray, os.PathLike, List[os.PathLike]]): Image data or paths.
metadata (Optional[Union[os.PathLike, dict]]): Metadata as a path or dict.
Returns:
DefocusedDataset: The created dataset instance.
"""
if metadata is not None:
mdata = cls._parse_mdata(metadata)
else:
mdata = {}
if isinstance(images, (list, np.ndarray)):
if isinstance(images[0], (os.PathLike, str)):
raise NotImplementedError(
"write method for reading list of files and collecting defocus values"
)
else:
if metadata is not None:
mdata = cls._parse_mdata(metadata)
else:
mdata = {}
elif isinstance(images, (os.PathLike, str)):
images, mdata = read_image(images)
if metadata is not None:
if isinstance(metadata, dict):
mdata_l = metadata
elif isinstance(metadata, (os.PathLike, str)):
mdata_l = read_json(metadata)
else:
raise ValueError(f"metadata must be dict or PathLike, not {type(metadata)}")
mdata = mdata | mdata_l # combine, json values prioritized
defvals = kwargs.pop("defvals", mdata.get("defocus_values"))
if defvals is None:
raise ValueError("defvals must be specified, is None.")
filepaths = mdata.get("data_files", mdata.get("filepath"))
dd = cls(
images=images,
scale=kwargs.pop("scale", mdata.get("scale")),
defvals=defvals,
beam_energy=kwargs.pop("beam_energy", mdata.get("beam_energy")),
data_files=kwargs.pop("data_files", filepaths),
simulated=kwargs.pop("simulated", mdata.get("simulated", False)),
**kwargs,
)
return dd
@property
def images(self) -> np.ndarray:
return self._images
@images.setter
def images(self, ims: np.ndarray) -> None:
ims = np.array(ims)
if np.ndim(ims) == 2:
ims = ims[None]
if hasattr(self, "_defvals"):
if len(ims) != len(self.defvals):
raise ValueError(
f"Len images, {len(ims)} must equal len defvals, {len(self.defvals)}"
)
self._images = ims
@property
def image(self, idx: int = 0) -> np.ndarray:
return self.images[idx]
@property
def defvals(self) -> np.ndarray:
return self._defvals
@defvals.setter
def defvals(self, dfs: np.ndarray) -> None:
if isinstance(dfs, (float, int)):
dfs = np.array([dfs])
if hasattr(self, "_images"):
if len(dfs) != len(self.images):
raise ValueError(
f"Number defocus vals, {len(dfs)} must equal number images, {len(self.images)}"
)
self._defvals = dfs
def __len__(self) -> int:
return len(self.images)
@property
def shape(self) -> tuple:
return self.images.shape[1:]
@property
def energy(self) -> Optional[float]:
return self._beam_energy
@energy.setter
def energy(self, val: float) -> None:
if not isinstance(val, (float, int)):
raise TypeError(f"energy must be numeric, found {type(val)}")
if val <= 0:
raise ValueError(f"energy must be > 0, not {val}")
self._beam_energy = float(val)
[docs] def select_ROI(self, idx: int = 0, image: Optional[np.ndarray] = None) -> None:
"""
Select a Region of Interest (ROI) for processing.
Args:
idx (int): Index of the image to use for ROI selection.
image (Optional[np.ndarray]): Specific image to use for ROI selection.
"""
if image is not None:
image = np.array(image)
if image.shape != self._orig_shape:
raise ValueError(
f"Shape of image for choosing ROI, {image.shape}, must match "
+ f"orig_images shape, {self._orig_shape}"
)
else:
if self._preprocessed:
image = self._orig_images_preprocessed[idx]
else:
image = self._orig_images[idx]
if self._filtered:
image = self._bandpass_filter(
image,
self._filters["q_lowpass"],
self._filters["q_highpass"],
self._filters["filter_type"],
self._filters["butterworth_order"],
)
self._select_ROI(image)
[docs] def preprocess(
self,
hotpix: bool = True,
median_filter_size: Optional[int] = None,
fast: bool = True,
**kwargs,
) -> None:
"""
Preprocess the images by filtering hot pixels and applying a median filter.
Args:
hotpix (bool): Whether to filter hot pixels.
median_filter_size (Optional[int]): Size of the median filter.
fast (bool): Whether to use a fast filtering method.
"""
self.images = self._orig_images.copy()
if hotpix:
self.vprint("Filtering hot/dead pixels")
for i in tqdm(range(len(self))):
self.images[i] = filter_hotpix(self.images[i], fast=fast, **kwargs)
if median_filter_size is not None:
self.images = ndi.median_filter(
self.images, size=(1, median_filter_size, median_filter_size)
)
self._preprocessed = True
self._orig_images_preprocessed = self.images.copy()
self.apply_transforms()
self._filters["hotpix"] = hotpix
self._filters["median"] = median_filter_size
def __str__(self) -> str:
return f"DefocusedDataset containing {len(self)} image(s) of shape {self.shape}"
[docs] def show_im(self, idx: int = 0, **kwargs) -> None:
"""
Display an image with optional parameters.
Args:
idx (int): Index of the image to display.
"""
if len(self) > 1:
title = kwargs.pop(
"title",
f"index {idx} / {len(self)-1} | defocus: {self._fmt_defocus(self.defvals[idx])}",
)
else:
title = kwargs.pop("title", f"defocus: {self._fmt_defocus(self.defvals[idx])}")
show_im(
self.images[idx],
scale=kwargs.pop("scale", self.scale),
title=title,
cbar=kwargs.pop("cbar", False),
**kwargs,
)
[docs] def show_all(self, **kwargs) -> None:
"""
Display all images.
Args:
idx (int): Index of the image to display.
"""
ncols = len(self)
if ncols == 1:
self.show_im(**kwargs)
else:
fig, axs = plt.subplots(nrows=1, ncols=ncols, figsize=(2 * ncols, 2))
for a0, df in enumerate(self.defvals):
show_im(
self.images[a0],
figax=(fig, axs[a0]),
title=f"{self._fmt_defocus(df)}",
simple=True,
**kwargs,
)
[docs] def filter(
self,
q_lowpass: Optional[float] = None,
q_highpass: Optional[float] = None,
filter_type: str = "butterworth",
butterworth_order: int = 2,
idx: Optional[Union[int, List[int]]] = None,
show: bool = False,
v: Optional[int] = None,
) -> None:
"""
Apply bandpass filtering to the images.
Args:
q_lowpass (Optional[float]): Lowpass filter cutoff.
q_highpass (Optional[float]): Highpass filter cutoff.
filter_type (str): Type of filter ('butterworth' or 'gaussian').
butterworth_order (int): Order of the Butterworth filter.
idx (Optional[Union[int, List[int]]]): Indices of images to filter.
show (bool): Whether to display the filtered images.
v (Optional[int]): Verbosity level.
"""
v = self._verbose if v is None else v
if idx is None:
indices = np.arange(len(self))
elif isinstance(idx, int):
indices = [idx]
else:
if not isinstance(idx, (list, np.ndarray)):
raise TypeError(
f"idx must be an integer index, list of indices, or None. Got type {type(idx)}"
)
indices = idx
if self._cropped:
input_ims = self._images_cropped[indices].copy()
elif self._preprocessed:
input_ims = self._orig_images_preprocessed[indices].copy()
else:
input_ims = self._orig_images[indices].copy()
filtered_ims = np.zeros_like(input_ims)
for i in range(len(input_ims)):
filtered_ims[i] = self._bandpass_filter(
input_ims[i], q_lowpass, q_highpass, filter_type, butterworth_order
)
if show:
fig, axs = plt.subplots(ncols=3, nrows=len(indices), figsize=(12, 4 * len(indices)))
if len(indices) == 1:
axs = [axs]
for a0 in range(len(indices)):
show_im(
input_ims[a0],
figax=(fig, axs[a0][0]),
title="original image",
scale=self.scale,
ticks_off=a0 != 0,
)
show_im(
filtered_ims[a0],
figax=(fig, axs[a0][1]),
title="filtered image",
ticks_off=True,
)
show_im(
input_ims[a0] - filtered_ims[a0],
figax=(fig, axs[a0][2]),
title="orig - filtered",
ticks_off=True,
)
self._filtered = True
self._filters["q_lowpass"] = q_lowpass
self._filters["q_highpass"] = q_highpass
self._filters["filter_type"] = filter_type
self._filters["butterworth_order"] = butterworth_order
if self._images_filtered is None:
self._images_filtered = np.zeros_like(self._images_cropped)
self._images_filtered[indices] = filtered_ims
self.images[indices] = filtered_ims
[docs] def copy(self) -> "DefocusedDataset":
"""
Create a deep copy of the dataset.
Returns:
DefocusedDataset: A deep copy of the current dataset.
"""
return copy.deepcopy(self)