Skip to content

Commit

Permalink
add fix_crossovers and remove_plate_rotations as gplately subcommands (
Browse files Browse the repository at this point in the history
…#128)

* move  feature_filter add_parser() into feature_filter.py to prevent __main__.py growing too big and messy

* add fix_crossover as a subcommand

* add remove_plate_rotations as subcommand
  • Loading branch information
michaelchin authored Nov 8, 2023
1 parent 533a238 commit da681cf
Show file tree
Hide file tree
Showing 5 changed files with 753 additions and 485 deletions.
123 changes: 28 additions & 95 deletions gplately/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
import os
import sys
import warnings
from typing import (
List,
Optional,
Sequence,
Union,
)
from typing import List, Optional, Sequence, Union

import pygplates

from gplately import __version__, feature_filter

from .ptt import fix_crossovers, remove_plate_rotations


def combine_feature_collections(input_files: List[str], output_file: str):
"""Combine multiple feature collections into one."""
Expand All @@ -34,53 +31,6 @@ def _run_combine_feature_collections(args):
)


def filter_feature_collection(args):
"""Filter the input feature collection according to command line arguments."""
input_feature_collection = pygplates.FeatureCollection(args.filter_input_file)

filters = []
if args.names:
filters.append(
feature_filter.FeatureNameFilter(
args.names,
exact_match=args.exact_match,
case_sensitive=args.case_sensitive,
)
)
elif args.exclude_names:
filters.append(
feature_filter.FeatureNameFilter(
args.exclude_names,
exclude=True,
exact_match=args.exact_match,
case_sensitive=args.case_sensitive,
)
)

if args.pids:
filters.append(feature_filter.PlateIDFilter(args.pids))
elif args.exclude_pids:
filters.append(feature_filter.PlateIDFilter(args.exclude_pids, exclude=True))

# print(args.max_birth_age)
if args.max_birth_age is not None:
filters.append(
feature_filter.BirthAgeFilter(args.max_birth_age, keep_older=False)
)
elif args.min_birth_age is not None:
filters.append(feature_filter.BirthAgeFilter(args.min_birth_age))

new_fc = feature_filter.filter_feature_collection(
input_feature_collection,
filters,
)

new_fc.write(args.filter_output_file)
print(
f"Done! The filtered feature collection has been saved to {args.filter_output_file}."
)


