Source code for TIE_helper

"""Helper functions for TIE. 

An assortment of helper functions that load images, pass data, and generally 
are used in the reconstruction. Additionally, a couple of functions used for 
displaying images and stacks.

Author: Arthur McCray, ANL, Summer 2019.
"""
 
import matplotlib.pyplot as plt 
import matplotlib as mpl
import numpy as np
import hyperspy.api as hs
import sys
from cv2 import resize, imwrite
from skimage import io
from scipy.ndimage.filters import median_filter
from scipy import ndimage
from ipywidgets import interact
import hyperspy # just for checking type in show_stack. 
from copy import deepcopy
from TIE_params import TIE_params
import textwrap
import os


# ============================================================= #
#      Functions used for loading and passing the TIE data      # 
# ============================================================= #

[docs]def load_data(path=None, fls_file='', al_file='', flip=None, flip_fls_file=None, filtersize=3): """Load files in a directory (from a .fls file) using hyperspy. For more information on how to organize the directory and load the data, as well as how to setup the .fls file please refer to the README or the TIE_template.ipynb notebook. Args: path (str): Location of data directory. fls_file (str): Name of the .fls file which contains the image names and defocus values. al_file (str): Name of the aligned stack image file. flip (Bool): True if using a flip stack, False otherwise. Uniformly thick films can be reconstructed without a flip stack. The electrostatic phase shift will not be reconstructed. flip_fls_file (str): Name of the .fls file for the flip images if they are not named the same as the unflip files. Will only be applied to the /flip/ directory. filtersize (int): (`optional`) The images are processed with a median filter to remove hot pixels which occur in experimental data. This should be set to 0 for simulated data, though generally one would only use this function for experimental data. Returns: list: List of length 3, containing the following items: - imstack: array of hyperspy signal2D objects (one per image) - flipstack: array of hyperspy signal2D objects, empty array if flip == False - ptie: TIE_params object holding a reference to the imstack and many other parameters. """ unflip_files = [] flip_files = [] # Finding the unflip fls file path = os.path.abspath(path) if not fls_file.endswith('.fls'): fls_file += '.fls' if os.path.isfile(os.path.join(path, fls_file)): fls_full = os.path.join(path, fls_file) elif os.path.isfile(os.path.join(path, 'unflip', fls_file)): fls_full = os.path.join(path, 'unflip', fls_file) elif os.path.isfile(os.path.join(path, 'tfs', fls_file)) and not flip: fls_full = os.path.join(path, 'tfs', fls_file) else: print("fls file could not be found.") sys.exit(1) if flip_fls_file is None: # one fls file given fls = [] with open(fls_full) as file: for line in file: fls.append(line.strip()) num_files = int(fls[0]) if flip: for line in fls[1:num_files+1]: unflip_files.append(os.path.join(path, 'unflip', line)) for line in fls[1:num_files+1]: flip_files.append(os.path.join(path, 'flip', line)) else: if os.path.isfile(os.path.join(path, 'tfs', fls[2])): tfs_dir = 'tfs' else: tfs_dir = 'unflip' for line in fls[1:num_files+1]: unflip_files.append(os.path.join(path, tfs_dir, line)) else: # there are 2 fls files given if not flip: print(textwrap.dedent(""" You probably made a mistake. You're defining both unflip and flip fls files but have flip=False. Proceeding anyways, will only load unflip stack (if it doesnt break).\n""")) # find the flip fls file if not flip_fls_file.endswith('.fls'): flip_fls_file += '.fls' if os.path.isfile(os.path.join(path, flip_fls_file)): flip_fls_full = os.path.join(path, flip_fls_file) elif os.path.isfile(os.path.join(path, 'flip', flip_fls_file)): flip_fls_full = os.path.join(path, 'flip', flip_fls_file) fls = [] flip_fls = [] with open(fls_full) as file: for line in file: fls.append(line.strip()) with open(flip_fls_full) as file: for line in file: flip_fls.append(line.strip()) assert int(fls[0]) == int(flip_fls[0]) num_files = int(fls[0]) for line in fls[1:num_files+1]: unflip_files.append(os.path.join(path, "unflip", line)) for line in flip_fls[1:num_files+1]: flip_files.append(os.path.join(path, "flip", line)) # Actually load the data using hyperspy imstack = hs.load(unflip_files) if flip: flipstack = hs.load(flip_files) else: flipstack = [] # convert scale dimensions to nm for sig in imstack + flipstack: sig.axes_manager.convert_units(units=['nm', 'nm']) if unflip_files[0][-4:] != '.dm3' and unflip_files[0][-4:] != '.dm4': # if not dm3's then they generally don't have the title metadata. for sig in imstack + flipstack: sig.metadata.General.title = sig.metadata.General.original_filename # load the aligned tifs and update the dm3 data to match # The data from the dm3's will be replaced with the aligned image data. try: al_tifs = io.imread(os.path.join(path, al_file)) except FileNotFoundError as e: print('Incorrect aligned stack filename given.') raise e if flip: tot_files = 2*num_files else: tot_files = num_files for i in range(tot_files): # pull slices from correct axis, assumes fewer slices than images are tall if al_tifs.shape[0] < al_tifs.shape[2]: im = al_tifs[i] elif al_tifs.shape[0] > al_tifs.shape[2]: im = al_tifs[:,:,i] else: print("Bad stack\n Or maybe the second axis is slice axis?") print('Loading failed.\n') sys.exit(1) # then median filter to remove "hot pixels" im = median_filter(im, size= filtersize) # and assign to appropriate stack if i < num_files: print('loading unflip:', unflip_files[i]) imstack[i].data = im else: j = i - num_files print('loading flip:', flip_files[j]) flipstack[j].data = im # read the defocus values defvals = fls[-(num_files//2):] assert num_files == 2*len(defvals) + 1 defvals = [float(i) for i in defvals] # defocus values +/- # Create a TIE_params object ptie = TIE_params(imstack, flipstack, defvals, flip, path) print('Data loaded successfully.') return (imstack, flipstack, ptie)
[docs]def load_data_GUI(path, fls_file1, fls_file2, al_file='', single=False, filtersize=3): """Load files in a directory (from a .fls file) using hyperspy. For more information on how to organize the directory and load the data, as well as how to setup the .fls file please refer to the README or the TIE_template.ipynb notebook. Args: path (str): Location of data directory. fls_file1 (str): Name of the .fls file which contains the image names and defocus values. fls_file2 (str): Name of the .fls file for the flip images if they are not named the same as the unflip files. Will only be applied to the /flip/ directory. al_file (str): Name of the aligned stack image file. single (Bool): True if using a single stack, False otherwise. Uniformly thick films can be reconstructed with a single stack. The electrostatic phase shift will not be reconstructed. filtersize (int): (`optional`) The images are processed with a median filter to remove hot pixels which occur in experimental data. This should be set to 0 for simulated data, though generally one would only use this function for experimental data. Returns: list: List of length 3, containing the following items: - imstack: array of hyperspy signal2D objects (one per image) - flipstack: array of hyperspy signal2D objects, empty array if flip == False - ptie: TIE_params object holding a reference to the imstack and many other parameters. """ unflip_files = [] flip_files = [] if fls_file2 is None: # one fls file given u_files = [] with open(fls_file1) as file: for line in file: u_files.append(line.strip()) num_files = int(u_files[0]) if not single: for line in u_files[1:num_files + 1]: unflip_files.append(os.path.join(path, 'unflip', line)) for line in u_files[1:num_files + 1]: flip_files.append(os.path.join(path, 'flip', line)) else: if os.path.exists(os.path.join(path, 'tfs')): sub_dir = 'tfs' else: sub_dir = 'unflip' for line in u_files[1:num_files + 1]: unflip_files.append(os.path.join(path, sub_dir, line)) else: # there are 2 fls files given if single: print(textwrap.dedent(""" You probably made a mistake. You're defining both unflip and flip fls files but have flip=False. Proceeding anyways, will only load unflip stack (if it doesnt break).\n""")) u_files = [] f_files = [] with open(fls_file1) as file: for line in file: u_files.append(line.strip()) with open(fls_file2) as file: for line in file: f_files.append(line.strip()) assert int(u_files[0]) == int(f_files[0]) num_files = int(u_files[0]) for line in u_files[1:num_files + 1]: unflip_files.append(os.path.join(path, "unflip", line)) for line in f_files[1:num_files + 1]: flip_files.append(os.path.join(path, "flip", line)) # Actually load the data using hyperspy imstack = hs.load(unflip_files) if not single: flipstack = hs.load(flip_files) else: flipstack = [] # convert scale dimensions to nm for sig in imstack + flipstack: sig.axes_manager.convert_units(units=['nm', 'nm']) if unflip_files[0][-4:] != '.dm3' and unflip_files[0][-4:] != '.dm4': # if not dm3's then they generally don't have the title metadata. for sig in imstack + flipstack: sig.metadata.General.title = sig.metadata.General.original_filename # load the aligned tifs and update the dm3 data to match # The data from the dm3's will be replaced with the aligned image data. try: al_tifs = io.imread(al_file) except FileNotFoundError as e: print('Incorrect aligned stack filename given.') raise e if not single: tot_files = 2 * num_files else: tot_files = num_files for i in range(tot_files): # pull slices from correct axis, assumes fewer slices than images are tall if al_tifs.shape[0] < al_tifs.shape[2]: im = al_tifs[i] elif al_tifs.shape[0] > al_tifs.shape[2]: im = al_tifs[:, :, i] else: print("Bad stack\n Or maybe the second axis is slice axis?") print('Loading failed.\n') sys.exit(1) # then median filter to remove "hot pixels" im = median_filter(im, size=filtersize) # and assign to appropriate stack if i < num_files: print('loading unflip:', unflip_files[i]) imstack[i].data = im else: j = i - num_files print('loading flip:', flip_files[j]) flipstack[j].data = im # read the defocus values defvals = u_files[-(num_files // 2):] assert num_files == 2 * len(defvals) + 1 defvals = [float(i) for i in defvals] # defocus values +/- # Create a TIE_params object if single: single = None else: single = True ptie = TIE_params(imstack, flipstack, defvals, single, path) print('Data loaded successfully.') return (imstack, flipstack, ptie)
[docs]def select_tifs(i, ptie, long_deriv = False): """Returns a list of the images which will be used in TIE() or SITIE(). Uses copy.deepcopy() as the data will be modified in the reconstruction process, and we don't want to change the original data. This method is likely not best practice. In the future this might get moved to the TIE_params class. Args: i (int): Index of defvals for which to select the tifs. ptie (``TIE_params`` object): Parameters for reconstruction, holds the images. Returns: list: List of np arrays, return depends on parameters: - if long_deriv == False: - if ptie.flip == True: returns [ +- , -- , 0 , ++ , -+ ] - elif ptie.flip == False: returns [+-, 0, ++] - where first +/- is unflip/flip, second +/- is over/underfocus. E.g. -+ is the flipped overfocused image. 0 is the averaged infocus image. - elif long_deriv == True: returns all images in imstack followed by all images in flipstack. """ if long_deriv: recon_tifs = [] for sig in ptie.imstack: recon_tifs.append(sig.data) if ptie.flip: for sig in ptie.flipstack: recon_tifs.append(sig.data) else: if i < 0: i = len(ptie.defvals)+i print('new i: ', i) num_files = ptie.num_files under = num_files//2 - (i+1) over = num_files//2 + (i+1) imstack = ptie.imstack flipstack = ptie.flipstack if ptie.flip: recon_tifs = [ imstack[under].data, # +- flipstack[under].data, # -- (imstack[num_files//2].data + flipstack[num_files//2].data)/2, # infocus imstack[over].data, # ++ flipstack[over].data # -+ ] else: recon_tifs = [ imstack[under].data, # +- imstack[num_files//2].data, # 0 imstack[over].data # ++ ] try: recon_tifs = deepcopy(recon_tifs) except TypeError: print("TypeError in select_tifs deepcopy. Proceeding with originals.") return recon_tifs
[docs]def dist(ny, nx, shift=False): """Creates a frequency array for Fourier processing. Args: ny (int): Height of array nx (int): Width of array shift (bool): Whether to center the frequency spectrum. - False: (default) smallest values are at the corners. - True: smallest values at center of array. Returns: ``ndarray``: Numpy array of shape (ny, nx). """ ly = (np.arange(ny)-ny/2)/ny lx = (np.arange(nx)-nx/2)/nx [X,Y] = np.meshgrid(lx, ly) q = np.sqrt(X**2 + Y**2) if not shift: q = np.fft.ifftshift(q) return q
[docs]def scale_stack(imstack): """Scale a stack of images so all have the same total intensity. Args: imstack (list): List of 2D arrays. Returns: list: List of same shape as imstack """ imstack = deepcopy(imstack) tots = np.sum(imstack, axis = (1,2)) t = max(tots) / tots for i in range(len(tots)): imstack[i] *= t[i] return imstack/np.max(imstack)
# =============================================== # # Various display functions # # =============================================== # """ Not all of these are used in TIE_reconstruct, but I often find them useful to have handy when working in Jupyter notebooks."""
[docs]def show_im(image, title=None, simple=False, origin='upper', cbar=True, cbar_title='', scale=None, **kwargs): """Display an image on a new axis. Takes a 2D array and displays the image in grayscale with optional title on a new axis. In general it's nice to have things on their own axes, but if too many are open it's a good idea to close with plt.close('all'). Args: image (2D array): Image to be displayed. title (str): (`optional`) Title of plot. simple (bool): (`optional`) Default output or additional labels. - True, will just show image. - False, (default) will show a colorbar with axes labels, and will adjust the contrast range for images with a very small range of values (<1e-12). origin (str): (`optional`) Control image orientation. - 'upper': (default) (0,0) in upper left corner, y-axis goes down. - 'lower': (0,0) in lower left corner, y-axis goes up. cbar (bool): (`optional`) Choose to display the colorbar or not. Only matters when simple = False. cbar_title (str): (`optional`) Title attached to the colorbar (indicating the units or significance of the values). scale (float): Scale of image in nm/pixel. Axis markers will be given in units of nanometers. Returns: None """ fig, ax = plt.subplots() if not simple and np.max(image) - np.min(image) < 1e-12: # adjust coontrast range vmin = np.min(image) - 1e-12 vmax = np.max(image) + 1e-12 im = ax.matshow(image, cmap = 'gray', origin=origin, vmin=vmin, vmax=vmax) else: im = ax.matshow(image, cmap = 'gray', origin=origin, **kwargs) if title is not None: ax.set_title(str(title), pad=0) if simple: plt.axis('off') else: plt.tick_params(axis='x',top=False) ax.xaxis.tick_bottom() ax.tick_params(direction='in') if scale is None: ticks_label = 'pixels' else: def mjrFormatter(x, pos): return f"{scale*x:.3g}" fov = scale * max(image.shape[0], image.shape[1]) if fov < 4e3: # if fov < 4um use nm scale ticks_label = ' nm ' elif fov > 4e6: # if fov > 4mm use m scale ticks_label = " m " scale /= 1e9 else: # if fov between the two, use um ticks_label = " $\mu$m " scale /= 1e3 ax.yaxis.set_major_formatter(mpl.ticker.FuncFormatter(mjrFormatter)) ax.xaxis.set_major_formatter(mpl.ticker.FuncFormatter(mjrFormatter)) if origin == 'lower': ax.text(y=0,x=0,s=ticks_label, rotation=-45, va='top', ha='right') elif origin =='upper': # keep label in lower left corner ax.text(y=image.shape[0],x=0,s=ticks_label, rotation=-45, va='top', ha='right') if cbar: plt.colorbar(im, ax=ax, pad=0.02, format="%.2g", label=str(cbar_title)) plt.show() return
[docs]def show_stack(images, ptie=None, origin='upper', title=False): """Shows a stack of dm3s or np images with a slider to navigate slice axis. Uses ipywidgets.interact to allow user to view multiple images on the same axis using a slider. There is likely a better way to do this, but this was the first one I found that works... If a TIE_params object is given, only the regions corresponding to ptie.crop will be shown. Args: images (list): List of 2D arrays. Stack of images to be shown. ptie (``TIE_params`` object): Will use ptie.crop to show only the region that will remain after being cropped. origin (str): (`optional`) Control image orientation. title (bool): (`optional`) Try and pull a title from the signal objects. Returns: None """ sig = False if type(images[0]) == hyperspy._signals.signal2d.Signal2D: sig = True imstack = [] titles = [] for signal2D in images: imstack.append(signal2D.data) titles.append(signal2D.metadata.General.title) images = np.array(imstack) else: images = np.array(images) if ptie is None: t , b = 0, images[0].shape[0] l , r = 0, images[0].shape[1] else: if ptie.rotation != 0 or ptie.x_transl != 0 or ptie.y_transl != 0: rotate, x_shift, y_shift = ptie.rotation, ptie.x_transl, ptie.y_transl for i in range(len(images)): images[i] = ndimage.rotate(images[i], rotate, reshape=False) images[i] = ndimage.shift(images[i], (-y_shift, x_shift)) t = ptie.crop['top'] b = ptie.crop['bottom'] l = ptie.crop['left'] r = ptie.crop['right'] images = images[:,t:b,l:r] fig, ax = plt.subplots() plt.axis('off') N = images.shape[0] def view_image(i=0): im = plt.imshow(images[i], cmap='gray', interpolation='nearest', origin=origin) if title: if sig: plt.title('Image title: {:}'.format(titles[i])) else: plt.title('Stack[{:}]'.format(i)) interact(view_image, i=(0, N-1)) return
[docs]def show_2D(mag_x, mag_y, mag_z=None, a=15, l=None, w=None, title=None, color=False, hsv=True, origin='upper', save=None, GUI_handle=False, GUI_color_array=None): """ Display a 2D vector arrow plot. Displays an an arrow plot of a vector field, with arrow length scaling with vector magnitude. If color=True, a colormap will be displayed under the arrow plot. If mag_z is included and color=True, a spherical colormap will be used with color corresponding to in-plane and white/black to out-of-plane vector orientation. Args: mag_x (2D array): x-component of magnetization. mag_y (2D array): y-component of magnetization. mag_z (2D array): optional z-component of magnetization. a (int): Number of arrows to plot along the x and y axes. Default 15. l (float): Scale factor of arrows. Larger l -> shorter arrows. Default None guesses at a good value. None uses matplotlib default. w (float): Width scaling of arrows. None uses matplotlib default. title (str): (`optional`) Title for plot. Default None. color (bool): (`optional`) Whether or not to show a colormap underneath the arrow plot. Color image is made from colorwheel.color_im(). hsv (bool): (`optional`) Only relevant if color == True. Whether to use an hsv or 4-fold color-wheel in the color image. origin (str): (`optional`) Control image orientation. save (str): (`optional`) Path to save the figure. GUI_handle (bool): ('optional') Handle for indicating if using GUI. Default is False. GUI_color_array (2D array): ('optional') The colored image array passed from the GUI, it is for creating the overlaying the arrows without using color_im(). Returns: fig: Returns the figure handle. """ a = ((mag_x.shape[0] - 1)//a)+1 dimy, dimx = mag_x.shape X = np.arange(0, dimx, 1) Y = np.arange(0, dimy, 1) U = mag_x V = mag_y sz_inches = 8 if not GUI_handle or save is not None: if color: rad = mag_x.shape[0]//16 rad = max(rad, 16) pad = 10 # pixels width = np.shape(mag_y)[1] + 2*rad + pad aspect = dimy/width else: aspect = dimy/dimx if GUI_handle and save is None: fig, ax = plt.subplots(figsize=(10, 10)) plt.ioff() ax.set_aspect('equal', adjustable='box') else: fig, ax = plt.subplots() ax.set_aspect(aspect) if color: if not GUI_handle or save is not None: from colorwheel import color_im im = ax.matshow(color_im(mag_x, mag_y, mag_z, hsvwheel=hsv, rad=rad), cmap='gray', origin=origin) else: im = ax.matshow(GUI_color_array, cmap='gray', origin=origin, aspect='equal') arrow_color = 'white' plt.axis('off') else: arrow_color = 'black' if GUI_handle and save is None: white_array = np.zeros([dimy, dimx, 3], dtype=np.uint8) white_array.fill(255) im = ax.matshow(white_array, cmap='gray', origin=origin, aspect='equal') plt.axis('off') elif GUI_handle and save: white_array = np.zeros([dimy, dimx, 3], dtype=np.uint8) white_array.fill(255) im = ax.matshow(white_array, cmap='gray', origin=origin) fig.tight_layout(pad=0) ax.xaxis.set_major_locator(mpl.ticker.NullLocator()) ax.yaxis.set_major_locator(mpl.ticker.NullLocator()) plt.axis('off') ashift = (dimx-1) % a//2 q = ax.quiver(X[ashift::a], Y[ashift::a], U[ashift::a,ashift::a], V[ashift::a,ashift::a], units='xy', scale=l, scale_units='xy', width=w, angles='xy', pivot='mid', color=arrow_color) if not color: if not GUI_handle: qk = ax.quiverkey(q, X=0.95, Y=0.98, U=1, label=r'$Msat$', labelpos='S', coordinates='axes') qk.text.set_backgroundcolor('w') if origin == 'upper': ax.invert_yaxis() if title is not None: tr = False ax.set_title(title) else: tr = True plt.tick_params(axis='x', labelbottom=False, bottom=False, top=False) plt.tick_params(axis='y', labelleft=False, left=False, right=False) # ax.set_aspect(aspect) if not GUI_handle: plt.show() if save is not None: if not color: tr = False fig.set_size_inches(8, 8/aspect) print(f'Saving: {save}') plt.axis('off') # sets dpi to 5 times original image dpi so arrows are reasonably sharp dpi2 = max(dimy, dimx) * 5 / sz_inches plt.savefig(save, dpi=dpi2, bbox_inches='tight', transparent=tr) if GUI_handle: return fig, ax else: return