Source code for metabox.expansion

"""
Defines functions to expand a 1d field to a 2d field.
"""
import os

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import pickle

import numpy as np
import tensorflow as tf


[docs]def expand_to_2d(tensor: tf.Tensor, basis_dir="basis_data") -> tf.Tensor: """Function to expand a 1d field to a 2d field. Args: tensor (tf.tensor): the 1d field to expand basis_dir: the directory where the basis is saved. The default directory is "basis_data". Returns: tf.Tensor: the expanded field tensor. """ # retrieve the basis radius_size = tensor.shape[-1] basis = load_basis(radius_size * 2, basis_dir) # Convert the RCWA output to a field radial_arr = tf.cast(tensor, tf.complex64) # radial_arr = tf.sparse.from_dense(radial_arr) # feed radial profile to the basis matrix to get the [batchsize, pixelsX, pixelsY] phase matrix circle = tf.sparse.sparse_dense_matmul(radial_arr, basis) circle = tf.math.conj(circle) # this is a hack to get the correct phase return tf.reshape(circle, [-1, radius_size * 2, radius_size * 2])
[docs]def load_basis(n_pix, basis_dir=None) -> tf.Tensor: """retrieve the 1d to 2d basis from the basis_dir. If the basis is not found, it is created and saved to the basis_dir. Args: n_pix_radial (int): number of pixels per axis basis_dir (str): path to the directory where the basis is saved. if None, the basis is generated and not loaded or saved. Returns: tf.Tensor: the 1d to 2d basis. """ # Initialize radius to circle basis basis_file_path = os.path.join(basis_dir, "r2c_basis_{}".format(n_pix // 2)) if basis_dir is not None: if os.path.exists(basis_file_path): with open(basis_file_path, "rb") as picked_file: basis = pickle.load(picked_file) return tf.cast(basis, tf.complex64) else: print("Basis file not found. Creating basis.") basis_tensor = tf.cast(radius_to_circle_basis(n_pix // 2), tf.complex64) if not os.path.exists(basis_dir): os.makedirs(basis_dir) with open(basis_file_path, "wb") as output_file: pickle.dump(basis_tensor, output_file) return tf.cast(basis_tensor, tf.complex64) else: return tf.cast(radius_to_circle_basis(n_pix // 2), tf.complex64)
# initializes the radius to circle basis
[docs]def radius_to_circle_basis(radius_size) -> tf.Tensor: """Create a basis to map a 1d field to a 2d field. Args: radius_size (int): number of pixels in the radius Returns: tf.Tensor: the 1d to 2d basis """ x = tf.linspace(-1, 1, radius_size * 2) y = tf.linspace(-1, 1, radius_size * 2) _, step = np.linspace(-1, 1, radius_size * 2, retstep=True) r = np.arange(0 + step, 1 + step, step) xx, yy = tf.meshgrid(x, y) rr = tf.sqrt(xx**2 + yy**2) output_list = [] for idx, current_r in enumerate(r): if idx > 0: previous_r = r[idx - 1] else: previous_r = 0 outer_circle = tf.cast(current_r >= rr, tf.int32) inner_circle = tf.cast(previous_r >= rr, tf.int32) ring = outer_circle - inner_circle ring = tf.squeeze(ring) ring = tf.sparse.from_dense([ring]) # basis_tensor[:, :, idx] = ring output_list.append(ring) output_list = tf.sparse.concat(sp_inputs=output_list, axis=0) return tf.sparse.reshape(output_list, [radius_size, -1])