def create_agegrids(
input_filenames: Union[str, Sequence[str]],
continents_filenames: Union[str, Sequence[str]],
Expand All @@ -97,15 +47,9 @@ def create_agegrids(
unmasked: bool = False,
) -> None:
"""Create age grids for a plate model."""
from gplately import (
PlateReconstruction,
PlotTopologies,
SeafloorGrid,
)
from gplately import PlateReconstruction, PlotTopologies, SeafloorGrid

features = pygplates.FeaturesFunctionArgument(
input_filenames
).get_features()
features = pygplates.FeaturesFunctionArgument(input_filenames).get_features()
rotations = []
topologies = []
for i in features:
Expand All @@ -123,9 +67,7 @@ def create_agegrids(
continents = pygplates.FeatureCollection()
else:
continents = pygplates.FeatureCollection(
pygplates.FeaturesFunctionArgument(
continents_filenames
).get_features()
pygplates.FeaturesFunctionArgument(continents_filenames).get_features()
)

with warnings.catch_warnings():
Expand Down Expand Up @@ -192,16 +134,17 @@ def main():
title="subcommands",
description="valid subcommands",
)

# add combine feature sub-command
combine_cmd = subparser.add_parser(
"combine",
help=combine_feature_collections.__doc__,
description=combine_feature_collections.__doc__,
)
filter_cmd = subparser.add_parser(
"filter",
help=filter_feature_collection.__doc__,
description=filter_feature_collection.__doc__,
)

# add feature filter sub-command
feature_filter.add_parser(subparser)

agegrid_cmd = subparser.add_parser(
"agegrid",
aliases=("ag",),
Expand All @@ -210,38 +153,28 @@ def main():
description=create_agegrids.__doc__,
)

# add fix crossovers sub-command
fix_crossovers_cmd = subparser.add_parser(
"fix_crossovers",
help="fix crossovers",
add_help=True,
)
fix_crossovers.add_arguments(fix_crossovers_cmd)

# add remove plate rotations sub-command
remove_plate_rotations_cmd = subparser.add_parser(
"remove_rotations",
help="remove plate rotations",
add_help=True,
)
remove_plate_rotations.add_arguments(remove_plate_rotations_cmd)

# combine command arguments
combine_cmd.set_defaults(func=_run_combine_feature_collections)
combine_cmd.add_argument("combine_first_input_file", type=str)
combine_cmd.add_argument("combine_other_input_files", nargs="+", type=str)
combine_cmd.add_argument("combine_output_file", type=str)

# feature filter command arguments
filter_cmd.set_defaults(func=filter_feature_collection)
filter_cmd.add_argument("filter_input_file", type=str)
filter_cmd.add_argument("filter_output_file", type=str)

name_group = filter_cmd.add_mutually_exclusive_group()
name_group.add_argument("-n", "--names", type=str, dest="names", nargs="+")
name_group.add_argument(
"--exclude-names", type=str, dest="exclude_names", nargs="+"
)

pid_group = filter_cmd.add_mutually_exclusive_group()
pid_group.add_argument("-p", "--pids", type=int, dest="pids", nargs="+")
pid_group.add_argument("--exclude-pids", type=int, dest="exclude_pids", nargs="+")

birth_age_group = filter_cmd.add_mutually_exclusive_group()
birth_age_group.add_argument(
"-a", "--min-birth-age", type=float, dest="min_birth_age"
)
birth_age_group.add_argument("--max-birth-age", type=float, dest="max_birth_age")

filter_cmd.add_argument(
"--case-sensitive", dest="case_sensitive", action="store_true"
)
filter_cmd.add_argument("--exact-match", dest="exact_match", action="store_true")

# agegrid command arguments
agegrid_cmd.set_defaults(func=_run_create_agegrids)
agegrid_cmd.add_argument(
Expand Down
80 changes: 80 additions & 0 deletions gplately/feature_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,83 @@ def filter_feature_collection(
if keep_flag:
new_feature_collection.add(feature)
return new_feature_collection


def add_parser(subparser):
"""add feature filter command line argument parser"""
filter_cmd = subparser.add_parser(
"filter",
help=filter_feature_collection.__doc__,
description=filter_feature_collection.__doc__,
)

# feature filter command arguments
filter_cmd.set_defaults(func=run_filter_feature_collection)
filter_cmd.add_argument("filter_input_file", type=str)
filter_cmd.add_argument("filter_output_file", type=str)

name_group = filter_cmd.add_mutually_exclusive_group()
name_group.add_argument("-n", "--names", type=str, dest="names", nargs="+")
name_group.add_argument(
"--exclude-names", type=str, dest="exclude_names", nargs="+"
)

pid_group = filter_cmd.add_mutually_exclusive_group()
pid_group.add_argument("-p", "--pids", type=int, dest="pids", nargs="+")
pid_group.add_argument("--exclude-pids", type=int, dest="exclude_pids", nargs="+")

birth_age_group = filter_cmd.add_mutually_exclusive_group()
birth_age_group.add_argument(
"-a", "--min-birth-age", type=float, dest="min_birth_age"
)
birth_age_group.add_argument("--max-birth-age", type=float, dest="max_birth_age")

filter_cmd.add_argument(
"--case-sensitive", dest="case_sensitive", action="store_true"
)
filter_cmd.add_argument("--exact-match", dest="exact_match", action="store_true")


def run_filter_feature_collection(args):
"""Filter the input feature collection according to command line arguments."""
input_feature_collection = pygplates.FeatureCollection(args.filter_input_file)

filters = []
if args.names:
filters.append(
FeatureNameFilter(
args.names,
exact_match=args.exact_match,
case_sensitive=args.case_sensitive,
)
)
elif args.exclude_names:
filters.append(
FeatureNameFilter(
args.exclude_names,
exclude=True,
exact_match=args.exact_match,
case_sensitive=args.case_sensitive,
)
)

if args.pids:
filters.append(PlateIDFilter(args.pids))
elif args.exclude_pids:
filters.append(PlateIDFilter(args.exclude_pids, exclude=True))

# print(args.max_birth_age)
if args.max_birth_age is not None:
filters.append(BirthAgeFilter(args.max_birth_age, keep_older=False))
elif args.min_birth_age is not None:
filters.append(BirthAgeFilter(args.min_birth_age))

new_fc = filter_feature_collection(
input_feature_collection,
filters,
)

new_fc.write(args.filter_output_file)
print(
f"Done! The filtered feature collection has been saved to {args.filter_output_file}."
)
2 changes: 1 addition & 1 deletion gplately/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,7 +2367,7 @@ def plot_subduction_teeth(
projection = ax.projection
except AttributeError:
print(
"The ax.projection does not exist. You must set project to plot Cartopy maps, such as ax = plt.subplot(211, projection=cartopy.crs.PlateCarree())"
"The ax.projection does not exist. You must set projection to plot Cartopy maps, such as ax = plt.subplot(211, projection=cartopy.crs.PlateCarree())"
)
projection = None

Expand Down
Loading

0 comments on commit da681cf

Please sign in to comment.