Skip to content

Commit

Permalink
🚧 Remove --output-group-by-missing-weights and allow default weight
Browse files Browse the repository at this point in the history
  • Loading branch information
victorlin committed Aug 13, 2024
1 parent 03dd3df commit a00d3b5
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 47 deletions.
14 changes: 5 additions & 9 deletions augur/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,11 @@ def register_arguments(parser):
(1) Any ``--group-by`` columns absent from this file will be given equal
weighting across all values *within* groups defined by the other
weighted columns.
(2) All combinations of weighted column values that are present in the
metadata must be included in this file. Absence from this file will
cause augur filter to exit with an error describing how to add the
weights explicitly.
(2) An entry with the value ``default`` under all columns will be
treated as the default weight for specific groups present in the
metadata but missing from the weights file. If there is no default
weight and the metadata contains rows that are not covered by the
given weights, augur filter will exit with an error.
""")
subsample_group.add_argument('--priority', type=str, help="""tab-delimited file with list of priority scores for strains (e.g., "<strain>\\t<priority>") and no header.
When scores are provided, Augur converts scores to floating point values, sorts strains within each subsampling group from highest to lowest priority, and selects the top N strains per group where N is the calculated or requested number of strains per group.
Expand All @@ -104,11 +105,6 @@ def register_arguments(parser):
output_group.add_argument('--output-metadata', help="metadata for strains that passed filters")
output_group.add_argument('--output-strains', help="list of strains that passed filters (no header)")
output_group.add_argument('--output-log', help="tab-delimited file with one row for each filtered strain and the reason it was filtered. Keyword arguments used for a given filter are reported in JSON format in a `kwargs` column.")
output_group.add_argument('--output-group-by-missing-weights', type=str, metavar="FILE", help="""
TSV file formatted for --group-by-weights with an empty weight column.
Represents groups with entries in --metadata but absent from
--group-by-weights.
""")
output_group.add_argument('--output-group-by-sizes', help="tab-delimited file one row per group with target size.")
output_group.add_argument(
'--empty-output-reporting',
Expand Down
1 change: 0 additions & 1 deletion augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def run(args):
group_by,
args.group_by_weights,
args.subsample_max_sequences,
args.output_group_by_missing_weights,
args.output_group_by_sizes,
args.subsample_seed,
)
Expand Down
46 changes: 24 additions & 22 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from augur.io.metadata import METADATA_DATE_COLUMN
from augur.io.print import print_err
from . import constants
from .weights_file import WEIGHTS_COLUMN, get_weighted_columns, read_weights_file
from .weights_file import WEIGHTS_COLUMN, COLUMN_VALUE_FOR_DEFAULT_WEIGHT, get_default_weight, get_weighted_columns, read_weights_file

Group = Tuple[str, ...]
"""Combinations of grouping column values in tuple form."""
Expand Down Expand Up @@ -315,7 +315,6 @@ def get_weighted_group_sizes(
group_by: List[str],
weights_file: str,
target_total_size: int,
output_missing_weights: Optional[str],
output_sizes_file: Optional[str],
random_seed: Optional[int],
) -> Dict[Group, int]:
Expand All @@ -336,17 +335,17 @@ def get_weighted_group_sizes(
# weights to ensure equal weighting of unweighted columns *within* each
# weighted group defined by the weighted columns.
weights = _add_unweighted_columns(weights, groups, group_by, unweighted_columns)

weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups)
weights = _drop_unused_groups(weights, groups, group_by)

weights = _adjust_weights_for_unweighted_columns(weights, weighted_columns, unweighted_columns)
else:
weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, groups)
weights = _drop_unused_groups(weights, groups, group_by)

weights = _calculate_weighted_group_sizes(weights, target_total_size, random_seed)

missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1))
if missing_groups:
weights = _handle_incomplete_weights(weights, weights_file, weighted_columns, group_by, missing_groups, output_missing_weights)

