Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add agegrid CLI #115

Merged
merged 2 commits into from
Oct 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 231 additions & 17 deletions gplately/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
import argparse
import os
import sys
from typing import List
import warnings
from typing import (
List,
Optional,
Sequence,
Union,
)

import pygplates

from gplately import __version__, feature_filter


def combine_feature_collections(input_files: List[str], output_file: str):
"""combine multiply feature collections into one"""
"""Combine multiple feature collections into one."""
feature_collection = pygplates.FeatureCollection()
for file in input_files:
if not os.path.isfile(file):
Expand All @@ -21,8 +27,15 @@ def combine_feature_collections(input_files: List[str], output_file: str):
print(f"Done! The combined feature collection has been saved to {output_file}.")


def _run_combine_feature_collections(args):
combine_feature_collections(
[args.combine_first_input_file] + args.combine_other_input_files,
args.combine_output_file,
)


def filter_feature_collection(args):
"""filter the input feature collection according to command line arguments"""
"""Filter the input feature collection according to command line arguments."""
input_feature_collection = pygplates.FeatureCollection(args.filter_input_file)

filters = []
Expand Down Expand Up @@ -68,6 +81,99 @@ def filter_feature_collection(args):
)


def create_agegrids(
input_filenames: Union[str, Sequence[str]],
continents_filenames: Union[str, Sequence[str]],
output_dir: str,
min_time: float,
max_time: float,
ridge_time_step: float = 1,
n_jobs: int = 1,
refinement_levels: int = 5,
grid_spacing: float = 0.1,
ridge_sampling: float = 0.5,
initial_spreadrate: float = 75,
file_collection: Optional[str] = None,
unmasked: bool = False,
) -> None:
"""Create age grids for a plate model."""
from gplately import (
PlateReconstruction,
PlotTopologies,
SeafloorGrid,
)

features = pygplates.FeaturesFunctionArgument(
input_filenames
).get_features()
rotations = []
topologies = []
for i in features:
if (
i.get_feature_type().to_qualified_string()
== "gpml:TotalReconstructionSequence"
):
rotations.append(i)
else:
topologies.append(i)
topologies = pygplates.FeatureCollection(topologies)
rotations = pygplates.RotationModel(rotations)

if continents_filenames is None:
continents = pygplates.FeatureCollection()
else:
continents = pygplates.FeatureCollection(
pygplates.FeaturesFunctionArgument(
continents_filenames
).get_features()
)

with warnings.catch_warnings():
warnings.simplefilter("ignore", ImportWarning)
reconstruction = PlateReconstruction(
rotation_model=rotations,
topology_features=topologies,
)
gplot = PlotTopologies(
reconstruction,
continents=continents,
)

grid = SeafloorGrid(
reconstruction,
gplot,
min_time=min_time,
max_time=max_time,
save_directory=output_dir,
ridge_time_step=ridge_time_step,
refinement_levels=refinement_levels,
grid_spacing=grid_spacing,
ridge_sampling=ridge_sampling,
initial_ocean_mean_spreading_rate=initial_spreadrate,
file_collection=file_collection,
)
grid.reconstruct_by_topologies()
for val in ("SEAFLOOR_AGE", "SPREADING_RATE"):
grid.lat_lon_z_to_netCDF(val, unmasked=unmasked, nprocs=n_jobs)


def _run_create_agegrids(args):
create_agegrids(
input_filenames=args.input_filenames,
continents_filenames=args.continents_filenames,
output_dir=args.output_dir,
min_time=args.min_time,
max_time=args.max_time,
n_jobs=args.n_jobs,
refinement_levels=args.refinement_levels,
grid_spacing=args.grid_spacing,
ridge_sampling=args.ridge_sampling,
initial_spreadrate=args.initial_spreadrate,
file_collection=args.file_collection,
unmasked=args.unmasked,
)


class ArgParser(argparse.ArgumentParser):
def error(self, message):
sys.stderr.write(f"error: {message}\n")
Expand All @@ -81,16 +187,37 @@ def main():
parser.add_argument("-v", "--version", action="store_true")

