import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
import dataclasses
from typing import List, Tuple, Union
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from metabox.utils import CoordType, ParameterType
[docs]@dataclasses.dataclass
class Shape:
value: ParameterType
def __post_init__(self):
self.use_complex = _is_complex(self.value)
[docs]@dataclasses.dataclass
class Polygon(Shape):
points: List[CoordType]
def __post_init__(self):
Shape.__post_init__(self)
if not isinstance(self.points, list):
raise TypeError("points must be a list.")
if len(self.points) < 3:
raise ValueError("Polygon must have at least 3 points.")
self.points = [_floatt(p) for p in self.points]
[docs]@dataclasses.dataclass
class Circle(Shape):
"""Defines a circle.
Attributes:
center: The center of the circle.
radius: The radius of the circle.
"""
center: CoordType
radius: ParameterType
def __post_init__(self):
Shape.__post_init__(self)
self.center = _floatt(self.center)
[docs]@dataclasses.dataclass
class Rectangle(Shape):
"""Defines a rectangle.
Attributes:
center: The center of the rectangle.
x_width: The width of the rectangle in the x direction.
y_width: The width of the rectangle in the y direction.
rotation_deg: The rotation of the rectangle in degrees.
"""
center: CoordType
x_width: ParameterType
y_width: ParameterType
rotation_deg: ParameterType = 0.0
def __post_init__(self):
Shape.__post_init__(self)
self.center = _floatt(self.center)
[docs]@dataclasses.dataclass
class Canvas:
x_width: ParameterType
y_width: ParameterType
spacing: ParameterType = 1.0
background_value: ParameterType = 0.0
enforce_4fold_symmetry: bool = False
"""A class for drawing on a canvas.
Attributes:
x_width: The width of the canvas in the x direction.
y_width: The width of the canvas in the y direction.
spacing: The spacing between pixels. Defaults to 1.
background_value: The value of the background. Defaults to 0.
enforce_4fold_symmetry: whether to make the layer 4-fold symmetric.
If true, then the layer will be mirrored along the x and y axes,
then added to its transpose.
"""
def __post_init__(self):
self.x_pixels = int(self.x_width / self.spacing)
self.y_pixels = int(self.y_width / self.spacing)
xx = tf.linspace(-self.x_pixels / 2, self.x_pixels / 2, self.x_pixels)
yy = tf.linspace(-self.y_pixels / 2, self.y_pixels / 2, self.y_pixels)
self.xx, self.yy = tf.meshgrid(xx, yy)
self.map = tf.zeros([self.y_pixels, self.x_pixels])
[docs] def rasterize(self, shape_list: List[Shape]) -> tf.Tensor:
"""Rasterizes a list of shapes.
High level API for rasterizing a list of shapes.
Args:
shape_list: The list of shapes to rasterize.
"""
if not isinstance(shape_list, list):
raise TypeError("shape_list must be a list.")
use_complex = False
if _is_complex(self.background_value):
use_complex = True
else:
for shape in shape_list:
if shape.use_complex:
use_complex = True
break
if use_complex:
self.map = tf.cast(self.map, tf.complex64)
self.background_value = tf.cast(self.background_value, tf.complex64)
self.use_complex = use_complex
for shape in shape_list:
self.merge_shape(shape, enforce_4fold_symmetry=self.enforce_4fold_symmetry)
if use_complex:
return self.map + tf.cast(self.background_value, tf.complex64)
return self.map + self.background_value
[docs] def merge_shape(self, shape: Shape, enforce_4fold_symmetry: bool) -> tf.Tensor:
"""Adds a shape onto the canvas.
Args:
shape: The shape to rasterize.
"""
tabula_rasa = _blank_canvas_like(self)
if isinstance(shape, Polygon):
tabula_rasa.add_polygon(shape.points)
elif isinstance(shape, Rectangle):
tabula_rasa.add_rectangle(
shape.center, shape.x_width, shape.y_width, shape.rotation_deg
)
elif isinstance(shape, Circle):
tabula_rasa.add_circle(shape.center, shape.radius)
else:
raise TypeError("Unsupported shape type {}.".format(type(shape)))
if enforce_4fold_symmetry:
image = tabula_rasa.map[..., tf.newaxis]
image += tf.image.rot90(image, k=1)
image += tf.image.rot90(image, k=2)
image += tf.image.flip_left_right(image)
tabula_rasa.map = image[..., 0]
tabula_rasa = _apply_threshold(tabula_rasa)
if self.use_complex:
tabula_rasa.map = tf.cast(tabula_rasa.map, tf.complex64)
tabula_rasa.map *= shape.value - self.background_value
self.merge_with(tabula_rasa)
[docs] def merge_with(self, other: "Canvas") -> None:
"""Merges the canvas with another canvas.
Args:
other: The other canvas to merge with.
"""
if self.use_complex:
self.map = tf.where(
tf.abs(other.map) > tf.abs(self.map), other.map, self.map
)
else:
self.map = tf.where(other.map > self.map, other.map, self.map)
[docs] def draw(self):
"""Draws the canvas."""
figure = plt.figure()
plt.xlabel("x")
plt.ylabel("y")
plt.imshow(
self.map.numpy(),
extent=[
-self.x_width / 2,
self.x_width / 2,
-self.y_width / 2,
self.y_width / 2,
],
)
plt.colorbar()
def __add__(self, other):
self.map += other.map
return self
def __sub__(self, other):
self.map -= other.map
return self
def __mul__(self, other):
self.map *= other.map
return self
def __truediv__(self, other):
self.map /= other.map
return self
[docs] def add_point(self, p: Tuple[float, float], radius=0.5) -> None:
"""Adds a point to the canvas.
Args:
p: The point to add.
radius: The radius of the point.
"""
self.map = _add_point(self, p, radius).map
[docs] def add_triangle(
self,
p0: Tuple[float, float],
p1: Tuple[float, float],
p2: Tuple[float, float],
) -> None:
"""Adds a triangle to the canvas.
Args:
p0: The first point of the triangle.
p1: The second point of the triangle.
p2: The third point of the triangle.
If the points are in clockwise order, the triangle will be rasterized
as a positive shape. If the points are in counter-clockwise order, the
triangle will be rasterized as a negative shape.
"""
self.map = _add_triangle(self, p0, p1, p2).map
[docs] def add_polygon(
self,
points: CoordType,
keep_positive: bool = True,
) -> None:
"""Adds a polygon to the canvas.
Args:
points (List[Tuple[float, float]]): _description_
keep_positive (bool, optional): keeps the values positive. Defaults to True.
"""
self.map = _add_polygon(self, points, keep_positive=keep_positive).map
[docs] def add_regular_polygon(
self,
center: Tuple[float, float],
radius: float,
n: int,
keep_positive: bool = True,
apply_threshold: bool = True,
) -> None:
"""Adds a regular polygon to the canvas.
Args:
center: The center of the polygon.
radius: The radius of the polygon.
n: The number of sides of the polygon.
keep_positive: Keeps the values positive. Defaults to True.
apply_threshold: Applies threshold from 0 to 1. Defaults to True.
"""
self.map = _add_regular_polygon(
self,
center,
radius,
n,
).map
[docs] def add_regular_star(
self,
center: CoordType,
radius: ParameterType,
n: int,
keep_positive: bool = True,
apply_threshold: bool = True,
) -> None:
"""Addes a regular star to the canvas.
Args:
center: The center of the star.
radius: The radius of the star.
n: The number of points of the star.
keep_positive: Keeps the values positive. Defaults to True.
apply_threshold: Applies threshold from 0 to 1. Defaults to True.
"""
self.map = _add_regular_star(
self,
center,
radius,
n,
keep_positive=keep_positive,
apply_threshold=apply_threshold,
).map
[docs] def add_rectangle(
self,
center: CoordType,
x_width: ParameterType,
y_width: ParameterType,
rotation_deg: ParameterType = 0.0,
) -> None:
"""Adds a rectangle to the canvas.
Args:
center: The center of the rectangle.
x_width: The width of the rectangle in the x direction.
y_width: The width of the rectangle in the y direction.
rotation_deg: The rotation of the rectangle in degrees. Defaults to 0.
"""
vertices = rectangle_to_vertices(center, x_width, y_width, rotation_deg)
self.map = _add_polygon(self, vertices).map
[docs] def add_circle(
self,
center: CoordType,
radius: ParameterType,
) -> None:
"""Adds a circle to the canvas.
Args:
center: The center of the circle.
radius: The radius of the circle.
keep_positive: Keeps the values positive. Defaults to True.
apply_threshold: Applies threshold from 0 to 1. Defaults to True.
"""
self.map = _add_circle(
self,
center,
radius,
).map
def _is_complex(x: ParameterType):
"""Determine if a ParameterType is complex-valued.
Args:
x: A ParameterType.
"""
x = tf.convert_to_tensor(x)
return x.dtype.is_complex
def _floatt(x: CoordType) -> CoordType:
"""Clean up the input coordinates.
Args:
x: The coordinates to clean up.
Returns:
The cleaned up coordinates.
"""
a, b = tf.math.real(x[0]), tf.math.real(x[1])
return (tf.cast(a, tf.float32), tf.cast(b, tf.float32))
def _blank_canvas_like(canvas) -> Canvas:
"""Generates a blank canvas with the same properties as the input canvas.
Args:
canvas: The canvas to copy.
Returns:
A blank canvas with the same properties as the input canvas.
"""
return Canvas(
x_width=canvas.x_width,
y_width=canvas.y_width,
spacing=canvas.spacing,
background_value=canvas.background_value,
)
def _add_point(canvas, p: CoordType, radius: ParameterType = 0.5) -> Canvas:
"""Adds a point to the canvas.
Args:
canvas: The canvas to add the point to.
p: The point to add.
radius: The radius of the point.
Returns:
The canvas with the point added.
"""
p = _floatt(p)
new_xx = canvas.xx - p[0] / canvas.spacing
new_yy = canvas.yy + p[1] / canvas.spacing
canvas.map += tf.exp((-tf.square(new_xx) - tf.square(new_yy)) / (radius**2))
return canvas
def _add_circle(canvas, center: CoordType, radius: ParameterType) -> Canvas:
"""Adds a circle to the canvas.
Args:
canvas: The canvas to add the circle to.
center: The center of the circle.
radius: The radius of the circle.
Returns:
The canvas with the circle added.
"""
center = _floatt(center)
new_xx = canvas.xx - center[0] / canvas.spacing
new_yy = canvas.yy + center[1] / canvas.spacing
canvas.map += radius / canvas.spacing - tf.sqrt((new_xx**2) + (new_yy**2))
return _apply_threshold(canvas)
def _add_line_ramp(
canvas,
p0: Tuple[float, float],
p1: Tuple[float, float],
) -> Canvas:
"""Adds a line ramp to the canvas.
Args:
canvas: The canvas to add the line ramp to.
p0: The first point of the line.
p1: The second point of the line.
Returns:
The canvas with the line ramp added.
"""
p0, p1 = _floatt(p0), _floatt(p1)
new_xx = canvas.xx - p0[0] / canvas.spacing
new_yy = canvas.yy + p0[1] / canvas.spacing
angle = tf.atan2(p1[1] - p0[1], p1[0] - p0[0])
canvas.map += -tf.sin(angle) * new_xx - tf.cos(angle) * new_yy
return canvas
def _line_function(
canvas,
p0: Tuple[float, float],
p1: Tuple[float, float],
invert: bool = False,
) -> Canvas:
"""Generates a line function.
Args:
canvas: The canvas to generate the line function on.
p0: The first point of the line.
p1: The second point of the line.
invert: Inverts the line function. Defaults to False.
Returns:
The canvas with the line function added.
"""
p0, p1 = _floatt(p0), _floatt(p1)
canvas = _blank_canvas_like(canvas)
canvas = _add_line_ramp(canvas, p0, p1)
canvas = _apply_threshold(canvas)
if invert:
canvas.map = 1 - canvas.map
return canvas
def _apply_threshold(
canvas,
high_threshold: float = 1,
low_threshold: float = 0,
offset: float = 0.5,
norm: Union[float, None] = None,
) -> Canvas:
"""Applies a threshold to the canvas.
Args:
canvas: The canvas to apply the threshold to.
high_threshold: The high threshold. Defaults to 1.
low_threshold: The low threshold. Defaults to 0.
offset: The offset to apply to the canvas. Defaults to 0.5.
norm: The normalization factor. Defaults to None.
Returns:
The canvas with the threshold applied.
"""
canvas.map += offset
canvas.map = tf.where(canvas.map > high_threshold, high_threshold, canvas.map)
canvas.map = tf.where(canvas.map < low_threshold, low_threshold, canvas.map)
if norm is not None:
canvas.map /= norm
else:
canvas.map /= high_threshold - low_threshold
return canvas
def _equal_t(a: CoordType, b: CoordType) -> bool:
"""Checks if two points are equal.
Args:
a: The first point.
b: The second point.
Returns:
True if the points are equal, False otherwise.
"""
return (a[0] == b[0]) and (a[1] == b[1])
def _add_triangle(
canvas,
p0: CoordType,
p1: CoordType,
p2: CoordType,
) -> Canvas:
"""Adds a triangle to the canvas.
Args:
canvas: The canvas to add the triangle to.
p0: The first point of the triangle.
p1: The second point of the triangle.
p2: The third point of the triangle.
Returns:
The canvas with the triangle added.
"""
p0, p1, p2 = _floatt(p0), _floatt(p1), _floatt(p2)
if _equal_t(p0, p1) or _equal_t(p1, p2) or _equal_t(p2, p0):
# not a triangle
return canvas
# calculate cross product to determine if triangle is clockwise or counterclockwise
vector_0 = (p1[1] - p0[1], p1[0] - p0[0])
vector_1 = (p2[1] - p1[1], p2[0] - p1[0])
cross_product = vector_0[0] * vector_1[1] - vector_0[1] * vector_1[0]
if cross_product == 0:
# not a triangle
return canvas
elif cross_product < 0:
invert = False
flip_v = 1.0
else:
invert = True
flip_v = -1.0
l0 = _line_function(canvas, p0, p1, invert)
l1 = _line_function(canvas, p1, p2, invert)
l2 = _line_function(canvas, p2, p0, invert)
canvas.map += (l0 * l1 * l2).map * flip_v
return canvas
def _is_convex(canvas, points: List[Tuple[float, float]]) -> bool:
"""Checks if a polygon is convex.
Args:
canvas: The canvas to check the polygon on.
points: The points of the polygon.
Returns:
True if the polygon is convex, False otherwise.
"""
chirolities = []
for i in range(len(points)):
p0, p1, p2 = (
points[i],
points[(i + 1) % len(points)],
points[(i + 2) % len(points)],
)
vector_0 = (p1[1] - p0[1], p1[0] - p0[0])
vector_1 = (p2[1] - p1[1], p2[0] - p1[0])
cross_product = vector_0[0] * vector_1[1] - vector_0[1] * vector_1[0]
if cross_product == 0:
# not a triangle
return False
else:
chirolities.append(cross_product > 0)
if all(chirolities) or not any(chirolities):
return True
return False
def _add_convex_polygon(canvas, points: List[CoordType]) -> Canvas:
"""Adds a convex polygon to the canvas.
Args:
canvas: The canvas to add the polygon to.
points: The points of the polygon.
Returns:
The canvas with the polygon added.
"""
slate = _blank_canvas_like(canvas)
slate.map += 1.0
# check for sign
vector_0 = (points[1][1] - points[0][1], points[1][0] - points[0][0])
vector_1 = (points[2][1] - points[1][1], points[2][0] - points[1][0])
cross_product = vector_0[0] * vector_1[1] - vector_0[1] * vector_1[0]
if cross_product > 0:
points = points[::-1]
for i in range(len(points)):
line = _line_function(canvas, points[i], points[(i + 1) % len(points)])
slate *= line
canvas += slate
return canvas
def _add_polygon(
canvas,
points: List[CoordType],
keep_positive: bool = True,
apply_threshold: bool = True,
) -> Canvas:
"""Adds a polygon to the canvas.
Args:
canvas: The canvas to add the polygon to.
points: The points of the polygon.
keep_positive: Whether to keep the positive values of the canvas.
apply_threshold: Whether to apply a threshold to the canvas.
Returns:
The canvas with the polygon added.
"""
if _is_convex(canvas, points):
return _add_convex_polygon(canvas, points)
for i in range(len(points)):
canvas = _add_triangle(canvas, (0, 0), points[i], points[(i + 1) % len(points)])
if keep_positive:
canvas.map = tf.abs(canvas.map)
if apply_threshold:
canvas = _apply_threshold(
canvas, high_threshold=1.0, low_threshold=0.0, offset=0.0
)
return canvas
[docs]def rectangle_to_vertices(
center: Tuple[float, float],
x_width: float,
y_width: float,
rotation_deg: float = 0,
) -> Canvas:
"""Adds a rectangle to the canvas.
Args:
canvas: The canvas to add the rectangle to.
center: The center of the rectangle.
x_width: The width of the rectangle in the x direction.
y_width: The width of the rectangle in the y direction.
rotation_deg: The rotation of the rectangle in degrees.
Returns:
The points of the rotated rectangle.
"""
center = _floatt(center)
x_width, y_width = float(x_width), float(y_width)
rotation_rad = -rotation_deg * np.pi / 180
rotation_matrix = np.array(
[
[tf.cos(rotation_rad), -tf.sin(rotation_rad)],
[tf.sin(rotation_rad), tf.cos(rotation_rad)],
]
)
translation = np.array([center[0], center[1]])
points = [
np.array([x_width / 2, y_width / 2]),
np.array([-x_width / 2, y_width / 2]),
np.array([-x_width / 2, -y_width / 2]),
np.array([x_width / 2, -y_width / 2]),
]
for i in range(len(points)):
points[i] = rotation_matrix @ points[i] + translation
return points
def _add_regular_polygon(
canvas,
center: Tuple[float, float],
radius: float,
n: int,
) -> Canvas:
"""Adds a regular polygon to the canvas.
Args:
canvas: The canvas to add the polygon to.
center: The center of the polygon.
radius: The radius of the polygon.
n: The number of sides of the polygon.
Returns:
The canvas with the polygon added.
"""
points = []
for i in range(n):
points.append(
(
radius * tf.cos(2 * np.pi * i / n) + center[0],
radius * tf.sin(2 * np.pi * i / n) + center[1],
)
)
return _add_polygon(canvas, points)
def _add_regular_star(
canvas,
center: Tuple[float, float],
radius: float,
n: int,
) -> Canvas:
"""Adds a regular star to the canvas.
Args:
canvas: The canvas to add the star to.
center: The center of the star.
radius: The radius of the star.
n: The number of sides of the star.
Returns:
The canvas with the star added.
"""
order = list(range(n))
new_order = [(ii * (n // 2)) % n for ii in order]
points = []
for i in new_order:
points.append(
(
radius * tf.cos(2 * np.pi * i / n) + center[0],
radius * tf.sin(2 * np.pi * i / n) - center[1],
)
)
return _add_polygon(canvas, points)