Skip to content

Commit

Permalink
Merge pull request #21 from NOAA-GFDL/more-robust-plugin-discovery
Browse files Browse the repository at this point in the history
More robust plugin discovery + analysis script updates
  • Loading branch information
ceblanton authored Jan 15, 2025
2 parents 97f61dc + e2cd82c commit 80e9aac
Show file tree
Hide file tree
Showing 15 changed files with 1,055 additions and 580 deletions.
95 changes: 74 additions & 21 deletions core/analysis_scripts/analysis_scripts/plugins.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,87 @@
import importlib
import inspect
from pathlib import Path
import pkgutil

from .base_class import AnalysisScript


# Dictionary of found plugins.
discovered_plugins = {}
for finder, name, ispkg in pkgutil.iter_modules():
if name.startswith("freanalysis_"):
discovered_plugins[name] = importlib.import_module(name)


class UnknownPluginError(BaseException):
"""Custom exception for when an invalid plugin name is used."""
pass


def _find_plugin_class(module):
"""Looks for a class that inherits from AnalysisScript.
Args:
module: Module object.
Returns:
Class that inherits from AnalysisScript.
Raises:
UnknownPluginError if no class is found.
"""
for attribute in vars(module).values():
# Try to find a class that inherits from the AnalysisScript class.
if inspect.isclass(attribute) and AnalysisScript in attribute.__bases__:
# Return the class so an object can be instantiated from it later.
return attribute
raise UnknownPluginError("could not find class that inherts from AnalysisScripts")


_sanity_counter = 0 # How much recursion is happening.
_maximum_craziness = 100 # This is too much recursion.


def _recursive_search(name, ispkg):
"""Recursively search for a module that has a class that inherits from AnalysisScript.
Args:
name: String name of the module.
ispkg: Flag telling whether or not the module is a package.
Returns:
Class that inherits from AnalysisScript.
Raises:
UnknownPluginError if no class is found.
ValueError if there is too much recursion.
"""
global _sanity_counter
_sanity_counter += 1
if _sanity_counter > _maximum_craziness:
raise ValueError(f"recursion level {_sanity_counter} too high.")

module = importlib.import_module(name)
try:
return _find_plugin_class(module)
except UnknownPluginError:
if not ispkg:
# Do not recurse further.
raise
paths = module.__spec__.submodule_search_locations
for finder, subname, ispkg in pkgutil.iter_modules(paths):
subname = f"{name}.{subname}"
try:
return _recursive_search(subname, ispkg)
except UnknownPluginError:
# Didn't find it, so continue to iterate.
pass


# Dictionary of found plugins.
_discovered_plugins = {}
for finder, name, ispkg in pkgutil.iter_modules():
if name.startswith("freanalysis_") and ispkg:
_sanity_counter = 0
_discovered_plugins[name] = _recursive_search(name, True)


