Source code for metabox.propagation

"""
This file contains the classes and functions to simulate the propagation of light.
"""
from __future__ import annotations

import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import copy
import dataclasses
import itertools
from typing import List, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from matplotlib.ticker import EngFormatter
from scipy import interpolate

from metabox import expansion, utils


[docs]@dataclasses.dataclass class FieldProperties: """Defines the properties of a 2d field. Args: n_pixels: the number of pixels per dim after 2D expansion. wavelength: the wavelength of the light in meters. theta: the angle of the light in degrees. phi: the phase of the light in degrees. period: the period of the pixels in meters. upsampling: the upsampling factor. use_padding: whether to use padding or not. use_antialiasing: whether to use antialiasing or not. """ n_pixels: int wavelength: List[float] theta: List[float] phi: List[float] period: float upsampling: int use_padding: bool use_antialiasing: bool
[docs] def copy(self): """Returns a copy of the field properties.""" return copy.deepcopy(self)
[docs]@dataclasses.dataclass class Field2D(FieldProperties): """Class to store the field data and its metadata. Args: tensor: the tensor of the field. wavelength: the wavelength of the light in meters. theta: the angle of the light in degrees. phi: the phase of the light in degrees. period: the period of the pixels in meters. upsampling: the upsampling factor. use_padding: whether to use padding or not. use_antialiasing: whether to use antialiasing or not. """ tensor: tf.Tensor
[docs] def modulated_by(self, other: Field2D) -> Field2D: """Modulate this field by another field. Args: other (Field2D): the field to modulate by. Returns: Field2D: the modulated field. """ # Check that the fields have the same properties. if not np.all(self.wavelength == other.wavelength): raise ValueError( "Wavelengths are not the same, got {0} and {1}.".format( self.wavelength, other.wavelength ) ) if not self.period == other.period: if self.period > other.period: fine_field = other.copy() coarse_field = self.copy() else: fine_field = self.copy() coarse_field = other.copy() enlarge_factor = coarse_field.period / float(fine_field.period) old_shape = coarse_field.tensor.shape[-2:] new_shape = ( int(old_shape[0] * enlarge_factor), int(old_shape[1] * enlarge_factor), ) real_part = tf.math.real(coarse_field.tensor) imag_part = tf.math.imag(coarse_field.tensor) parts = [] for part in [real_part, imag_part]: part = tf.image.resize( part[..., tf.newaxis], new_shape, method=tf.image.ResizeMethod.NEAREST_NEIGHBOR, ) part = tf.image.resize_with_crop_or_pad( part, fine_field.tensor.shape[-2], fine_field.tensor.shape[-1], ) part = part[..., 0] parts.append(part) mod_tensor = tf.complex(parts[0], parts[1]) self = fine_field other = copy.deepcopy(fine_field) other.tensor = mod_tensor if not np.all(self.phi == other.phi): raise ValueError("angles_y are not the same.") if not np.all(self.theta == other.theta): raise ValueError("Angles are not the same.") if not self.tensor.shape == other.tensor.shape: if self.tensor.shape[-2] > other.tensor.shape[-2]: small_field = other large_field = self else: small_field = self large_field = other real_part = tf.math.real(small_field.tensor) imag_part = tf.math.imag(small_field.tensor) parts = [] for part in [real_part, imag_part]: part = tf.image.resize_with_crop_or_pad( part[..., tf.newaxis], large_field.tensor.shape[-2], large_field.tensor.shape[-1], ) part = part[..., 0] parts.append(part) mod_tensor = tf.complex(parts[0], parts[1]) self = large_field other = copy.deepcopy(large_field) other.tensor = mod_tensor new_field = copy.deepcopy(self) new_field.tensor *= other.tensor return new_field
def __post_init__(self): # Check that the field tensor is of the correct shape. if len(self.tensor.shape) != 3: raise ValueError( "Field tensor must have shape [batchsize, n_pix, n_pix]" ) # check batch size is correct n_batch = self.tensor.shape[0] expected_n_batch = ( len(self.wavelength) * len(self.theta) * len(self.phi) ) if n_batch != expected_n_batch: raise ValueError( """Batch size of field tensor does not match the number of wavelengths, angles, and angles_y multiplied together. Expected batch size: {0} Got batch size: {1}""".format( expected_n_batch, n_batch ) )
[docs] def get_intensity(self): """Returns the intensity tensor of the tensor.""" return get_intensity_2d(self)
[docs] def get_phase(self): """Returns the phase tensor of the tensor.""" return get_phase_2d(self)
[docs] def wavelength_average(self): """Returns the wavelength averaged field.""" return wavelength_average_2d(self)
[docs] def show_phase(self): """Shows the phase of the field.""" phase = self.get_phase() wl, ag = phase.shape[:2] diameter = self.period * phase.shape[-1] radius = diameter / 2.0 for i, j in itertools.product(range(wl), range(ag)): wave = self.wavelength[i] * 1e6 angle = list(itertools.product(self.theta, self.phi))[j] f = plt.figure(figsize=(5, 5), dpi=100) ax = plt.axes([0, 0.05, 0.9, 0.9]) im = ax.imshow( phase[i, j], extent=[-radius, radius, -radius, radius] ) formatter0 = EngFormatter(unit="m") ax.xaxis.set_major_formatter(formatter0) ax.yaxis.set_major_formatter(formatter0) plt.locator_params(axis="y", nbins=3) plt.locator_params(axis="x", nbins=3) ax.set_xlabel("X") ax.set_ylabel("Y") ax.grid(False) title = "Phase Distribution\n λ={0}µm, AOI={1}°,{2}°".format( round(wave, 2), round(angle[0], 2), round(angle[1], 2) ) ax.set_title(title) cax = plt.axes([0.95, 0.05, 0.05, 0.9]) plt.colorbar(mappable=im, cax=cax) plt.show()
[docs] def show_intensity(self, crop_factor=1.0): """Shows the intensity of the field. Args: crop_factor (float): The crop factor. Must be less than or equal to 1.0. """ if crop_factor > 1.0: raise ValueError("Zoom must be greater than or equal to 1.0.") intensity = self.get_intensity() # put angle in the last dimension intensity = tf.transpose(intensity, [0, 2, 3, 1]) if crop_factor != 1.0: intensity = tf.image.central_crop(intensity, crop_factor) # return to the original dim order intensity = tf.transpose(intensity, [0, 3, 1, 2]) wl, ag = intensity.shape[:2] diameter = self.period * intensity.shape[-1] * crop_factor radius = diameter / 2.0 for i, j in itertools.product(range(wl), range(ag)): wave = self.wavelength[i] * 1e6 angle = list(itertools.product(self.theta, self.phi))[j] f = plt.figure(figsize=(5, 5), dpi=100) ax = plt.axes([0, 0.05, 0.9, 0.9]) im = ax.imshow( intensity[i, j], extent=[-radius, radius, -radius, radius] ) formatter0 = EngFormatter(unit="m") ax.xaxis.set_major_formatter(formatter0) ax.yaxis.set_major_formatter(formatter0) plt.locator_params(axis="y", nbins=3) plt.locator_params(axis="x", nbins=3) ax.set_xlabel("X") ax.set_ylabel("Y") ax.grid(False) title = "Intensity Distribution\n λ={0}µm, AOI={1}°,{2}°".format( round(wave, 2), round(angle[0], 2), round(angle[1], 2) ) ax.set_title(title) cax = plt.axes([0.95, 0.05, 0.05, 0.9]) plt.colorbar(mappable=im, cax=cax) plt.show()
[docs] def show_color_intensity(self, crop_factor=1.0): """Shows the intensity of the field. Args: crop_factor (float): The crop factor. Must be less than or equal to 1.0. """ if crop_factor > 1.0: raise ValueError("Zoom must be greater than or equal to 1.0.") rgb_intensity = self.to_rgb_intensity() if crop_factor != 1.0: rgb_intensity = tf.image.central_crop(rgb_intensity, crop_factor) ag = rgb_intensity.shape[0] diameter = self.period * rgb_intensity.shape[-2] * crop_factor radius = diameter / 2.0 for j in range(ag): if len(self.wavelength) == 1: wave = round(self.wavelength[0] * 1e6, 2) else: wave_0 = round(self.wavelength[0] * 1e6, 2) wave_1 = round(self.wavelength[-1] * 1e6, 2) wave = f"{wave_0}-{wave_1}" angle = list(itertools.product(self.theta, self.phi))[j] f = plt.figure(figsize=(5, 5), dpi=100) ax = plt.axes([0, 0.05, 0.9, 0.9]) im = ax.imshow( rgb_intensity[j], extent=[-radius, radius, -radius, radius] ) formatter0 = EngFormatter(unit="m") ax.xaxis.set_major_formatter(formatter0) ax.yaxis.set_major_formatter(formatter0) plt.locator_params(axis="y", nbins=3) plt.locator_params(axis="x", nbins=3) ax.set_xlabel("X") ax.set_ylabel("Y") ax.grid(False) title = ( "Color Intensity Distribution\n λ={0}µm, AOI={1}°,{2}°".format( wave, round(angle[0], 2), round(angle[1], 2) ) ) ax.set_title(title) plt.show()
[docs] def to_rgb_intensity(self): """Return RGB image of the intensity.""" intensity_distributions = self.get_intensity().numpy() weighted_intensity_distributions = [] for wavelength, intensity_dist in zip( self.wavelength, intensity_distributions ): r, g, b = utils.wavelength_to_rgb(wavelength) r_dist = intensity_dist * r g_dist = intensity_dist * g b_dist = intensity_dist * b rgb_dist = np.stack([r_dist, g_dist, b_dist], axis=-1) weighted_intensity_distributions.append(rgb_dist) # normalize the image weighted_intensity_distributions = np.array( weighted_intensity_distributions ) weighted_intensity_distributions /= np.max( weighted_intensity_distributions ) return np.sum(weighted_intensity_distributions, axis=0)
[docs]@dataclasses.dataclass class Field1D(FieldProperties): """Class to store the field data and its metadata. Args: wavelength: the wavelength of the light in meters. theta: the angle of the light in degrees. phi: the phase of the light in degrees. period: the period of the pixels in meters. upsampling: the upsampling factor. use_padding: whether to use padding or not. use_antialiasing: whether to use antialiasing or not. """ tensor: tf.Tensor def __post_init__(self): # Check that the field tensor is of the correct shape. if len(self.tensor.shape) != 2: raise ValueError( "Field tensor must have shape [batchsize, n_pix_radius]" ) # check batch size is correct n_batch = self.tensor.shape[0] expected_n_batch = ( len(self.wavelength) * len(self.theta) * len(self.phi) ) if n_batch != expected_n_batch: raise ValueError( """Batch size of field tensor does not match the number of wavelengths, angles, and angles_y multiplied together.""" )
[docs] def expand_to_2d(self, basis_dir="basis_data") -> Field2D: """Function to expand a 1d field to a 2d field. Args: basis_dir: the directory where the basis is saved. The default directory is "basis_data". Returns: propagation2d.Field2D: a 2d field """ new_tensor = expansion.expand_to_2d(self.tensor, basis_dir) # create the 2d field field2d = Field2D( tensor=new_tensor, n_pixels=self.n_pixels, wavelength=self.wavelength, theta=self.theta, phi=self.phi, period=self.period, upsampling=self.upsampling, use_padding=self.use_padding, use_antialiasing=self.use_antialiasing, ) return field2d
[docs] def get_intensity(self): return get_intensity_1d(self)
[docs]def get_intensity_1d(field_1d: Field1D): """returns the intensity tensor of the field. Returns: tf.Tensor: intensity tensor of shape [wavelengths, angles, pixelsX] """ intensity = tf.math.abs(field_1d.tensor) ** 2 # new shape new_shape = tf.convert_to_tensor( [ len(field_1d.wavelength), len(field_1d.theta) * len(field_1d.phi), field_1d.tensor.shape[1], ] ) return tf.reshape(intensity, new_shape)
[docs]def get_phase_2d(field_2d: Field2D): """returns the phase tensor of the field. Returns: tf.Tensor: intensity tensor of shape [wavelengths, angles, pixelsX, pixelsY] """ phase = tf.math.angle(field_2d.tensor) # new shape new_shape = tf.convert_to_tensor( [ len(field_2d.wavelength), len(field_2d.theta) * len(field_2d.phi), field_2d.tensor.shape[1], field_2d.tensor.shape[2], ] ) return tf.reshape(phase, new_shape)
[docs]def get_intensity_2d(field_2d: Field2D): """returns the intensity tensor of the field. Returns: tf.Tensor: intensity tensor of shape [wavelengths, angles, pixelsX, pixelsY] """ intensity = tf.math.abs(field_2d.tensor) ** 2 # new shape new_shape = tf.convert_to_tensor( [ len(field_2d.wavelength), len(field_2d.theta) * len(field_2d.phi), field_2d.tensor.shape[1], field_2d.tensor.shape[2], ] ) return tf.reshape(intensity, new_shape)
[docs]def wavelength_average_2d(field: Field2D) -> Field2D: """Function to average the field over the wavelengths. Args: field (Field2D): the field to average over the wavelengths. Returns: Field2D: the averaged field. """ # seperate wavelength and angle dims field_tensor = tf.reshape( field.tensor, [ len(field.wavelength), len(field.theta) * len(field.phi), field.n_pixels, field.n_pixels, ], ) # average over the wavelengths new_tensor = tf.reduce_mean(field_tensor, axis=0, keepdims=False) # create the new field new_field = Field2D( tensor=new_tensor, n_pixels=field.n_pixels, wavelength=[np.mean(field.wavelength)], theta=field.theta, phi=field.phi, period=field.period, upsampling=field.upsampling, use_padding=field.use_padding, use_antialiasing=field.use_antialiasing, ) return new_field
[docs]def get_transfer_function( field_like: Field2D, ref_idx: float, prop_dist: float, lateral_shift: Union[None, Tuple[float, float]] = None, ) -> tf.Tensor: """Get the Propagator object given the propagation information. Args: ref_idx (float): refractive index of the medium prop_dist (float): propagation distance in meters upsampling (int): upsampling factor use_padding (bool): whether or not to use padding use_anti_aliasing (bool): whether or not to use anti-aliasing lateral_shift: the lateral shift of the sampling window on the detector in meters. If None, the shift is set so that the Chief Ray is at the center of the detector. If a tuple of two floats, the shift is set according to the first element (x shift) and the second element (y shift) of the input tuple. Returns: tf.Tensor: the complex field on the final plane """ return get_propagator_batched( ref_idx=ref_idx, prop_dist=prop_dist, n_pix=field_like.tensor.shape[1], period=field_like.period, wavelength_sampling=field_like.wavelength, theta_sampling=field_like.theta, phi_sampling=field_like.phi, upsampling=field_like.upsampling, use_padding=field_like.use_padding, use_antialiasing=field_like.use_antialiasing, lateral_shift=lateral_shift, )
[docs]def propagate( field: Field2D, transfer_function: tf.Tensor, ) -> Field2D: """Propagate a field through a given transfer function. Args: field (Field): the field to propagate transfer_function (tf.Tensor): the transfer function Returns: End field (Field): the complex field on the final plane """ field = copy.deepcopy(field) new_tensor = propagate_with_propagator_batched( field.tensor, transfer_function, upsampling=field.upsampling, use_padding=field.use_padding, ) field.tensor = new_tensor return field
[docs]def get_propagator_batched( ref_idx: float, prop_dist: float, n_pix: int, period: float, wavelength_sampling: List[float], theta_sampling: List[float], phi_sampling: List[float], upsampling=1, use_padding=True, use_antialiasing=True, lateral_shift: Union[None, Tuple[float, float]] = None, ) -> tf.Tensor: """ Returns the transfer function for a given propagation distance. Args: ref_idx: refractive index of the medium prop_dist: propagation distance in meters n_pix: number of pixels in each dimension period: pixel size in meters wavelength: wavelength in meters theta_sampling: list of angles to sample in degrees phi_sampling: list of angles to sample in degrees upsampling: upsampling factor use_padding: whether to pad the transfer function to prevent aliasing use_antialiasing: whether to limit the transfer function bandwidth to prevent aliasing lateral_shift: the lateral shift of the sampling window on the detector in meters. If None, the shift is set so that the Chief Ray is at the center of the detector. If a tuple of two floats, the shift is set according to the first element (x shift) and the second element (y shift) of the input tuple. Returns: propagator: transfer function """ batch_size = ( len(wavelength_sampling) * len(theta_sampling) * len(phi_sampling) ) lam0 = tf.convert_to_tensor( np.repeat( wavelength_sampling, np.size(theta_sampling) * np.size(phi_sampling), ), dtype=tf.float32, ) lam0 = lam0[:, tf.newaxis, tf.newaxis] lam0 = tf.tile(lam0, multiples=(1, n_pix, n_pix)) # Propagator definition. k = ref_idx * 2 * np.pi / lam0[:, 0, 0] k = k[:, np.newaxis, np.newaxis] samp = int(upsampling * n_pix) k = tf.cast(k, dtype=tf.complex64) if use_padding: k_xlist_pos = ( 2 * np.pi * np.linspace(0, 1 / (2 * period / upsampling), samp) ) front = k_xlist_pos[-(samp - 1) :] front = -front[::-1] k_xlist = np.hstack((front, k_xlist_pos)) k_x = np.kron(k_xlist, np.ones((2 * samp - 1, 1))) else: k_xlist = ( 2 * np.pi * np.linspace( -1 / (2 * period / upsampling), 1 / (2 * period / upsampling), samp, ) ) k_x = np.kron(k_xlist, np.ones((samp, 1))) k_x = k_x[np.newaxis, :, :] k_y = np.transpose(k_x, axes=[0, 2, 1]) k_x = tf.convert_to_tensor(k_x, dtype=tf.complex64) k_x = tf.tile(k_x, multiples=(batch_size, 1, 1)) k_y = tf.convert_to_tensor(k_y, dtype=tf.complex64) k_y = tf.tile(k_y, multiples=(batch_size, 1, 1)) k_z_arg = tf.square(k) - (tf.square(k_x) + tf.square(k_y)) k_z = tf.sqrt(k_z_arg) # Find shift amount theta_out = ( np.pi / 180.0 * tf.convert_to_tensor( np.tile( theta_sampling, np.size(wavelength_sampling) * np.size(phi_sampling), ), dtype=tf.float32, ) ) theta_out = theta_out[:, tf.newaxis, tf.newaxis] theta_out = tf.tile(theta_out, multiples=(1, n_pix, n_pix)) theta = theta_out[:, 0, 0] theta = theta[:, np.newaxis, np.newaxis] y0_real = tf.tan(theta) * prop_dist phi = np.repeat(phi_sampling, np.size(theta_sampling)) phi = np.tile(phi, np.size(wavelength_sampling)) phi = np.pi / 180.0 * tf.convert_to_tensor(phi, dtype=tf.float32) phi = phi[:, tf.newaxis, tf.newaxis] phi = tf.tile(phi, multiples=(1, n_pix, n_pix)) phi = phi[:, 0, 0] phi = phi[:, np.newaxis, np.newaxis] x0_real = tf.tan(phi) * prop_dist if lateral_shift is not None: x0_real = lateral_shift[0] y0_real = lateral_shift[1] y0 = tf.cast(y0_real, dtype=tf.complex64) x0 = tf.cast(x0_real, dtype=tf.complex64) propagator_arg = 1j * (k_z * prop_dist + k_x * x0 + k_y * y0) propagator = tf.exp(propagator_arg) # Limit transfer function bandwidth to prevent aliasing if use_antialiasing: S_x_src = n_pix * period S_y_src = S_x_src S_x_dest = n_pix * period S_y_dest = S_x_dest S_x_avg = 0.5 * (S_x_src + S_x_dest) kx_cond1 = S_x_avg < x0_real kx_cond2 = (-S_x_avg <= x0_real) & (x0_real < S_x_avg) S_y_avg = 0.5 * (S_y_src + S_y_dest) ky_cond1 = S_y_avg < y0_real ky_cond2 = (-S_y_avg <= y0_real) & (y0_real < S_y_avg) kx_limit_plus = k * tf.cast( 1 / tf.math.sqrt((prop_dist / (x0_real + S_x_avg)) ** 2 + 1), dtype=tf.complex64, ) kx_limit_minus = k * tf.cast( 1 / tf.math.sqrt((prop_dist / (x0_real - S_x_avg)) ** 2 + 1), dtype=tf.complex64, ) ky_limit_plus = k * tf.cast( 1 / tf.math.sqrt((prop_dist / (y0_real + S_y_avg)) ** 2 + 1), dtype=tf.complex64, ) ky_limit_minus = k * tf.cast( 1 / tf.math.sqrt((prop_dist / (y0_real - S_y_avg)) ** 2 + 1), dtype=tf.complex64, ) kx_region = tf.where( kx_cond1, _case1(k_x, k_y, k, kx_limit_minus, kx_limit_plus), _where_case2_case3( kx_cond2, k_x, k_y, k, kx_limit_minus, kx_limit_plus ), ) ky_region = tf.where( ky_cond1, _case1(k_y, k_x, k, ky_limit_minus, ky_limit_plus), _where_case2_case3( ky_cond2, k_y, k_x, k, ky_limit_minus, ky_limit_plus ), ) k_region = tf.cast(kx_region & ky_region, dtype=tf.complex64) return propagator * k_region else: return propagator
def _case1(k_x, k_y, k, kx_limit_minus, kx_limit_plus): """Case 1 of the transfer function bandwidth limiting.""" kx_r1 = k_x.numpy() >= 0 kx_r2 = (tf.square(k_x / kx_limit_minus) + tf.square(k_y / k)).numpy() >= 1 kx_r3 = (tf.square(k_x / kx_limit_plus) + tf.square(k_y / k)).numpy() <= 1 return kx_r1 & kx_r2 & kx_r3 def _case2(k_x, k_y, k, kx_limit_minus, kx_limit_plus): """Case 2 of the transfer function bandwidth limiting.""" kx_r1 = k_x.numpy() <= 0 kx_r2 = (tf.square(k_x / kx_limit_minus) + tf.square(k_y / k)).numpy() <= 1 kx_r3 = k_x.numpy() > 0 kx_r4 = (tf.square(k_x / kx_limit_plus) + tf.square(k_y / k)).numpy() <= 1 return (kx_r1 & kx_r2) | (kx_r3 & kx_r4) def _case3(k_x, k_y, k, kx_limit_minus, kx_limit_plus): """Case 3 of the transfer function bandwidth limiting.""" kx_r1 = k_x.numpy() <= 0 kx_r2 = (tf.square(k_x / kx_limit_plus) + tf.square(k_y / k)).numpy() >= 1 kx_r3 = (tf.square(k_x / kx_limit_minus) + tf.square(k_y / k)).numpy() <= 1 return kx_r1 & kx_r2 & kx_r3 def _where_case2_case3(kx_cond2, k_x, k_y, k, kx_limit_minus, kx_limit_plus): """Case 2 and 3 of the transfer function bandwidth limiting.""" kx_region = tf.where( kx_cond2, _case2(k_x, k_y, k, kx_limit_minus, kx_limit_plus), _case3(k_x, k_y, k, kx_limit_minus, kx_limit_plus), ) return kx_region
[docs]def propagate_with_propagator_batched( field: tf.Tensor, propagator: tf.Tensor, use_padding=True, upsampling=1, ) -> tf.Tensor: if use_padding: _, _, m = field.shape n = upsampling * m field = tf.transpose(field, perm=[1, 2, 0]) field_real = tf.math.real(field) field_imag = tf.math.imag(field) field_real = tf.image.resize(field_real, [n, n], method="nearest") field_imag = tf.image.resize(field_imag, [n, n], method="nearest") field = tf.cast(field_real, dtype=tf.complex64) + 1j * tf.cast( field_imag, dtype=tf.complex64 ) field = tf.image.resize_with_crop_or_pad(field, 2 * n - 1, 2 * n - 1) field = tf.transpose(field, perm=[2, 0, 1]) field_freq = tf.signal.fftshift(tf.signal.fft2d(field), axes=(1, 2)) field_filtered = tf.signal.ifftshift(field_freq * propagator, axes=(1, 2)) out = tf.signal.ifft2d(field_filtered) if use_padding: # Crop back down to n x n matrices out = tf.transpose(out, perm=[1, 2, 0]) out = tf.image.resize_with_crop_or_pad(out, n, n) out = tf.transpose(out, perm=[2, 0, 1]) return out
# def propagate_with_propagator( # field: tf.Tensor, propagator: tf.Tensor, use_padding=True, upsampling=1 # ) -> tf.Tensor: # """ # Progragates the field through a distance using the transfer function. # # Args: # field: The field to propagate. # propagator: The transfer function. # use_padding: Whether to use padding or not. # upsampling: The upsampling factor. # """ # if use_padding: # _, m = field.shape # field = tf.transpose(field, perm=[1, 2, 0]) # field_real = tf.math.real(field) # field_imag = tf.math.imag(field) # field_real = tf.image.resize(field_real, [m, m], method="nearest") # field_imag = tf.image.resize(field_imag, [m, m], method="nearest") # field = tf.cast(field_real, dtype=tf.complex64) + 1j * tf.cast( # field_imag, dtype=tf.complex64 # ) # field = tf.image.resize_with_crop_or_pad(field, 2 * m - 1, 2 * m - 1) # field = tf.transpose(field, perm=[2, 0, 1]) # # field_freq = tf.signal.fftshift(tf.signal.fft2d(field)) # field_filtered = tf.signal.ifftshift(field_freq * propagator) # out = tf.signal.ifft2d(field_filtered) # # if use_padding: # # Crop back down to m x m matrices # out = tf.image.resize_with_crop_or_pad(out, m, m) # # return out
[docs]def get_incident_field_2d( field_props: FieldProperties, ) -> tf.Tensor: """Defines the input electric fields for the given wavelengths and field angles. Args: field_props: the field properties. Returns: Field2d: The incident field. """ # Define the cartesian cross section # TODO: perioidicity for x and y seperately wavelength_base = field_props.wavelength theta_base = field_props.theta phi_base = field_props.phi # Define the cartesian cross section n_pix = field_props.n_pixels dx = field_props.period dy = field_props.period xa = np.linspace(0, n_pix - 1, n_pix) * dx # x axis array xa = xa - np.mean(xa) # center x axis at zero ya = np.linspace(0, n_pix - 1, n_pix) * dy # y axis vector ya = ya - np.mean(ya) # center y axis at zero [y_mesh, x_mesh] = np.meshgrid(ya, xa) x_mesh = x_mesh[np.newaxis, :, :] y_mesh = y_mesh[np.newaxis, :, :] lam0, theta, phi = utils.unravel_wavelength_theta_phi( wavelength_base, theta_base, phi_base ) lam0 = lam0[:, tf.newaxis, tf.newaxis] theta = theta[:, tf.newaxis, tf.newaxis] phi = phi[:, tf.newaxis, tf.newaxis] phase_def = ( 2 * np.pi / lam0 * (np.sin(theta) * x_mesh + np.sin(phi) * y_mesh) ) phase_def = tf.cast(phase_def, dtype=tf.complex64) tensor = tf.exp(1j * phase_def) total_energy = tf.reduce_sum(tf.math.abs(tensor) ** 2) tensor /= tf.sqrt(tf.cast(total_energy, dtype=tf.complex64)) return Field2D( tensor=tensor, n_pixels=field_props.n_pixels, wavelength=field_props.wavelength, theta=field_props.theta, phi=field_props.phi, period=field_props.period, upsampling=field_props.upsampling, use_padding=field_props.use_padding, use_antialiasing=field_props.use_antialiasing, )