Skip to content

Commit

Permalink
244 experiment pygmt plotting (#295)
Browse files Browse the repository at this point in the history
* add a new folder plotting

* initial effort to add pygmt plot

* add a new notebook to test pygmt

* add a new function get_gplot()

* some unfinished code

* finish the unfinished code yesterday

* remove old placeholder files

* rollback change in setup.py

* add parameters for pygmt plot

* plot subduction with pygmt

* add PlotEngine class

* add PygmtPlotEngine and CartopyPlotEngine

* clean up code, remove temporary comments

* add comments for PlotEngine classes
  • Loading branch information
michaelchin authored Jan 20, 2025
1 parent 55d2e4e commit 053dd33
Show file tree
Hide file tree
Showing 12 changed files with 651 additions and 141 deletions.
2 changes: 2 additions & 0 deletions docker/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ dependencies:
- gplately
- jupyter
- moviepy
- gmt
- pygmt
57 changes: 57 additions & 0 deletions gplately/auxiliary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import Union

from plate_model_manager import PlateModel, PlateModelManager

from .mapping.plot_engine import PlotEngine
from .mapping.cartopy_plot import CartopyPlotEngine
from .plot import PlotTopologies
from .reconstruction import PlateReconstruction


def get_gplot(
model_name: str,
model_repo_dir: str,
age: Union[int, float],
plot_engine: PlotEngine = CartopyPlotEngine(),
) -> PlotTopologies:
"""auxiliary function to get gplot object"""
try:
model = PlateModelManager().get_model(model_name, data_dir=model_repo_dir)
except:
model = PlateModel(model_name, data_dir=model_repo_dir, readonly=True)

if model is None:
raise Exception(f"Unable to get model ({model_name})")

topology_features = None
static_polygons = None
coastlines = None
COBs = None
continents = None

all_layers = model.get_avail_layers()

if "Topologies" in all_layers:
topology_features = model.get_layer("Topologies")
if "StaticPolygons" in all_layers:
static_polygons = model.get_layer("StaticPolygons")
if "Coastlines" in all_layers:
coastlines = model.get_layer("Coastlines")
if "COBs" in all_layers:
COBs = model.get_layer("COBs")
if "ContinentalPolygons" in all_layers:
continents = model.get_layer("ContinentalPolygons")

m = PlateReconstruction(
model.get_rotation_model(),
topology_features=topology_features,
static_polygons=static_polygons,
)
return PlotTopologies(
m,
coastlines=coastlines,
COBs=COBs,
continents=continents,
time=age,
plot_engine=plot_engine,
)
2 changes: 1 addition & 1 deletion gplately/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def inner(func_pointer):
@wraps(func_pointer)
def wrapper(self, ax, **kwargs):
if not self.plate_reconstruction.topology_features:
logger.warn(
logger.warning(
f"Plate model does not have topology features. Unable to plot {feature_name}."
)
return ax
Expand Down
2 changes: 1 addition & 1 deletion gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def read_netcdf_grid(
x_dimension_name: str = "",
y_dimension_name: str = "",
data_variable_name: str = "",
):
) -> tuple[np.ma.MaskedArray, np.ma.MaskedArray, np.ma.MaskedArray] | np.ma.MaskedArray:
"""Read a `netCDF` (.nc) grid from a given `filename` and return its data as a
`MaskedArray`.
Expand Down
1 change: 1 addition & 0 deletions gplately/mapping/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This submodule contains code to plot maps.
112 changes: 112 additions & 0 deletions gplately/mapping/cartopy_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import logging, math
from geopandas.geodataframe import GeoDataFrame
import cartopy.crs as ccrs
from .plot_engine import PlotEngine
from ..utils.plot_utils import _clean_polygons, plot_subduction_teeth
from ..tools import EARTH_RADIUS

logger = logging.getLogger("gplately")

DEFAULT_CARTOPY_PROJECTION = ccrs.PlateCarree()


class CartopyPlotEngine(PlotEngine):
def __init__(self):
pass

def plot_geo_data_frame(self, ax_or_fig, gdf: GeoDataFrame, **kwargs):
"""Plot GeoDataFrame object with Cartopy
Parameters
----------
ax_or_fig : cartopy.mpl.geoaxes.GeoAxes
Cartopy GeoAxes instance
gdf : GeoDataFrame
GeoPandas GeoDataFrame object
"""
if hasattr(ax_or_fig, "projection"):
gdf = _clean_polygons(data=gdf, projection=ax_or_fig.projection)
else:
kwargs["transform"] = DEFAULT_CARTOPY_PROJECTION

return gdf.plot(ax=ax_or_fig, **kwargs)

def plot_pygplates_features(self, ax_or_fig, features, **kwargs):
"""TODO"""
pass

def plot_subduction_zones(
self,
ax_or_fig,
gdf_subduction_left: GeoDataFrame,
gdf_subduction_right: GeoDataFrame,
color="blue",
**kwargs,
):
"""Plot subduction zones with "teeth" using pygmt
Parameters
----------
ax_or_fig : cartopy.mpl.geoaxes.GeoAxes
Cartopy GeoAxes instance
gdf_subduction_left : GeoDataFrame
subduction zone with "left" polarity
gdf_subduction_right : GeoDataFrame
subduction zone with "right" polarity
color : str
The colour used to fill the "teeth".
"""
if "transform" in kwargs.keys():
logger.warning(
"'transform' keyword argument is ignored by CartopyPlotEngine."
)
kwargs.pop("transform")

spacing = kwargs.pop("spacing")
size = kwargs.pop("size")
aspect = kwargs.pop("aspect")

try:
projection = ax_or_fig.projection
except AttributeError:
logger.warning(
"The ax.projection does not exist. You must set projection to plot Cartopy maps, such as ax = plt.subplot(211, projection=cartopy.crs.PlateCarree())"
)
projection = None

if isinstance(projection, ccrs.PlateCarree):
spacing = math.degrees(spacing)
else:
spacing = spacing * EARTH_RADIUS * 1e3

if aspect is None:
aspect = 2.0 / 3.0
if size is None:
size = spacing * 0.5

height = size * aspect

plot_subduction_teeth(
gdf_subduction_left,
size,
"l",
height,
spacing,
projection=projection,
ax=ax_or_fig,
color=color,
**kwargs,
)
plot_subduction_teeth(
gdf_subduction_right,
size,
"r",
height,
spacing,
projection=projection,
ax=ax_or_fig,
color=color,
**kwargs,
)
50 changes: 50 additions & 0 deletions gplately/mapping/plot_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#
# Copyright (C) 2024 The University of Sydney, Australia
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License, version 2, as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#

from enum import Enum
from abc import ABC, abstractmethod

from geopandas.geodataframe import GeoDataFrame


class PlotEngineType(Enum):
CARTOPY = 1
PYGMT = 2


class PlotEngine(ABC):
@abstractmethod
def plot_geo_data_frame(self, ax_or_fig, gdf: GeoDataFrame, **kwargs):
"""Plot GeoPandas GeoDataFrame object"""
pass # This is an abstract method, no implementation here.

@abstractmethod
def plot_pygplates_features(self, ax_or_fig, features, **kwargs):
"""Plot one or more pygplates feature(s)"""
pass # This is an abstract method, no implementation here.

@abstractmethod
def plot_subduction_zones(
self,
ax_or_fig,
gdf_subduction_left: GeoDataFrame,
gdf_subduction_right: GeoDataFrame,
color="blue",
**kwargs,
):
"""Plot subduction zones with "teeth" """
pass # This is an abstract method, no implementation here.
135 changes: 135 additions & 0 deletions gplately/mapping/pygmt_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#
# Copyright (C) 2024 The University of Sydney, Australia
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License, version 2, as published by
# the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
# for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
from geopandas.geodataframe import GeoDataFrame
import pygmt
from .plot_engine import PlotEngine

pygmt.config(
FONT_ANNOT=8,
FONT_LABEL=8,
FONT=8,
MAP_TICK_PEN="0.75p",
MAP_FRAME_PEN="0.75p",
MAP_TICK_LENGTH_PRIMARY="4p",
)

# NW's example is at https://gist.github.com/nickywright/f53018a8eda29223cca6f39ab2cfa25d


class PygmtPlotEngine(PlotEngine):
def __init__(self):
pass

def plot_geo_data_frame(self, ax_or_fig, gdf: GeoDataFrame, **kwargs):
"""Plot GeoDataFrame object with pygmt
Parameters
----------
ax_or_fig : pygmt.Figure()
pygmt Figure object
gdf : GeoDataFrame
GeoPandas GeoDataFrame object
edgecolor : str
For polygons, it is the border colour. For polylines, it is the line colour.
Currently, only colour names are tested and officially supported, for example, "red", "blue", etc.
facecolor : str
The colour used to fill the polygon.
fill : str
GMT "fill" parameter
pen : str
GMT "pen" parameter
style : str
GMT "style" parameter
gmtlabel : str
GMT "label" parameter
"""
line_color = kwargs.pop("edgecolor", "blue")
line_width = f"{kwargs.pop('linewidth',0.1)}p"

fill = kwargs.pop("facecolor", None)
if fill and fill.lower() == "none":
fill = None
fill = kwargs.pop("fill", fill) # the "fill" parameter override the "facecolor"

if line_color.lower() == "none":
# line_width = "0"
# line_color = fill
pen = None
else:
pen = kwargs.pop("pen", f"{line_width},{line_color}")
style = kwargs.pop("style", None)
label = kwargs.pop("gmtlabel", None)

ax_or_fig.plot(
data=gdf.geometry,
pen=pen,
fill=fill,
style=style,
transparency=0,
label=label,
)

def plot_pygplates_features(self, ax_or_fig, features, **kwargs):
"""TODO"""
pass

def plot_subduction_zones(
self,
ax_or_fig,
gdf_subduction_left: GeoDataFrame,
gdf_subduction_right: GeoDataFrame,
color="blue",
**kwargs,
):
"""Plot subduction zones with "teeth" using pygmt
Parameters
----------
ax_or_fig : pygmt.Figure()
pygmt Figure object
gdf_subduction_left : GeoDataFrame
subduction zone with "left" polarity
gdf_subduction_right : GeoDataFrame
subduction zone with "right" polarity
color : str
The colour used to fill the "teeth".
gmtlabel : str
GMT "label" parameter
"""
label = kwargs.pop("gmtlabel", None)

ax_or_fig.plot(
data=gdf_subduction_left,
pen=f"0.5p,{color}",
fill=color,
style="f0.2/0.08+l+t",
label=label,
)
ax_or_fig.plot(
data=gdf_subduction_right,
pen=f"0.5p,{color}",
fill=color,
style="f0.2/0.08+r+t",
)


def get_pygmt_basemap_figure(projection="N180/10c", region="d"):
"""return a pygmt.Figure() object"""
fig = pygmt.Figure()
fig.basemap(region=region, projection=projection, frame="lrtb")
return fig
Loading

0 comments on commit 053dd33

Please sign in to comment.