def _plugin_object(name):
"""Searches for a class that inherits from AnalysisScript in the plugin module.
"""Attempts to create an object from a class that inherits from AnalysisScript in
the plugin module.
Args:
name: Name of the plugin.
Expand All @@ -27,27 +90,17 @@ def _plugin_object(name):
The object that inherits from AnalysisScript.
Raises:
KeyError if the input name does not match any installed plugins.
ValueError if no object that inhertis from AnalysisScript is found in the
plugin module.
UnknownPluginError if the input name is not in the disovered_plugins dictionary.
"""
# Loop through all attributes in the plugin package with the input name.
try:
plugin_module = discovered_plugins[name]
return _discovered_plugins[name]()
except KeyError:
raise UnknownPluginError(f"could not find analysis script plugin {name}.")

for attribute in vars(plugin_module).values():
# Try to find a class that inherits from the AnalysisScript class.
if inspect.isclass(attribute) and AnalysisScript in attribute.__bases__:
# Instantiate an object of this class.
return attribute()
raise ValueError(f"could not find compatible object in {name}.")


def available_plugins():
"""Returns a list of plugin names."""
return sorted(list(discovered_plugins.keys()))
return sorted(list(_discovered_plugins.keys()))


def list_plugins():
Expand Down
3 changes: 2 additions & 1 deletion core/figure_tools/figure_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .anomaly_timeseries import AnomalyTimeSeries
from .common_plots import observation_vs_model_maps, radiation_decomposition, \
timeseries_and_anomalies, zonal_mean_vertical_and_column_integrated_map
timeseries_and_anomalies, zonal_mean_vertical_and_column_integrated_map, \
chuck_radiation
from .figure import Figure
from .global_mean_timeseries import GlobalMeanTimeSeries
from .lon_lat_map import LonLatMap
Expand Down
24 changes: 24 additions & 0 deletions core/figure_tools/figure_tools/common_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,30 @@
from .figure import Figure


def chuck_radiation(reference, model, title):
figure = Figure(num_rows=3, num_columns=1, title=title, size=(14, 12))

# Create a common color bar for the reference and model.
colorbar_range = [120, 340]

# Model data.
global_mean = model.global_mean()
figure.add_map(model, f"Model [Mean: {global_mean:.2f}]", 1,
colorbar_range=colorbar_range, num_levels=11, colormap="jet")

# Reference data.
global_mean = reference.global_mean()
figure.add_map(reference, f"Obersvations [Mean: {global_mean:.2f}]", 2,
colorbar_range=colorbar_range, num_levels=11, colormap="jet")

# Difference between the reference and the model.
difference = model - reference
color_range = [-34., 34.]
figure.add_map(difference, f"Model - Obs [Mean: {global_mean:.2f}]", 3,
colorbar_range=color_range, colormap="jet", normalize_colors=True)
return figure


def observation_vs_model_maps(reference, model, title):
figure = Figure(num_rows=2, num_columns=2, title=title, size=(14, 12))

Expand Down
21 changes: 19 additions & 2 deletions core/figure_tools/figure_tools/lon_lat_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __sub__(self, arg):

@classmethod
def from_xarray_dataset(cls, dataset, variable, time_method=None, time_index=None,
year=None):
year=None, year_range=None, month_range=None):
"""Instantiates a LonLatMap object from an xarray dataset."""
v = dataset.data_vars[variable]
data = array(v.data[...])
Expand All @@ -75,8 +75,25 @@ def from_xarray_dataset(cls, dataset, variable, time_method=None, time_index=Non
time = TimeSubset(time)
data = time.annual_mean(data, year)
timestamp = r"$\bar{t} = $" + str(year)
elif "climatology" in time_method:
if year_range == None or len(year_range) != 2:
raise ValueError("year_range is required ([star year, end year]" +
" when time_method is a climatology.")
if time_method == "annual climatology":
time = TimeSubset(time)
data = time.annual_climatology(data, year_range)
timestamp = f"{year_range[0]} - {year_range[1]} annual climatology"
elif time_method == "seasonal climatology":
if month_range == None or len(month_range) != 2:
raise ValueError("month_range is required ([start month, end month])" +
" when time_method='seasonal climatology'.")
time = TimeSubset(time)
data = time.seasonal_climatology(data, year_range, month_range)
timestamp = ""
else:
raise ValueError("time_method must be either 'instantaneous' or 'annual mean.'")
valid_values = ["instantaneous", "annual mean", "annual climatology",
"seasonal climatology"]
raise ValueError(f"time_method must one of :{valid_values}.")
else:
timestamp = None

Expand Down
39 changes: 39 additions & 0 deletions core/figure_tools/figure_tools/time_subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,45 @@ def __init__(self, data):
"""
self.data = data

def annual_climatology(self, data, year_range):
years = [x for x in range(year_range[0], year_range[1] + 1)]
sum_, counter = None, 0
for i, point in enumerate(self.data):
_, year = self._month_and_year(point)
if year in years:
if sum_ is None:
sum_ = data[i, ...]
else:
sum_ += data[i, ...]
counter += 1

if counter != 12*len(years):
raise ValueError("Expected monthly data and did not find correct number of months.")
return array(sum_[...]/counter)

def seasonal_climatology(self, data, year_range, month_range):
years = [x for x in range(year_range[0], year_range[1] + 1)]
if month_range[1] - month_range[0] < 0:
# We have crossed to the next year.
months = [x for x in range(month_range[0], 13)] + \
[x for x in range(1, month_range[1] + 1)]
else:
months = [x for x in range(month_range[0], month_range[1] + 1)]

sum_, counter = None, 0
for i, point in enumerate(self.data):
month, year = self._month_and_year(point)
if month in months and year in years:
if sum_ is None:
sum_ = data[i, ...]
else:
sum_ += data[i, ...]
counter += 1

if counter != len(months)*len(years):
raise ValueError("Expected monthly data and did not find enough months.")
return array(sum_[...]/counter)

def annual_mean(self, data, year):
"""Calculates the annual mean of the input date for the input year.
Expand Down
1 change: 1 addition & 0 deletions core/figure_tools/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"cartopy",
"matplotlib",
"numpy",
"scipy",
"xarray",
]
requires-python = ">= 3.6"
Expand Down
Loading

0 comments on commit 80e9aac

Please sign in to comment.