import dataclasses
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from os import devnull
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import tensorflow as tf
# https://stackoverflow.com/questions/11130156/suppress-stdout-stderr-print-from-python-functions
[docs]@contextmanager
def suppress_stdout_stderr() -> None:
"""A context manager that redirects stdout and stderr to devnull"""
with open(devnull, "w") as fnull:
with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out:
yield (err, out)
[docs]def recursively_convert_ndarray_in_dict_to_list(item: Any):
"""Recursively converts ndarray item in dict to list"""
if type(item) is np.ndarray:
item = item.tolist()
elif type(item) is dict:
for key, value in item.items():
item[key] = recursively_convert_ndarray_in_dict_to_list(value)
return item
[docs]@dataclasses.dataclass
class Feature:
"""Defines a feature variable.
Args:
vmin: the minimum value of the feature.
vmax: the maximum value of the feature.
name: the name of the feature.
sampling: the number of samples to take between vmin and vmax. If None,
the sampling is undefined.
"""
vmin: float
vmax: float
name: str
initial_value: Union[float, None] = None
sampling: Union[int, None] = None
value: Union[tf.Variable, None, float] = None
def _post_init__(self):
if ":" in self.name:
raise ValueError(
"The name of the feature cannot contain the character '~'"
)
def __eq__(self, other):
return self.name == other.name
def __hash__(self):
return hash(self.name)
[docs] def initialize_value(self) -> None:
"""Initializes the variables of the feature."""
if self.initial_value is not None:
tensor = tf.math.real(self.initial_value)
tensor = tf.cast([tensor], tf.float32)
else:
tensor = tf.random.uniform([1], self.vmin, self.vmax, tf.float32)
self.value = tensor
[docs] def set_variable(self) -> None:
"""Convert self.value to a variable"""
constraint_func = lambda x: tf.clip_by_value(x, self.vmin, self.vmax)
self.value = tf.Variable(
self.value, constraint=constraint_func, name=self.name
)
[docs] def set_value(self, value: Any) -> None:
"""Set the value of the feature to value"""
self.value = value
[docs]@dataclasses.dataclass
class Incidence:
"""Defines the physical properties of the incident light.
Args:
wavelength: the wavelengths of the light in meters.
theta: tuple of the angles of incidence in degrees on the xz plane.
Defaults to (0).
phi: tuple of the angles of incidence in degrees on the yz plane.
Defaults to (0).
jones_vector: the Jones vector of the incident light.
Defaults to (1, 0) which corresponds to a linearly polarized
light with the electric field vector parallel to the x axis.
"""
wavelength: Tuple[float]
theta: Tuple[float] = (0,)
phi: Tuple[float] = (0,)
jones_vector: Tuple[float] = (1, 0)
[docs]def unravel_wavelength_theta_phi(
wavelength: List[float], theta: List[float], phi: List[float]
) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Unravels the wavelength, theta, and phi lists into tensors.
Args:
wavelength: a list of wavelengths in meters.
theta: a list of angles of incidence in degrees on the xz plane.
phi: a list of angles of incidence in degrees on the yz plane.
Returns:
A tuple of tensors of shape (batch_size,).
"""
wavelength_base = wavelength
theta_base = theta
phi_base = phi
wavelength_out = tf.convert_to_tensor(
np.repeat(wavelength_base, np.size(theta_base) * np.size(phi_base)),
dtype=tf.float32,
)
theta_out = (
np.pi
/ 180.0
* tf.convert_to_tensor(
np.tile(theta_base, np.size(wavelength_base) * np.size(phi_base)),
dtype=tf.float32,
)
)
phi = np.repeat(phi_base, np.size(theta_base))
phi = np.tile(phi, np.size(wavelength_base))
phi_out = np.pi / 180.0 * tf.convert_to_tensor(phi, dtype=tf.float32)
return wavelength_out, theta_out, phi_out
[docs]def unravel_incidence(incidence: Incidence) -> Dict[str, Any]:
"""Serializes an incidence data into lists."""
(
wavelength_batch,
theta_batch,
phi_batch,
) = unravel_wavelength_theta_phi(
wavelength=incidence.wavelength,
theta=incidence.theta,
phi=incidence.phi,
)
x_pol, y_pol = incidence.jones_vector
x_pol_batch = tf.cast(
tf.repeat(x_pol, len(wavelength_batch)), tf.complex64
)
y_pol_batch = tf.cast(
tf.repeat(y_pol, len(wavelength_batch)), tf.complex64
)
return {
"wavelength": wavelength_batch,
"theta": theta_batch,
"phi": phi_batch,
"ptm": x_pol_batch,
"pte": y_pol_batch,
}
ParameterType = Union[Feature, float, tf.Tensor]
CoordType = Tuple[ParameterType, ParameterType]
TF_FUNCTIONS = [
"abs",
"acos",
"acosh",
"add",
"asin",
"asinh",
"atan",
"atanh",
"cos",
"cosh",
"sin",
"sinh",
"tan",
"tanh",
"exp",
"sqrt",
"square",
"reduce_sum",
"reduce_mean",
"reduce_max",
"reduce_min",
"reduce_prod",
"matmul",
"transpose",
"reshape",
"expand_dims",
"squeeze",
"stack",
"concat",
# Add any other Tensor functions you want to support here
]
[docs]def wavelength_to_rgb(wavelength):
"""
Convert a wavelength in the visible spectrum to RGB
Input: wavelength (in m)
Output: RGB tuple
"""
wavelength = wavelength * 1e9
gamma = 0.8
if (wavelength >= 380) and (wavelength < 440):
attenuation = 0.3 + 0.7 * (wavelength - 380) / (440 - 380)
R = ((-(wavelength - 440) / (440 - 380)) * attenuation) ** gamma
G = 0.0
B = (1.0 * attenuation) ** gamma
elif (wavelength >= 440) and (wavelength < 490):
R = 0.0
G = ((wavelength - 440) / (490 - 440)) ** gamma
B = 1.0
elif (wavelength >= 490) and (wavelength < 510):
R = 0.0
G = 1.0
B = (-(wavelength - 510) / (510 - 490)) ** gamma
elif (wavelength >= 510) and (wavelength < 580):
R = ((wavelength - 510) / (580 - 510)) ** gamma
G = 1.0
B = 0.0
elif (wavelength >= 580) and (wavelength < 645):
R = 1.0
G = (-(wavelength - 645) / (645 - 580)) ** gamma
B = 0.0
elif (wavelength >= 645) and (wavelength <= 750):
attenuation = 0.3 + 0.7 * (750 - wavelength) / (750 - 645)
R = (1.0 * attenuation) ** gamma
G = 0.0
B = 0.0
else:
R = 0.0
G = 0.0
B = 0.0
return (R, G, B)