# Add columns to summarize the input data
weights[INPUT_SIZE_COLUMN] = weights.apply(lambda row: records_per_group[tuple(row[group_by].values)], axis=1)
weights[OUTPUT_SIZE_COLUMN] = weights[[INPUT_SIZE_COLUMN, TARGET_SIZE_COLUMN]].min(axis=1)
Expand Down Expand Up @@ -462,11 +461,15 @@ def _handle_incomplete_weights(
weights_file: str,
weighted_columns: List[str],
group_by: List[str],
missing_groups: Collection[Group],
output_missing_weights: Optional[str],
groups: Iterable[Group],
) -> pd.DataFrame:
"""Handle the case where the weights file does not cover all rows in the metadata.
"""
missing_groups = set(groups) - set(weights[group_by].apply(tuple, axis=1))

if not missing_groups:
return weights

# Collect the column values that are missing weights.
missing_values_by_column = defaultdict(set)
for group in missing_groups:
Expand All @@ -478,27 +481,26 @@ def _handle_incomplete_weights(
for column in weighted_columns:
missing_values_by_column[column].add(column_to_value_map[column])

columns_with_values = '\n - '.join(f'{column!r}: {list(values)}' for column, values in missing_values_by_column.items())
if not output_missing_weights:
columns_with_values = '\n - '.join(f'{column!r}: {list(sorted(values))}' for column, values in sorted(missing_values_by_column.items()))

default_weight = get_default_weight(weights, weighted_columns)

if not default_weight:
raise AugurError(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
- {columns_with_values}
Re-run with --output-group-by-missing-weights to continue."""))
To fix this, either:
(1) specify weights explicitly - add entries to {weights_file!r} for the values above, or
(2) specify a default weight - add an entry to {weights_file!r} with the value {COLUMN_VALUE_FOR_DEFAULT_WEIGHT!r} for all columns"""))
else:
missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by)
missing_weights_weighted_columns_only = missing_weights[weighted_columns].drop_duplicates()
missing_weights_weighted_columns_only[WEIGHTS_COLUMN] = ''
missing_weights_weighted_columns_only.to_csv(output_missing_weights, index=False, sep='\t')
print_err(dedent(f"""\
The input metadata contains these values under the following columns that are not covered by {weights_file!r}:
WARNING: The input metadata contains these values under the following columns that are not directly covered by {weights_file!r}:
- {columns_with_values}
Sequences associated with these values will be dropped.
A separate weights file has been generated with implicit weight of zero for these values: {output_missing_weights!r}
Consider updating {weights_file!r} with nonzero weights and re-running without --output-group-by-missing-weights."""))
The default weight of {default_weight!r} will be used for all groups defined by those values."""))

# Set the weight for these groups to zero, effectively dropping all sequences.
missing_weights[TARGET_SIZE_COLUMN] = 0
return pd.merge(weights, missing_weights, on=[*group_by, TARGET_SIZE_COLUMN], how='outer')
missing_weights = pd.DataFrame(sorted(missing_groups), columns=group_by)
missing_weights[WEIGHTS_COLUMN] = default_weight
return pd.merge(weights, missing_weights, on=[*group_by, WEIGHTS_COLUMN], how='outer')


def create_queues_by_group(max_sizes_per_group):
Expand Down
11 changes: 11 additions & 0 deletions augur/filter/weights_file.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import pandas as pd
from textwrap import dedent
from typing import List
from augur.errors import AugurError


WEIGHTS_COLUMN = 'weight'
COLUMN_VALUE_FOR_DEFAULT_WEIGHT = 'default'


class InvalidWeightsFile(AugurError):
Expand Down Expand Up @@ -42,3 +44,12 @@ def get_weighted_columns(weights_file):
raise InvalidWeightsFile(weights_file, "File is empty.")
columns.remove(WEIGHTS_COLUMN)
return columns


def get_default_weight(weights: pd.DataFrame, weighted_columns: List[str]):
default_weight_values = weights[(weights[weighted_columns] == COLUMN_VALUE_FOR_DEFAULT_WEIGHT).all(axis=1)][WEIGHTS_COLUMN].unique()

if len(default_weight_values) > 1:
raise InvalidWeightsFile(f"Multiple default weights were specified: {', '.join(repr(weight) for weight in default_weight_values)}. Only one default weight entry can be accepted.")
if len(default_weight_values) == 1:
return default_weight_values[0]
80 changes: 65 additions & 15 deletions tests/functional/filter/cram/subsample-weighted.t
Original file line number Diff line number Diff line change
Expand Up @@ -89,34 +89,84 @@ Sampling with incomplete weights should show an error.
Sampling with weights defined by weights.tsv.
ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv':
- 'location': ['B']
Re-run with --output-group-by-missing-weights to continue.
To fix this, either:
(1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or
(2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns
[2]

Re-running with --output-group-by-missing-weights shows a warning and a file to use for fixing.
Re-running with a default weight shows a warning and continues.

$ cat >weights.tsv <<~~
> location weight
> A 2
> default 1
> ~~

$ ${AUGUR} filter \
> --metadata metadata.tsv \
> --group-by month location \
> --group-by-weights weights.tsv \
> --subsample-max-sequences 6 \
> --subsample-seed 0 \
> --output-group-by-missing-weights missing-weights.tsv \
> --output-strains strains.txt
Sampling with weights defined by weights.tsv.
NOTE: Skipping 1 group due to lack of entries in metadata.
WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv':
- 'location': ['B']
The default weight of 1 will be used for all groups defined by those values.
NOTE: Skipping 4 groups due to lack of entries in metadata.
NOTE: Weights were not provided for the column 'month'. Using equal weights across values in that column.
The input metadata contains these values under the following columns that are not covered by 'weights.tsv':
2 strains were dropped during filtering
2 were dropped because of subsampling criteria
6 strains passed all filters

To specify a default weight, the value 'default' must be set for all weighted columns.

$ cat >weights.tsv <<~~
> location month weight
> A 2000-01 2
> A 2000-02 2
> default 1
> ~~

$ ${AUGUR} filter \
> --metadata metadata.tsv \
> --group-by month location \
> --group-by-weights weights.tsv \
> --subsample-max-sequences 6 \
> --subsample-seed 0 \
> --output-strains strains.txt
Sampling with weights defined by weights.tsv.
ERROR: The input metadata contains these values under the following columns that are not covered by 'weights.tsv':
- 'location': ['B']
Sequences associated with these values will be dropped.
A separate weights file has been generated with implicit weight of zero for these values: 'missing-weights.tsv'
Consider updating 'weights.tsv' with nonzero weights and re-running without --output-group-by-missing-weights.
4 strains were dropped during filtering
4 were dropped because of subsampling criteria
4 strains passed all filters

$ cat missing-weights.tsv
location weight
B
- 'month': ['2000-01', '2000-03']
To fix this, either:
(1) specify weights explicitly - add entries to 'weights.tsv' for the values above, or
(2) specify a default weight - add an entry to 'weights.tsv' with the value 'default' for all columns
[2]

$ cat >weights.tsv <<~~
> location month weight
> A 2000-01 2
> A 2000-02 2
> default default 1
> ~~

$ ${AUGUR} filter \
> --metadata metadata.tsv \
> --group-by month location \
> --group-by-weights weights.tsv \
> --subsample-max-sequences 6 \
> --subsample-seed 0 \
> --output-strains strains.txt
Sampling with weights defined by weights.tsv.
WARNING: The input metadata contains these values under the following columns that are not directly covered by 'weights.tsv':
- 'location': ['B']
- 'month': ['2000-01', '2000-03']
The default weight of 1 will be used for all groups defined by those values.
NOTE: Skipping 1 group due to lack of entries in metadata.
2 strains were dropped during filtering
2 were dropped because of subsampling criteria
6 strains passed all filters

When --group-by-weights is specified, all columns must be provided in
--group-by.
Expand Down

0 comments on commit a00d3b5

Please sign in to comment.