Skip to content

Commit

Permalink
Fix: Property layer visualization for HexGrid (#2646)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sahil-Chhoker authored Jan 28, 2025
1 parent 79f2969 commit efed03e
Showing 1 changed file with 89 additions and 53 deletions.
142 changes: 89 additions & 53 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import contextlib
import itertools
import warnings
from collections.abc import Callable
from collections.abc import Callable, Iterator
from functools import lru_cache
from itertools import pairwise
from typing import Any

Expand All @@ -18,7 +19,7 @@
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.collections import LineCollection, PatchCollection
from matplotlib.collections import LineCollection, PatchCollection, PolyCollection
from matplotlib.colors import LinearSegmentedColormap, Normalize, to_rgba
from matplotlib.patches import Polygon

Expand Down Expand Up @@ -159,6 +160,37 @@ def draw_space(
return ax


@lru_cache(maxsize=1024, typed=True)
def _get_hexmesh(
width: int, height: int, size: float = 1.0
) -> Iterator[list[tuple[float, float]]]:
"""Generate hexagon vertices for the mesh. Yields list of vertex coordinates for each hexagon."""

# Helper function for getting the vertices of a hexagon given the center and size
def _get_hex_vertices(
center_x: float, center_y: float, size: float = 1.0
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

for row, col in itertools.product(range(height), range(width)):
# Calculate center position with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing
yield _get_hex_vertices(x, y, size)


def draw_property_layers(
space, propertylayer_portrayal: dict[str, dict[str, Any]], ax: Axes
):
Expand Down Expand Up @@ -205,46 +237,74 @@ def draw_property_layers(
vmax = portrayal.get("vmax", np.max(data))
colorbar = portrayal.get("colorbar", True)

# Draw the layer
# Prepare colormap
if "color" in portrayal:
data = data.T
rgba_color = to_rgba(portrayal["color"])
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
cmap = LinearSegmentedColormap.from_list(
layer_name, [(0, 0, 0, 0), (*rgba_color[:3], alpha)]
)
im = ax.imshow(
rgba_data,
origin="lower",
)
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
ax.figure.colorbar(sm, ax=ax, orientation="vertical")

elif "colormap" in portrayal:
cmap = portrayal.get("colormap", "viridis")
if isinstance(cmap, list):
cmap = LinearSegmentedColormap.from_list(layer_name, cmap)
im = ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)
if colorbar:
plt.colorbar(im, ax=ax, label=layer_name)
elif isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
raise ValueError(
f"PropertyLayer {layer_name} portrayal must include 'color' or 'colormap'."
)

if isinstance(space, OrthogonalGrid):
if "color" in portrayal:
data = data.T
normalized_data = (data - vmin) / (vmax - vmin)
rgba_data = np.full((*data.shape, 4), rgba_color)
rgba_data[..., 3] *= normalized_data * alpha
rgba_data = np.clip(rgba_data, 0, 1)
ax.imshow(rgba_data, origin="lower")
else:
ax.imshow(
data.T,
cmap=cmap,
alpha=alpha,
vmin=vmin,
vmax=vmax,
origin="lower",
)

elif isinstance(space, HexGrid):
width, height = data.shape

# Generate hexagon mesh
hexagons = _get_hexmesh(width, height)

# Normalize colors
norm = Normalize(vmin=vmin, vmax=vmax)
colors = data.ravel() # flatten data to 1D array

if "color" in portrayal:
normalized_colors = np.clip(norm(colors), 0, 1)
rgba_colors = np.full((len(colors), 4), rgba_color)
rgba_colors[:, 3] = normalized_colors * alpha
else:
rgba_colors = cmap(norm(colors))

# Draw hexagons
collection = PolyCollection(hexagons, facecolors=rgba_colors, zorder=-1)
ax.add_collection(collection)

else:
raise NotImplementedError(
f"PropertyLayer visualization not implemented for {type(space)}."
)

# Add colorbar if requested
if colorbar:
norm = Normalize(vmin=vmin, vmax=vmax)
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
plt.colorbar(sm, ax=ax, label=layer_name)


def draw_orthogonal_grid(
space: OrthogonalGrid,
Expand Down Expand Up @@ -349,39 +409,15 @@ def draw_hex_grid(
def setup_hexmesh(width, height):
"""Helper function for creating the hexmesh with unique edges."""
edges = set()
size = 1.0
x_spacing = np.sqrt(3) * size
y_spacing = 1.5 * size

def get_hex_vertices(
center_x: float, center_y: float
) -> list[tuple[float, float]]:
"""Get vertices for a hexagon centered at (center_x, center_y)."""
vertices = [
(center_x, center_y + size), # top
(center_x + size * np.sqrt(3) / 2, center_y + size / 2), # top right
(center_x + size * np.sqrt(3) / 2, center_y - size / 2), # bottom right
(center_x, center_y - size), # bottom
(center_x - size * np.sqrt(3) / 2, center_y - size / 2), # bottom left
(center_x - size * np.sqrt(3) / 2, center_y + size / 2), # top left
]
return vertices

# Generate edges for each hexagon
for row, col in itertools.product(range(height), range(width)):
# Calculate center position for each hexagon with offset for even rows
x = col * x_spacing + (row % 2 == 0) * (x_spacing / 2)
y = row * y_spacing

vertices = get_hex_vertices(x, y)

for vertices in _get_hexmesh(width, height):
# Edge logic, connecting each vertex to the next
for v1, v2 in pairwise([*vertices, vertices[0]]):
# Sort vertices to ensure consistent edge representation and avoid duplicates.
edge = tuple(sorted([tuple(np.round(v1, 6)), tuple(np.round(v2, 6))]))
edges.add(edge)

# Return LineCollection for hexmesh
return LineCollection(edges, linestyle=":", color="black", linewidth=1, alpha=1)

if draw_grid:
Expand Down

0 comments on commit efed03e

Please sign in to comment.