Source code for metabox.rcwa

import os, csv, tqdm, copy, dataclasses, gc, warnings, logging, glob

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
from typing import Any, Dict, List, Tuple, Union

_ROOT = os.path.abspath(os.path.dirname(__file__))

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from scipy import interpolate

from metabox import raster, rcwa_tf, utils
from metabox.utils import CoordType, Feature, Incidence, ParameterType

# Suppress tensorflow warnings
tf.get_logger().setLevel(logging.ERROR)


def _get_features(parameter) -> List[Feature]:
    """Returns the features of the shape.

    Args:
        parameter: the parameter to get the features from.

    Yields:
        The features of the parameter."""
    if isinstance(parameter, Feature):
        yield parameter
    # recursively get features from nested iterables
    elif isinstance(parameter, (list, tuple, set)):
        for item in parameter:
            yield from _get_features(item)


[docs]class Parameterizable: """Defines a parameterizable object.""" def __init__(self): self.unique_features = self.get_unique_features()
[docs] def get_unique_features(self) -> List[Feature]: """Returns the unique features of the shape (non-recursively).""" return list(set(self.get_features()))
[docs] def get_features(self) -> List[Feature]: """Returns the features of the shape.""" for field in dataclasses.fields(self): parameter = getattr(self, field.name) yield from _get_features(parameter)
[docs] def initialize_values( self, value_assignment: Union[ None, Tuple[List[Feature], List[float]] ] = None, ) -> None: """Initializes the variables.""" if value_assignment is not None: if len(value_assignment) != 2: raise ValueError( "value_assignment must be a tuple of length 2." ) elif len(value_assignment[0]) != len(value_assignment[1]): raise ValueError( "value_assignment must be a tuple of lists of equal length." ) # Assign the value to the field that contain the feature. features, values = value_assignment for feature, value in zip(features, values): # use value = None to initialize the feature randomly feature.initial_value = value for feature in self.unique_features: feature.initialize_value()
[docs] def replace_feature_with_value(self) -> None: # set the field to the value for field in dataclasses.fields(self): parameter = getattr(self, field.name) if isinstance(parameter, Feature): setattr(self, field.name, parameter.value) elif isinstance(parameter, (list, tuple, set)): new_parameter = [] for item in parameter: if isinstance(item, Feature): new_parameter.append(item.value) else: new_parameter.append(item) setattr(self, field.name, new_parameter)
[docs] def get_variables(self) -> List[tf.Variable]: """Returns the variables of the shape.""" variables = [] for feature in self.unique_features: if feature.value is not None: if isinstance(feature.value, tf.Variable): variables.append(feature.value) return variables
[docs]@dataclasses.dataclass class Shape(Parameterizable): """Defines a shape. Args: material: the `Material` or the ref. index of the shape. """ material: Union[ParameterType, None] def __post_init__(self): return super().__init__()
[docs]@dataclasses.dataclass class Polygon(Shape): """Defines a polygon. Args: material: the `Material` or the ref. index of the shape. vertices: the vertices of the polygon. List of (x, y) coordinates. Example_0: [(0, 0), (1, 0), (1, 1), (0, 1)] Example_1: var = Feature(vmin=0, vmax=1, name="var") [(0, 0), (var, 0), (var, 1), (0, 1)] """ vertices: List[CoordType] def __post_init__(self): if len(self.vertices) < 3: raise ValueError("A polygon must have at least 3 vertices.") for vertex in self.vertices: if len(vertex) != 2: raise ValueError( "Each vertex must be a tuple of (x, y) coordinates." ) return super().__post_init__()
[docs] def get_shape(self, wavelength: Union[float, None] = None): if type(self.material) is Material: if wavelength is None: raise ValueError( "The wavelength must be given to rasterize a polygon with " "a Material index." ) value = self.material.index_at(wavelength) else: value = self.material return raster.Polygon(value=value, points=self.vertices)
[docs] def get_vertices(self): """Returns the vertices of the polygon.""" return self.vertices
[docs]@dataclasses.dataclass class Rectangle(Shape): """Defines a rectangle. Args: material: the ref. index of the shape. x_width: the width of the rectangle in the x direction. y_width: the width of the rectangle in the y direction. x_pos: the x position of the rectangle. Default: 0 y_pos: the y position of the rectangle. Default: 0 rotation_deg: the rotation of the rectangle in degrees. Default: 0 use_4_fold_symmetry: whether to use 4-fold symmetry. Default: False If True, the rectangle will be rotated by 0, 90, 180, 270 degrees. Then """ x_width: ParameterType y_width: ParameterType x_pos: ParameterType = 0 y_pos: ParameterType = 0 rotation_deg: ParameterType = 0 def __post_init__(self): return super().__post_init__()
[docs] def get_shape(self, wavelength: Union[float, None] = None): if type(self.material) is Material: if wavelength is None: raise ValueError( "The wavelength must be given to rasterize a rectangle with " "a Material index." ) value = self.material.index_at(wavelength) else: value = self.material return raster.Rectangle( value=value, center=(self.x_pos, self.y_pos), x_width=self.x_width, y_width=self.y_width, rotation_deg=self.rotation_deg, )
[docs] def get_vertices(self): """Returns the vertices of the rectangle.""" return raster.rectangle_to_vertices( center=(self.x_pos, self.y_pos), x_width=self.x_width, y_width=self.y_width, rotation_deg=self.rotation_deg, )
[docs]@dataclasses.dataclass class Circle(Shape): """Defines a circle. Args: material: the ref. index of the shape. radius: the radius of the circle. center: the center of the circle. A tuple of (x, y) coordinates. Default: (0, 0) """ radius: ParameterType x_pos: ParameterType = 0 y_pos: ParameterType = 0 def __post_init__(self): return super().__post_init__()
[docs] def get_shape(self, wavelength: Union[float, None] = None): if type(self.material) is Material: if wavelength is None: raise ValueError( "The wavelength must be given to rasterize a circle with " "a Material index." ) value = self.material.index_at(wavelength) else: value = self.material return raster.Circle( value=value, center=(self.x_pos, self.y_pos), radius=self.radius, )
[docs] def get_vertices(self, num_of_vertices: int = 21): """Returns the vertices of the circle.""" vertices = [] for i in range(num_of_vertices): angle = i * 2 * np.pi / num_of_vertices x = self.x_pos + self.radius * np.cos(angle) y = self.y_pos + self.radius * np.sin(angle) vertices.append((x, y)) return vertices
[docs]def duplicate_shape(shape: Shape, num_of_duplicates: int) -> List[Shape]: """Generates a list of duplicate parameterized shapes. The returned shapes share the same parameters as the input shape. But with different unique names. Args: shape: the shape to be duplicated. num_of_duplicates: the number of duplicates. Returns: A list of duplicated shapes. """ shapes = [] for i in range(num_of_duplicates): new_shape = copy.deepcopy(shape) features = new_shape.unique_features for feature in features: feature.name = f"{feature.name}~{i}" shapes.append(new_shape) return shapes
[docs]@dataclasses.dataclass class Layer(Parameterizable): """Defines a layer. Args: material: the ref. index of the shape. thickness: the thickness of the layer in meters. shapes: the shapes of the layer. Tuple of Shape objects. Default: () """ material: Union[Feature, float, "Material"] thickness: Union[Feature, float] shapes: Tuple[Shape] = () enforce_4fold_symmetry: bool = False def __post_init__(self): Parameterizable.__init__(self)
[docs] def get_shapes(self, wavelength: Union[float, None] = None): """Returns the shapes of the layer.""" shapes = [] for shape in self.shapes: shapes.append(shape.get_shape(wavelength)) return shapes
[docs] def initialize_values( self, value_assignment: Union[ None, Tuple[List[Feature], List[float]] ] = None, ) -> None: """Initializes the layer variables.""" super().initialize_values(value_assignment) for shape in self.shapes: shape.initialize_values(value_assignment)
[docs] def get_layer_unique_features(self) -> List[Feature]: """Returns the unique features of the layer.""" all_features = copy.deepcopy(self.unique_features) for shape in self.shapes: all_features.extend(shape.unique_features) return list(set(all_features))
[docs] def get_variables(self) -> List[tf.Variable]: """Returns the variables of the layer.""" all_features = copy.deepcopy(self.unique_features) for shape in self.shapes: all_features.extend(shape.unique_features) unique_features = list(set(all_features)) variables = [] for feature in unique_features: if feature.value is not None: if isinstance(feature.value, tf.Variable): variables.append(feature.value) return variables
[docs]@dataclasses.dataclass class UnitCell(Parameterizable): """Defines a unit cell. Attributes: layers: the layers of the unit cell. periodicity: a tuple of (x, y) in meters that define the periodicity of the unit cell in the x and y direction. refl_index: the ref. index in the reflection region. tran_index: the ref. index in the transmission region. """ layers: List[Layer] periodicity: Tuple[ParameterType, ParameterType] refl_index: ParameterType = 1.0 tran_index: ParameterType = 1.0 def __post_init__(self): super().__init__() if len(self.periodicity) != 2: raise ValueError( "The periodicity must be a tuple of (x, y) in meters." )
[docs] def initialize_values( self, value_assignment: Union[ None, Tuple[List[Feature], List[float]] ] = None, ) -> None: """Initializes the layer variables.""" super().initialize_values(value_assignment) for layer in self.layers: layer.initialize_values(value_assignment)
[docs] def replace_features(self): """Replaces the features with the given values.""" _replace_feature_with_value_in_dataclass(self)
[docs] def get_cell_unique_features(self) -> List[Feature]: """Returns the unique features of the unit cell.""" all_features = copy.deepcopy(self.unique_features) for layer in self.layers: all_features.extend(layer.get_layer_unique_features()) return list(set(all_features))
[docs] def get_variables(self) -> List[tf.Variable]: unique_features = self.get_cell_unique_features() variables = [] for feature in unique_features: if feature.value is not None: if isinstance(feature.value, tf.Variable): variables.append(feature.value) return variables
[docs] def get_epsilon(self, x_resolution: int, wavelength: float) -> tf.Tensor: """Returns the permittivity of the unit cell. Args: x_resolution: the resolution of the permittivity in the x direction. Returns: The permittivity of the unit cell as a tf.Tensor. """ pixel_density = self.periodicity[0] / float(x_resolution) epsilon_all = [] for layer in self.layers: epsilon_layer = ( _rasterize_layer( layer=layer, periodicity=self.periodicity, pixel_density=pixel_density, enforce_4fold_symmetry=layer.enforce_4fold_symmetry, wavelength=wavelength, ) ** 2 ) epsilon_all.append(tf.cast(epsilon_layer, tf.complex64)) epsilon_all = tf.cast(epsilon_all, tf.complex64) return tf.stack(epsilon_all, axis=0)
[docs] def get_thickness(self) -> tf.Tensor: """Returns the thickness of the unit cell as a tf.Tensor.""" return tf.cast( [tf.math.real(layer.thickness) for layer in self.layers], tf.float32, )
[docs] def find_feature_index(self, feature_str): """Returns the index of the feature with the given name.""" for i, feature in enumerate(self.unique_features): if feature.name == feature_str: return i raise ValueError("Feature not found.")
def _replace_this_feature_with_value_recursively( parent: Any, child_field: Any ) -> None: """Recursively replaces the features with their values. if the field is a feature, replace it with its value. If the field is a list, tuple, or set, recursively call this function on each element of the list, tuple, or set. If the field is a dataclass, recursively call this function on each field of the dataclass. Args: parent: the parent of the field. child_field: the field to replace. """ parent_is_list_or_tuple = isinstance(parent, (list, tuple)) parent_is_tuple = isinstance(parent, tuple) if parent_is_list_or_tuple: field_content = child_field else: field_content = getattr(parent, child_field.name) if isinstance(field_content, Feature): if field_content.value is None: field_content.initialize_value() if parent_is_list_or_tuple: if parent_is_tuple: parent = list(parent) idx = parent.index(child_field) parent[idx] = field_content.value if parent_is_tuple: parent = tuple(parent) else: setattr(parent, child_field.name, field_content.value) elif isinstance(field_content, (list, tuple)): for child_field_child_field in field_content: _replace_this_feature_with_value_recursively( parent=field_content, child_field=child_field_child_field ) elif dataclasses.is_dataclass(field_content): _replace_feature_with_value_in_dataclass(field_content) # do nothing for the other cases def _replace_feature_with_value_in_dataclass(dataclass_instance) -> None: """Replaces the features with their values in a dataclass instance. Args: dataclass_instance: the dataclass instance to replace the features in. """ # set the field to the value for field in dataclasses.fields(dataclass_instance): _replace_this_feature_with_value_recursively(dataclass_instance, field)
[docs]@dataclasses.dataclass class ProtoUnitCell: """Defines an archetype of UnitCell (i.e. parameterized by `Feature`s). Provides an interface to generate an array of unit cells from a tensor with shape (n_features, n_unit_cells). Attributes: proto_unit_cell: the unit cell that the children units cells are based of. The children unit cells will share the same Features as the parent unit cell. """ proto_unit_cell: UnitCell def __post_init__(self): self.features = self.proto_unit_cell.get_cell_unique_features() self.period = self.proto_unit_cell.periodicity[0]
[docs] def generate_initial_variables(self, n_cells: int) -> tf.Tensor: """Returns a tensor of initial variable parameters for the unit cells. The tensor has shape (n_features, n_unit_cells). Args: n_cell: the number of unit cells to generate. Returns: A tensor of initial parameters for the unit cells. """ # Initialize the tensor tensor_columns = [] clip_value_min = [] clip_value_max = [] for feature in self.proto_unit_cell.get_cell_unique_features(): vmin = [feature.vmin] vmax = [feature.vmax] clip_value_min.append(vmin) clip_value_max.append(vmax) tensor_columns.append(tf.random.uniform([n_cells], vmin, vmax)) tensor = tf.stack(tensor_columns, axis=0) constraint_func = lambda x: tf.clip_by_value( x, clip_value_min, clip_value_max ) return tf.Variable(tensor, constraint=constraint_func)
[docs] def generate_cells_from_parameter_tensor( self, tensor: tf.Tensor ) -> List[UnitCell]: """Returns an array of unit cells from a tensor shape: (n_cell, n_feat). Args: tensor(tf.Tensor): a tensor with shape (n_unit_cells, n_features). Raises: ValueError: when the tensor does not have the correct shape. """ if tensor.shape[0] != len(self.features): raise ValueError( "The tensor must have shape (n_features, n_unit_cells)." ) unit_cell_array = [] for i in range(tensor.shape[-1]): unit_cell = copy.deepcopy(self.proto_unit_cell) features = unit_cell.get_cell_unique_features() parameters = tensor[:, i] for feature, init_value in zip(features, parameters): feature.set_value(init_value) unit_cell.replace_features() unit_cell_array.append(unit_cell) return unit_cell_array
def _rasterize_layer( layer: Layer, periodicity: Tuple[ParameterType, ParameterType], pixel_density: float, enforce_4fold_symmetry: bool = False, wavelength: Union[float, None] = None, ) -> raster.Canvas: """Rasterizes a layer. Args: layer: the layer to rasterize. periodicity: a tuple of (x, y) in meters that define the periodicity of the unit cell in the x and y direction. pixel_density: the pixel density in meters. enforce_4fold_symmetry: whether to make the layer 4-fold symmetric. If true, then the layer will be mirrored along the x and y axes, then added to its transpose. wavelength: the wavelength of the simulation in meters. Returns: The rasterized layer. """ if type(layer.material) is Material: if wavelength is None: raise ValueError( "The wavelength must be given to rasterize a layer with " "a Material index." ) layer_value = layer.material.index_at(wavelength) else: layer_value = layer.material return raster.Canvas( x_width=periodicity[0], y_width=periodicity[1], spacing=pixel_density, background_value=layer_value, enforce_4fold_symmetry=enforce_4fold_symmetry, ).rasterize(layer.get_shapes(wavelength))
[docs]def get_avaliable_materials( custom_csv_dir: Union[str, None] = None ) -> List[str]: """Returns a list of avaliable material strings in the given directory.""" if custom_csv_dir is None: custom_csv_dir = os.path.join(_ROOT, "material_data") avail_materials_dir = glob.glob(os.path.join(custom_csv_dir, "*.csv")) avali_materials = [ os.path.split(file_path)[-1].split(".")[0] for file_path in avail_materials_dir ] return avali_materials
[docs]@dataclasses.dataclass class Material: """Defines a material class. A material provides a way to define the refractive index of a material given the wavelength of the simulation. Attributes: name: the name of the material. The .csv file name must be [name].csv The file extension i.e. `.csv` must be in lowercase. custom_csv_dir: the path to the csv file folder that contains the refractive index data. The data can be downloaded from refractiveindex.info. Just search for the material and download the csv file, under the "Data" section. Save the [CSV - comma separated] file as `csv_file_dir`. """ name: str custom_csv_dir: Union[str, None] = None def __post_init__(self): if self.custom_csv_dir is None: self.custom_csv_dir = os.path.join(_ROOT, "material_data") csv_dir = os.path.join(self.custom_csv_dir, self.name + ".csv") if not os.path.exists(csv_dir): avaliable_materials = get_avaliable_materials(self.custom_csv_dir) default_mat_str = ", ".join(avaliable_materials) raise ValueError( f"The csv file for {self.name} does not exist in {self.custom_csv_dir}.\n" f"Could not find {csv_dir}\n" f"Avaliable materials: {default_mat_str}" ) self.wl_n = [] self.wl_k = [] self.n = [] self.k = [] with open(csv_dir, "r") as file: reader = csv.reader(file) mode = None for row in reader: if len(row) == 0 or row[0].strip() == "": continue # Skip empty rows if row[1] == "n": mode = "n" continue elif row[1] == "k": mode = "k" continue if mode == "n": self.wl_n.append(float(row[0]) * 1e-6) self.n.append(float(row[1])) elif mode == "k": self.wl_k.append(float(row[0]) * 1e-6) self.k.append(float(row[1])) self.min_wl_n = min(self.wl_n) self.max_wl_n = max(self.wl_n) if len(self.wl_k) > 0: self.min_wl_k = min(self.wl_k) self.max_wl_k = max(self.wl_k) self.n_interp = interpolate.interp1d(self.wl_n, self.n) if len(self.wl_k) > 0: self.k_interp = interpolate.interp1d(self.wl_k, self.k)
[docs] def index_at(self, wavelength): """Returns the refractive index at the given wavelength.""" if not (self.min_wl_n <= wavelength <= self.max_wl_n): raise ValueError(f"Wavelength {wavelength} is out of range.") n_value = self.n_interp(wavelength) if hasattr(self, "k_interp") and self.k_interp is not None: k_value = self.k_interp(wavelength) else: k_value = 0j return n_value + 1.0j * k_value
[docs]@dataclasses.dataclass class SimConfig: """Defines a simulation configuration. The SimConfig class is an immutable dataclass that provides the configuration and precomuputed data for the RCWA simulation. TODO: add fast convolution matrix Attributes: xy_harmonics: a tuple of (x, y) positive odd ints of Fourier harmonics. x_resolution: the grid x resolution of the simulation of the real space. Note that the grid y resolution is determined by the aspect ratio of the unit cell. minibatch_size: the minibatch size of the simulation. return_tensor: whether to use tensor as the output. If False, then the following arguments are ignored: return_zeroth_order, use_transmission, include_tz. And the output is the in the form of `SimResult`. If True, the output is in the form of a tf.Tensor, and the following arguments are used: return_zeroth_order=True, use_transmission=True, include_z_comp=False. return_zeroth_order: whether to use zeroth order diffraction as the output. use_transmission: whether to use transmission as the output. include_z: whether to include the z component of the electric field. """ xy_harmonics: Tuple[int, int] resolution: int minibatch_size: int = 100 return_tensor: bool = False return_zeroth_order: Union[bool, None] = None use_transmission: Union[bool, None] = None include_z_comp: Union[bool, None] = None def __post_init__(self): if self.xy_harmonics[0] % 2 != 1 or self.xy_harmonics[0] < 1: raise ValueError("xy_harmonics[0] must be a positive odd int.") elif self.xy_harmonics[1] % 2 != 1 or self.xy_harmonics[1] < 1: raise ValueError("xy_harmonics[1] must be a positive odd int.") if self.return_tensor: if self.return_zeroth_order is None: self.return_zeroth_order = True if self.use_transmission is None: self.use_transmission = True if self.include_z_comp is None: self.include_z_comp = False else: if ( (self.return_zeroth_order is not None) or (self.use_transmission is not None) or (self.include_z_comp is not None) ): warnings.warn( "When return_tensor is False, the following arguments are " "ignored: return_zeroth_order, use_transmission, include_z_comp." )
[docs]@dataclasses.dataclass class SimInstance: """Defines a simulation instance. Attributes: unit_cell_array: an array of unit cells to be simulated. incidence: the incidence of the simulation. sim_config: the simulation configuration. """ unit_cell_array: List[UnitCell] incidence: Incidence sim_config: SimConfig def __post_init__(self): # check all x and y periods are the same x_periodicitys = [ unit_cell.periodicity[0] for unit_cell in self.unit_cell_array ] y_periodicitys = [ unit_cell.periodicity[1] for unit_cell in self.unit_cell_array ] if len(set(x_periodicitys)) != 1: raise ValueError("All x periods must be the same.") if len(set(y_periodicitys)) != 1: raise ValueError("All y periods must be the same.") # check all ref. indices are the same for transmission and reflection regions refl_indices = [ unit_cell.refl_index for unit_cell in self.unit_cell_array ] tran_indices = [ unit_cell.tran_index for unit_cell in self.unit_cell_array ] if len(set(refl_indices)) != 1: raise ValueError( "All ref. indices must be the same for reflection region." ) if len(set(tran_indices)) != 1: raise ValueError( "All ref. indices must be the same for transmission region." )
[docs] def get_variables(self) -> List[tf.Variable]: """Returns the variables of the simulation instance.""" variables = [] for unit_cell in self.unit_cell_array: variables.extend(unit_cell.get_variables()) return variables
[docs]@dataclasses.dataclass class SimResult: """The result of an RCWA simulation. Attributes: rx: the x component of the reflected diffraction coeff. ry: the y component of the reflected diffraction coeff. rz: the z component of the reflected diffraction coeff. r_eff: the reflective efficiency. r_power: the total reflected power. tx: the x component of the transmitted diffraction coeff. ty: the y component of the transmitted diffraction coeff. tz: the z component of the transmitted diffraction coeff. t_eff: the transmissive efficiency. t_power: the total transmitted power. """ rx: tf.Tensor ry: tf.Tensor rz: tf.Tensor r_eff: tf.Tensor r_power: tf.Tensor tx: tf.Tensor ty: tf.Tensor tz: tf.Tensor t_eff: tf.Tensor t_power: tf.Tensor xy_harmonics: Tuple[int, int] @staticmethod def _get_0th(fields, xy_harmonics): return fields[:, :, 0, np.prod(xy_harmonics) // 2, 0]
[docs] def ref_field(self, config: SimConfig) -> tf.Tensor: """Returns the reflected diffraction coefficients. Returns: The reflected field according to the simulation configuation. """ if config.return_zeroth_order: rx = self._get_0th(self.rx, self.xy_harmonics) ry = self._get_0th(self.ry, self.xy_harmonics) rz = self._get_0th(self.rz, self.xy_harmonics) else: rx, ry, rz = self.rx, self.ry, self.rz if config.include_z_comp: return tf.stack([rx, ry, rz], axis=-1) else: return tf.stack([rx, ry], axis=-1)
[docs] def trn_field(self, config: SimConfig) -> tf.Tensor: """Returns the transmitted diffraction coefficients. Args: config: the simulation configuration. Returns: The transmitted field according to the simulation configuation. """ if config.return_zeroth_order: tx = self._get_0th(self.tx, self.xy_harmonics) ty = self._get_0th(self.ty, self.xy_harmonics) tz = self._get_0th(self.tz, self.xy_harmonics) else: tx, ty, tz = self.tx, self.ty, self.tz if config.include_z_comp: return tf.stack([tx, ty, tz], axis=-1) else: return tf.stack([tx, ty], axis=-1)
[docs] def get_result_using_config( self, config: SimConfig ) -> Union["SimResult", tf.Tensor]: """Returns the result according to the simulation configuation. Args: config: the simulation configuration. Returns: The result according to the simulation configuation. """ if not config.return_tensor: return self if config.use_transmission: return self.trn_field(config) else: return self.ref_field(config)
[docs]def minibatch_sim_instance( sim_instance: SimInstance, minibatch_size: int ) -> List[SimInstance]: """Generates a list of minibatch simulation instances. Args: sim_instance: the simulation instance. batch_size: the batch size. Returns: A list of minibatch simulation instances. """ unit_cell_array_chunks = [ sim_instance.unit_cell_array[i : i + minibatch_size] for i in range(0, len(sim_instance.unit_cell_array), minibatch_size) ] sim_instance_array = [] for unit_cell_array in unit_cell_array_chunks: sim_instance_array.append( SimInstance( unit_cell_array=unit_cell_array, incidence=sim_instance.incidence, sim_config=sim_instance.sim_config, ) ) return sim_instance_array
[docs]def combine_sim_results( sim_results: Union[List[SimResult], List[tf.Tensor]] ) -> Union[SimResult, tf.Tensor]: """Combines a list of simulation results into one Args: sim_results: the list of simulation results. Returns: The combined simulation result. """ if isinstance(sim_results[0], tf.Tensor): return tf.concat(sim_results, axis=1) rx = tf.concat([sim_result.rx for sim_result in sim_results], axis=1) ry = tf.concat([sim_result.ry for sim_result in sim_results], axis=1) rz = tf.concat([sim_result.rz for sim_result in sim_results], axis=1) r_eff = tf.concat([sim_result.r_eff for sim_result in sim_results], axis=1) r_power = tf.concat( [sim_result.r_power for sim_result in sim_results], axis=1 ) tx = tf.concat([sim_result.tx for sim_result in sim_results], axis=1) ty = tf.concat([sim_result.ty for sim_result in sim_results], axis=1) tz = tf.concat([sim_result.tz for sim_result in sim_results], axis=1) t_eff = tf.concat([sim_result.t_eff for sim_result in sim_results], axis=1) t_power = tf.concat( [sim_result.t_power for sim_result in sim_results], axis=1 ) xy_harmonics = sim_results[0].xy_harmonics return SimResult( rx=rx, ry=ry, rz=rz, r_eff=r_eff, r_power=r_power, tx=tx, ty=ty, tz=tz, t_eff=t_eff, t_power=t_power, xy_harmonics=xy_harmonics, )
[docs]def simulate_parameterized_unit_cells( parameter_tensor: tf.Tensor, proto_cell: ProtoUnitCell, incidence: Incidence, sim_config: SimConfig, ) -> tf.Tensor: """Simulate RCWA and precompute the JVP with better memory efficiency. This method computes the zeroth order diffraction coefficients and the Jacobian of the diffraction coefficients with respect to the unit cell parameters. Args: parameter_tensor: the tensor of unit cell parameters, in the shape (num_features, num_unit_cells). proto_cell: a parameterized unit cell. incidence: the incidence data. sim_config: the simulation configuration. Returns: The 0th order diffraction coefficients. In the form of a tf.Tensor of shape (batch_size, num_unit_cells, 2). """ minibatch_size = sim_config.minibatch_size simulate_func = simulate_parameterized_unit_cells_one_batch if minibatch_size < parameter_tensor.shape[1]: # TODO: multi-GPU support parameters_chunks = [ parameter_tensor[:, i : i + minibatch_size] for i in range(0, parameter_tensor.shape[1], minibatch_size) ] sim_results = [] for parameters_chunk in tqdm.tqdm(parameters_chunks): sim_results.append( simulate_func( parameter_tensor=parameters_chunk, proto_cell=proto_cell, incidence=incidence, sim_config=sim_config, ) ) gc.collect() return combine_sim_results(sim_results) return simulate_func( parameter_tensor=parameter_tensor, proto_cell=proto_cell, incidence=incidence, sim_config=sim_config, )
[docs]def simulate_parameterized_unit_cells_one_batch( parameter_tensor: tf.Tensor, proto_cell: ProtoUnitCell, incidence: Incidence, sim_config: SimConfig, ) -> tf.Tensor: if not sim_config.return_tensor: raise ValueError( "SimConfig.return_tensor=True is required for this method." ) if len(proto_cell.features) == 0: raise ValueError("The proto cell has no features (not parameterized).") children = proto_cell.generate_cells_from_parameter_tensor( parameter_tensor ) sim_instance = SimInstance( unit_cell_array=children, incidence=incidence, sim_config=sim_config, ) return simulate_one(sim_instance)
[docs]def simulate(sim_instance: SimInstance) -> SimResult: """Simulates the periodic unit cell using RCWA. Calculates the transmission/reflection coefficients for a unit cell with a given a simulation instance (SimInstance), which contains the unit cell, incidence, and simulation configuration. Args: sim_instance: the simulation instance. Returns: The simulation result. """ minibatch_size = sim_instance.sim_config.minibatch_size if minibatch_size < len(sim_instance.unit_cell_array): sim_instances = minibatch_sim_instance( sim_instance=sim_instance, minibatch_size=minibatch_size ) sim_results = simulate_batch(sim_instances=sim_instances) return combine_sim_results(sim_results=sim_results) else: return simulate_one(sim_instance=sim_instance)
[docs]def simulate_batch(sim_instances: List[SimInstance]) -> List[SimResult]: """Simulates a batch of periodic unit cells using RCWA. Calculates the transmission/reflection coefficients for a batch of unit cells with a given a list of simulation instances (SimInstance), which contains the unit cell, incidence, and simulation configuration. TODO: parallelize this function. Args: sim_instances: list of simulation instances. Returns: The list of simulation results. """ return [simulate_one(sim_instance) for sim_instance in sim_instances]
[docs]def simulate_one(sim_instance: "SimInstance") -> "SimResult": """Simulates the periodic unit cell using RCWA. Calculates the transmission/reflection coefficients for a unit cell with a given a simulation instance (SimInstance), which contains the unit cell, incidence, and simulation configuration. Args: sim_instance: the simulation instance. Returns: The simulation result. """ incidence_dict = utils.unravel_incidence(sim_instance.incidence) batched_layer_thicknesses = [] for unit_cell in sim_instance.unit_cell_array: layer_thicknesses = unit_cell.get_thickness() layer_thicknesses = tf.cast(layer_thicknesses, tf.complex64) layer_thicknesses = layer_thicknesses[ tf.newaxis, tf.newaxis, tf.newaxis, :, tf.newaxis, tf.newaxis ] batched_layer_thicknesses.append(layer_thicknesses) layer_thicknesses = tf.concat(batched_layer_thicknesses, axis=1) # permittivity in reflection region er1 = sim_instance.unit_cell_array[0].refl_index ** 2 # permittivity in transmission region er2 = sim_instance.unit_cell_array[0].tran_index ** 2 ### Step 2: Compute permittivity and permeability ### n_cells = len(sim_instance.unit_cell_array) batch_size = len(incidence_dict["wavelength"]) ER_t_arrary = [] for unit_cell in sim_instance.unit_cell_array: ER_t_wl = [] for wavelength in incidence_dict["wavelength"]: this_epsilon = unit_cell.get_epsilon( sim_instance.sim_config.resolution, wavelength=wavelength, ) this_epsilon = this_epsilon[ tf.newaxis, tf.newaxis, tf.newaxis, :, :, : ] ER_t_wl.append(this_epsilon) ER_t = tf.concat(ER_t_wl, axis=0) ER_t_arrary.append(ER_t) ER_t = tf.concat(ER_t_arrary, axis=1) # Dielectric materials for now UR_t = tf.ones_like(ER_t) refl_n = sim_instance.unit_cell_array[0].refl_index output = rcwa_tf.simulate_rcwa( incidence_dict, PQ=sim_instance.sim_config.xy_harmonics, n_cells=n_cells, n_layers=len(sim_instance.unit_cell_array[0].layers), layer_thicknesses=layer_thicknesses, L_xy=sim_instance.unit_cell_array[0].periodicity, er1=er1, er2=er2, ER_t=ER_t, UR_t=UR_t, refl_n=refl_n, ) # Store the transmission/reflection coefficients and powers in a SimResult. result = SimResult( rx=tf.math.conj(output["rx"]), ry=tf.math.conj(output["ry"]), rz=tf.math.conj(output["rz"]), r_eff=tf.math.conj(output["R"]), r_power=tf.math.conj(output["REF"]), tx=tf.math.conj(output["tx"]), ty=tf.math.conj(output["ty"]), tz=tf.math.conj(output["tz"]), t_eff=tf.math.conj(output["T"]), t_power=tf.math.conj(output["TRN"]), xy_harmonics=sim_instance.sim_config.xy_harmonics, ) return result.get_result_using_config(sim_instance.sim_config)