Source code for PyLorentz.phase.DIP_NN

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


[docs]class DIP_NN(nn.Module): """ Autoencoder for reconstructing object wave and amplitude of LTEM images. Args: num_images: int, number of input channels, equal to number of images in TFS. nb_filters: int, number of filters in the first convolutional block. use_dropout: bool, whether to use dropout in the inner layers. batch_norm: bool, whether to use batch normalization after each convolutional layer. upsampling_mode: str, "bilinear" or "nearest" upsampling method. """
[docs] def __init__( self, num_images: int = 1, nb_filters: int = 16, use_dropout: bool = False, batch_norm: bool = False, upsampling_mode: str = "nearest", ): super().__init__() self.num_images = num_images self.nb_filters = nb_filters self.use_dropout = use_dropout self.batch_norm = batch_norm self.upsampling_mode = upsampling_mode dropout_vals = [0.1, 0.2, 0.1] if use_dropout else [0, 0, 0] self.cd1 = conv2dblock( nb_layers=2, input_channels=self.num_images, output_channels=nb_filters, use_batchnorm=batch_norm, ) self.cd2 = conv2dblock(2, nb_filters, nb_filters * 2, use_batchnorm=batch_norm) self.cd3 = conv2dblock( 2, nb_filters * 2, nb_filters * 4, use_batchnorm=batch_norm, dropout_=dropout_vals[0], ) self.cd4 = conv2dblock( 2, nb_filters * 4, nb_filters * 8, use_batchnorm=batch_norm, dropout_=dropout_vals[0], ) self.bn = conv2dblock( 2, nb_filters * 8, nb_filters * 8, use_batchnorm=batch_norm, dropout_=dropout_vals[1], ) self.upsample_block4p = upsample_block( nb_filters * 8, nb_filters * 4, mode=upsampling_mode ) self.cu4p = conv2dblock( 2, nb_filters * 4, nb_filters * 4, use_batchnorm=batch_norm, dropout_=dropout_vals[2], ) self.upsample_block3p = upsample_block( nb_filters * 4, nb_filters * 2, mode=upsampling_mode ) self.cu3p = conv2dblock(2, nb_filters * 2, nb_filters * 2, use_batchnorm=batch_norm) self.upsample_block2p = upsample_block(nb_filters * 2, nb_filters, mode=upsampling_mode) self.cu2p = conv2dblock(1, nb_filters, nb_filters, use_batchnorm=batch_norm) self.upsample_block1p = upsample_block(nb_filters, self.num_images, mode=upsampling_mode) self.maxpool = F.max_pool2d self.concat = torch.cat
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the DIP network. Args: x: Input tensor. Returns: Output tensor after passing through the network. """ c1 = self.cd1(x) d1 = self.maxpool(c1, kernel_size=2, stride=2) c2 = self.cd2(d1) d2 = self.maxpool(c2, kernel_size=2, stride=2) c3 = self.cd3(d2) d3 = self.maxpool(c3, kernel_size=2, stride=2) c4 = self.cd4(d3) d4 = self.maxpool(c4, kernel_size=2, stride=2) bn = self.bn(d4) u4_p = self.cu4p(self.upsample_block4p(bn)) u3_p = self.cu3p(self.upsample_block3p(u4_p)) u2_p = self.cu2p(self.upsample_block2p(u3_p)) ph = self.upsample_block1p(u2_p) return ph
[docs]class conv2dblock(nn.Module): """ A block consisting of convolutional layers with optional batch normalization and dropout. Args: nb_layers: int, number of convolutional layers. input_channels: int, number of input channels. output_channels: int, number of output channels. kernel_size: int, size of the convolutional kernel. stride: int, stride of the convolution. padding: int, padding for the convolution. use_batchnorm: bool, whether to use batch normalization. lrelu_a: float, negative slope for the Leaky ReLU activation. dropout_: float, dropout rate. last_sigmoid: bool, whether to use a sigmoid activation on the last layer. last_tanh: bool, whether to use a tanh activation on the last layer. last_skipReLU: bool, whether to skip ReLU activation on the last layer. """
[docs] def __init__( self, nb_layers: int, input_channels: int, output_channels: int, kernel_size: int = 3, stride: int = 1, padding: int = 1, use_batchnorm: bool = False, lrelu_a: float = 0.01, dropout_: float = 0, last_sigmoid: bool = False, last_tanh: bool = False, last_skipReLU: bool = False, ): super(conv2dblock, self).__init__() block = [] for idx in range(nb_layers): input_channels = output_channels if idx > 0 else input_channels block.append( nn.Conv2d( input_channels, output_channels, kernel_size=kernel_size, stride=stride, padding=padding, ) ) if dropout_ > 0: block.append(nn.Dropout(dropout_)) if last_sigmoid and idx == nb_layers - 1: block.append(nn.Sigmoid()) elif last_tanh and idx == nb_layers - 1: block.append(nn.Tanh()) elif last_skipReLU and idx == nb_layers - 1: pass else: block.append(nn.LeakyReLU(negative_slope=lrelu_a)) if use_batchnorm: block.append(nn.BatchNorm2d(output_channels)) self.block = nn.Sequential(*block)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the conv2dblock. Args: x: Input tensor. Returns: Output tensor after passing through the block. """ return self.block(x)
[docs]class upsample_block(nn.Module): """ Upsampling block using interpolation followed by a convolution. Args: input_channels: int, number of input channels. output_channels: int, number of output channels. scale_factor: int, factor by which to scale the input. mode: str, interpolation mode, either "bilinear" or "nearest". """
[docs] def __init__( self, input_channels: int, output_channels: int, scale_factor: int = 2, mode: str = "bilinear", ): super(upsample_block, self).__init__() assert mode in ["bilinear", "nearest"], "Mode must be 'bilinear' or 'nearest'." self.scale_factor = scale_factor self.mode = mode self.conv = nn.Conv2d( input_channels, output_channels, kernel_size=1, stride=1, padding=0, ) self.upsample2x = nn.ConvTranspose2d( input_channels, input_channels, kernel_size=3, stride=2, padding=(1, 1), output_padding=(1, 1), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the upsample_block. Args: x: Input tensor. Returns: Output tensor after upsampling and convolution. """ if self.scale_factor == 2: x = self.upsample2x(x) else: x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode) return self.conv(x)
[docs]def rng_seed(seed: int) -> None: """ Set the random seed for reproducibility. Args: seed: The seed value to use. Returns: None """ torch.manual_seed(seed) np.random.seed(seed) torch.cuda.empty_cache() torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False
[docs]def weight_reset(m: nn.Module) -> None: """ Reset the weights of a given module. Args: m: The neural network module to reset. Returns: None """ reset_parameters = getattr(m, "reset_parameters", None) if callable(reset_parameters): m.reset_parameters()