# sub-commands
subparser = parser.add_subparsers(dest="command")
combine_cmd = subparser.add_parser("combine")
filter_cmd = subparser.add_parser("filter")
subparser = parser.add_subparsers(
dest="command",
title="subcommands",
description="valid subcommands",
)
combine_cmd = subparser.add_parser(
"combine",
help=combine_feature_collections.__doc__,
description=combine_feature_collections.__doc__,
)
filter_cmd = subparser.add_parser(
"filter",
help=filter_feature_collection.__doc__,
description=filter_feature_collection.__doc__,
)
agegrid_cmd = subparser.add_parser(
"agegrid",
aliases=("ag",),
help=create_agegrids.__doc__,
add_help=True,
description=create_agegrids.__doc__,
)

# combine command arguments
combine_cmd.set_defaults(func=_run_combine_feature_collections)
combine_cmd.add_argument("combine_first_input_file", type=str)
combine_cmd.add_argument("combine_other_input_files", nargs="+", type=str)
combine_cmd.add_argument("combine_output_file", type=str)

# feature filter command arguments
filter_cmd.set_defaults(func=filter_feature_collection)
filter_cmd.add_argument("filter_input_file", type=str)
filter_cmd.add_argument("filter_output_file", type=str)

Expand All @@ -115,6 +242,103 @@ def main():
)
filter_cmd.add_argument("--exact-match", dest="exact_match", action="store_true")

# agegrid command arguments
agegrid_cmd.set_defaults(func=_run_create_agegrids)
agegrid_cmd.add_argument(
metavar="INPUT_FILE",
nargs="+",
help="input reconstruction files",
dest="input_filenames",
)
agegrid_cmd.add_argument(
metavar="OUTPUT_DIR",
help="output directory",
dest="output_dir",
)
agegrid_cmd.add_argument(
"-c",
"--continents",
metavar="CONTINENTS_FILE",
nargs="+",
help="input continent files",
dest="continents_filenames",
default=None,
)
agegrid_cmd.add_argument(
"-r",
"--resolution",
metavar="RESOLUTION",
type=float,
help="grid resolution (degrees); default: 0.1",
default=0.1,
dest="grid_spacing",
)
agegrid_cmd.add_argument(
"--refinement-levels",
metavar="LEVELS",
type=int,
help="mesh refinement levels; default: 5",
default=5,
dest="refinement_levels",
)
agegrid_cmd.add_argument(
"--ridge-sampling",
metavar="RESOLUTION",
type=float,
help="MOR sampling resolution (degrees); default: 0.5",
default=0.5,
dest="ridge_sampling",
)
agegrid_cmd.add_argument(
"--initial-spreadrate",
metavar="SPREADRATE",
type=float,
help="initial ocean spreading rate (km/Myr); default: 75",
default=75,
dest="initial_spreadrate",
)
agegrid_cmd.add_argument(
"-e",
"--min-time",
metavar="MIN_TIME",
type=float,
help="minimum time (Ma); default: 0",
default=0,
dest="min_time",
)
agegrid_cmd.add_argument(
"-s",
"--max-time",
metavar="MAX_TIME",
type=float,
help="maximum time (Ma); default: 0",
default=0,
dest="max_time",
)
agegrid_cmd.add_argument(
"-j",
"--n_jobs",
help="number of processes to use; default: 1",
metavar="N_JOBS",
default=1,
dest="n_jobs",
)
agegrid_cmd.add_argument(
"-f",
"--file-collection",
help="file collection name (optional)",
metavar="NAME",
default=None,
dest="file_collection",
)
agegrid_cmd.add_argument(
"-u",
"--include-unmasked",
help="create unmasked grids in addition to masked ones",
action="store_true",
dest="unmasked",
)

if len(sys.argv) == 1:
parser.print_help(sys.stderr)
sys.exit(1)
Expand All @@ -125,17 +349,7 @@ def main():
print(__version__)
sys.exit(0)

if args.command == "combine":
combine_feature_collections(
[args.combine_first_input_file] + args.combine_other_input_files,
args.combine_output_file,
)
elif args.command == "filter":
filter_feature_collection(args)
else:
print(f"Unknow command {args.command}!")
parser.print_help(sys.stderr)
sys.exit(1)
args.func(args)


if __name__ == "__main__":
Expand Down
Loading