diff --git a/augur/filter/__init__.py b/augur/filter/__init__.py index 299b506bd..3a6c5d864 100644 --- a/augur/filter/__init__.py +++ b/augur/filter/__init__.py @@ -88,10 +88,11 @@ def register_arguments(parser): (1) Any ``--group-by`` columns absent from this file will be given equal weighting across all values *within* groups defined by the other weighted columns. - (2) All combinations of weighted column values that are present in the - metadata must be included in this file. Absence from this file will - cause augur filter to exit with an error describing how to add the - weights explicitly. + (2) An entry with the value ``default`` under all columns will be + treated as the default weight for specific groups present in the + metadata but missing from the weights file. If there is no default + weight and the metadata contains rows that are not covered by the + given weights, augur filter will exit with an error. """) subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "\\t") and no header. When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group. @@ -104,11 +105,6 @@ def register_arguments(parser): output_group.add_argument('--output-metadata', help="metadata for strains that passed filters") output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)") output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.") - output_group.add_argument('--output-group-by-missing-weights', type=str, metavar="FILE", help=""" - TSV file formatted for --group-by-weights with an empty weight column. - Represents groups with entries in --metadata but absent from - --group-by-weights. - """) output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.") output_group.add_argument( '--empty-output-reporting', diff --git a/augur/filter/_run.py b/augur/filter/_run.py index 144251d0d..84e7bac3f 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -286,7 +286,6 @@ def run(args): group_by, args.group_by_weights, args.subsample_max_sequences, - args.output_group_by_missing_weights, args.output_group_by_sizes, args.subsample_seed, ) diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index cf340ce86..35fd7ab62 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -12,7 +12,7 @@ from augur.io.metadata import METADATA_DATE_COLUMN from augur.io.print import print_err from . import constants -from .weights_file import WEIGHTS_COLUMN, get_weighted_columns, read_weights_file +from .weights_file import WEIGHTS_COLUMN, COLUMN_VALUE_FOR_DEFAULT_WEIGHT, get_default_weight, get_weighted_columns, read_weights_file Group = Tuple[str, ...] """Combinations of grouping column values in tuple form.""" @@ -315,7 +315,6 @@ def get_weighted_group_sizes( group_by: List[str], weights_file: str, target_total_size: int, - output_missing_weights: Optional[str], output_sizes_file: Optional[str], random_seed: Optional[int], ) -> Dict[Group, int]: @@ -336,17 +335,17 @@ def get_weighted_group_sizes( # weights to ensure equal weighting of unweighted columns *within* each # weighted group defined by the weighted columns. weights = _add_unweighted_columns(weights, groups, group_by, unweighted_columns) + + weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups) weights = _drop_unused_groups(weights, groups, group_by) + weights = _adjust_weights_for_unweighted_columns(weights, weighted_columns, unweighted_columns) else: + weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups) weights = _drop_unused_groups(weights, groups, group_by) weights = _calculate_weighted_group_sizes(weights, target_total_size, random_seed) - missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1)) - if missing_groups: - weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, missing_groups, output_missing_weights) - # Add columns to summarize the input data weights[INPUT_SIZE_COLUMN] = weights.apply(lambda row: records_per_group[tuple(row[group_by].values)], axis=1) weights[OUTPUT_SIZE_COLUMN] = weights[[INPUT_SIZE_COLUMN, TARGET_SIZE_COLUMN]].min(axis=1) @@ -462,11 +461,15 @@ def _handle_incomplete_weights( weights_file: str, weighted_columns: List[str], group_by: List[str], - missing_groups: Collection[Group], - output_missing_weights: Optional[str], + groups: Iterable[Group], ) -> pd.DataFrame: """Handle the case where the weights file does not cover all rows in the metadata. """ + missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1)) + + if not missing_groups: + return weights + # Collect the column values that are missing weights. missing_values_by_column = defaultdict(set) for group in missing_groups: @@ -478,27 +481,26 @@ def _handle_incomplete_weights( for column in weighted_columns: missing_values_by_column[column].add(column_to_value_map[column]) - columns_with_values = '\n - '.join(f'{column!r}: {list(values)}' for column, values in missing_values_by_column.items()) - if not output_missing_weights: + columns_with_values = '\n - '.join(f'{column!r}: {list(sorted(values))}' for column, values in sorted(missing_values_by_column.items())) + + default_weight = get_default_weight(weights, weighted_columns) + + if not default_weight: raise AugurError(dedent(f"""\ The input metadata contains these values under the following columns that are not covered by {weights_file!r}: - {columns_with_values} - Re-run with --output-group-by-missing-weights to continue.""")) + To fix this, either: + (1) specify weights explicitly - add entries to {weights_file!r} for the values above, or + (2) specify a default weight - add an entry to {weights_file!r} with the value {COLUMN_VALUE_FOR_DEFAULT_WEIGHT!r} for all columns""")) else: - missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by) - missing_weights_weighted_columns_only = missing_weights[weighted_columns].drop_duplicates() - missing_weights_weighted_columns_only[WEIGHTS_COLUMN] = '' - missing_weights_weighted_columns_only.to_csv(output_missing_weights, index=False, sep='\t') print_err(dedent(f"""\ - The input metadata contains these values under the following columns that are not covered by {weights_file!r}: + WARNING: The input metadata contains these values under the following columns that are not directly covered by {weights_file!r}: - {columns_with_values} - Sequences associated with these values will be dropped. - A separate weights file has been generated with implicit weight of zero for these values: {output_missing_weights!r} - Consider updating {weights_file!r} with nonzero weights and re-running without --output-group-by-missing-weights.""")) + The default weight of {default_weight!r} will be used for all groups defined by those values.""")) - # Set the weight for these groups to zero, effectively dropping all sequences. - missing_weights[TARGET_SIZE_COLUMN] = 0 - return pd.merge(weights, missing_weights, on=[*group_by, TARGET_SIZE_COLUMN], how='outer') + missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by) + missing_weights[WEIGHTS_COLUMN] = default_weight + return pd.merge(weights, missing_weights, on=[*group_by, WEIGHTS_COLUMN], how='outer') def create_queues_by_group(max_sizes_per_group): diff --git a/augur/filter/weights_file.py b/augur/filter/weights_file.py index 8546bb22a..42e073e2c 100644 --- a/augur/filter/weights_file.py +++ b/augur/filter/weights_file.py @@ -1,9 +1,11 @@ import pandas as pd from textwrap import dedent +from typing import List from augur.errors import AugurError WEIGHTS_COLUMN = 'weight' +COLUMN_VALUE_FOR_DEFAULT_WEIGHT = 'default' class InvalidWeightsFile(AugurError): @@ -42,3 +44,12 @@ def get_weighted_columns(weights_file): raise InvalidWeightsFile(weights_file, "File is empty.") columns.remove(WEIGHTS_COLUMN) return columns + + +def get_default_weight(weights: pd.DataFrame, weighted_columns: List[str]): + default_weight_values = weights[(weights[weighted_columns] == COLUMN_VALUE_FOR_DEFAULT_WEIGHT).all(axis=1)][WEIGHTS_COLUMN].unique() + + if len(default_weight_values) > 1: + raise InvalidWeightsFile(f"Multiple default weights were specified: {', '.join(repr(weight) for weight in default_weight_values)}. Only one default weight entry can be accepted.") + if len(default_weight_values) == 1: + return default_weight_values[0] diff --git a/tests/functional/filter/cram/subsample-weighted.t b/tests/functional/filter/cram/subsample-weighted.t index 7ec8b5da4..bc5498743 100644 --- a/tests/functional/filter/cram/subsample-weighted.t +++ b/tests/functional/filter/cram/subsample-weighted.t @@ -89,10 +89,18 @@ Sampling with incomplete weights should show an error. Sampling with weights defined by weights.tsv. ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv': - 'location': ['B'] - Re-run with --output-group-by-missing-weights to continue. + To fix this, either: + (1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or + (2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns [2] -Re-running with --output-group-by-missing-weights shows a warning and a file to use for fixing. +Re-running with a default weight shows a warning and continues. + + $ cat >weights.tsv <<~~ + > location weight + > A 2 + > default 1 + > ~~ $ ${AUGUR} filter \ > --metadata metadata.tsv \ @@ -100,23 +108,65 @@ Re-running with --output-group-by-missing-weights shows a warning and a file to > --group-by-weights weights.tsv \ > --subsample-max-sequences 6 \ > --subsample-seed 0 \ - > --output-group-by-missing-weights missing-weights.tsv \ > --output-strains strains.txt Sampling with weights defined by weights.tsv. - NOTE: Skipping 1 group due to lack of entries in metadata. + WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv': + - 'location': ['B'] + The default weight of 1 will be used for all groups defined by those values. + NOTE: Skipping 4 groups due to lack of entries in metadata. NOTE: Weights were not provided for the column 'month'. Using equal weights across values in that column. - The input metadata contains these values under the following columns that are not covered by 'weights.tsv': + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters + +To specify a default weight, the value 'default' must be set for all weighted columns. + + $ cat >weights.tsv <<~~ + > location month weight + > A 2000-01 2 + > A 2000-02 2 + > default 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv': - 'location': ['B'] - Sequences associated with these values will be dropped. - A separate weights file has been generated with implicit weight of zero for these values: 'missing-weights.tsv' - Consider updating 'weights.tsv' with nonzero weights and re-running without --output-group-by-missing-weights. - 4 strains were dropped during filtering - 4 were dropped because of subsampling criteria - 4 strains passed all filters - - $ cat missing-weights.tsv - location weight - B + - 'month': ['2000-01', '2000-03'] + To fix this, either: + (1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or + (2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns + [2] + + $ cat >weights.tsv <<~~ + > location month weight + > A 2000-01 2 + > A 2000-02 2 + > default default 1 + > ~~ + + $ ${AUGUR} filter \ + > --metadata metadata.tsv \ + > --group-by month location \ + > --group-by-weights weights.tsv \ + > --subsample-max-sequences 6 \ + > --subsample-seed 0 \ + > --output-strains strains.txt + Sampling with weights defined by weights.tsv. + WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv': + - 'location': ['B'] + - 'month': ['2000-01', '2000-03'] + The default weight of 1 will be used for all groups defined by those values. + NOTE: Skipping 1 group due to lack of entries in metadata. + 2 strains were dropped during filtering + 2 were dropped because of subsampling criteria + 6 strains passed all filters When --group-by-weights is specified, all columns must be provided in --group-by.