Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

viz: Automatically visualize Cell color attributes in cell space #2558

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 79 additions & 44 deletions mesa/visualization/mpl_space_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def draw_orthogonal_grid(
draw_grid: bool = True,
**kwargs,
):
"""Visualize a orthogonal grid.
"""Visualize a orthogonal grid with automatic cell coloring.

Args:
space: the space to visualize
Expand All @@ -263,13 +263,34 @@ def draw_orthogonal_grid(
Returns:
Returns the Axes object with the plot drawn onto it.

``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.

Cell colors will be automatically visualized if cells have a 'color' attribute. The color
attribute can be any valid matplotlib color specification (name, hex, RGB tuple, etc.).
"""
if ax is None:
fig, ax = plt.subplots()

# First draw cell colors if they exist
if hasattr(space, "all_cells"): # Check if it's a cell space
cell_colors = np.full(
(space.height, space.width, 4), [1, 1, 1, 0]
) # Transparent default

for cell in space.all_cells:
if hasattr(cell, "color"):
x, y = cell.coordinate
try:
rgba_color = to_rgba(cell.color)
cell_colors[y, x] = rgba_color
except ValueError:
warnings.warn(
f"Invalid color value '{cell.color}' for cell at {cell.coordinate}",
UserWarning,
stacklevel=2,
)

# Plot the cell colors
ax.imshow(cell_colors, origin="lower", interpolation="nearest")

# gather agent data
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)
Expand Down Expand Up @@ -298,7 +319,7 @@ def draw_hex_grid(
draw_grid: bool = True,
**kwargs,
):
"""Visualize a hex grid.
"""Visualize a hex grid with automatic cell coloring.

Args:
space: the space to visualize
Expand All @@ -310,28 +331,56 @@ def draw_hex_grid(
Returns:
Returns the Axes object with the plot drawn onto it.

``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", and "zorder". Other field are ignored and will result in a user warning.

Cell colors will be automatically visualized if cells have a 'color' attribute. The color
attribute can be any valid matplotlib color specification (name, hex, RGB tuple, etc.).
"""
if ax is None:
fig, ax = plt.subplots()

# gather data
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)
# First create hexagons for cells if they exist
if hasattr(space, "all_cells"):
patches = []
offset = math.sqrt(0.75)

# for hexgrids we have to go from logical coordinates to visual coordinates
# this is a bit messy.
for cell in space.all_cells:
x, y = cell.coordinate
if y % 2 == 0:
x += 0.5
y *= offset

hex_patch = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)

# give all even rows an offset in the x direction
# give all rows an offset in the y direction
if hasattr(cell, "color"):
try:
hex_patch.set_facecolor(cell.color)
except ValueError:
warnings.warn(
f"Invalid color value '{cell.color}' for cell at {cell.coordinate}",
UserWarning,
stacklevel=2,
)
hex_patch.set_facecolor("none")
else:
hex_patch.set_facecolor("none")

patches.append(hex_patch)

# Add colored hexagons
cell_collection = PatchCollection(patches, match_original=True)
ax.add_collection(cell_collection)

# gather data for agents
s_default = (180 / max(space.width, space.height)) ** 2
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# numbers here are based on a distance of 1 between centers of hexes
# Convert logical to visual coordinates for agents
offset = math.sqrt(0.75)

loc = arguments["loc"].astype(float)

logical = np.mod(loc[:, 1], 2) == 0
loc[:, 0][logical] += 0.5
loc[:, 1] *= offset
Expand All @@ -340,43 +389,29 @@ def draw_hex_grid(
# plot the agents
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
# further styling
ax.set_xlim(-1, space.width + 0.5)
ax.set_ylim(-offset, space.height * offset)

def setup_hexmesh(
width,
height,
):
"""Helper function for creating the hexmaesh."""
# fixme: this should be done once, rather than in each update
# fixme check coordinate system in hexgrid (see https://www.redblobgames.com/grids/hexagons/#coordinates-offset)

patches = []
for x, y in itertools.product(range(width), range(height)):
if draw_grid:
# Grid lines
grid_patches = []
for x, y in itertools.product(range(space.width), range(space.height)):
if y % 2 == 0:
x += 0.5 # noqa: PLW2901
y *= offset # noqa: PLW2901
hex = RegularPolygon(
x += 0.5
y *= offset
hex_patch = RegularPolygon(
(x, y),
numVertices=6,
radius=math.sqrt(1 / 3),
orientation=np.radians(120),
)
patches.append(hex)
mesh = PatchCollection(
patches, edgecolor="k", facecolor=(1, 1, 1, 0), linestyle="dotted", lw=1
grid_patches.append(hex_patch)
grid_collection = PatchCollection(
grid_patches, edgecolor="k", facecolor="none", linestyle="dotted", lw=1
)
return mesh
ax.add_collection(grid_collection)

if draw_grid:
# add grid
ax.add_collection(
setup_hexmesh(
space.width,
space.height,
)
)
return ax


Expand Down
Loading