diff --git a/augur/filter/_run.py b/augur/filter/_run.py index c2a22ad43..8325f23f1 100644 --- a/augur/filter/_run.py +++ b/augur/filter/_run.py @@ -23,7 +23,7 @@ from . import include_exclude_rules from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs from .include_exclude_rules import apply_filters, construct_filters -from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_group_sizes, create_queues_by_group, get_groups_for_subsampling +from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling def run(args): @@ -276,19 +276,20 @@ def run(args): except TooManyGroupsError as error: raise AugurError(error) - if (probabilistic_used): - print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") - else: - print_err(f"Sampling at {sequences_per_group} per group.") - if queues_by_group is None: # We know all of the possible groups now from the first pass through # the metadata, so we can create queues for all groups at once. - group_sizes = get_group_sizes( - records_per_group.keys(), - sequences_per_group, - random_seed=args.subsample_seed, - ) + if (probabilistic_used): + print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.") + group_sizes = get_probabilistic_group_sizes( + records_per_group.keys(), + sequences_per_group, + random_seed=args.subsample_seed, + ) + else: + print_err(f"Sampling at {sequences_per_group} per group.") + assert type(sequences_per_group) is int + group_sizes = {group: sequences_per_group for group in records_per_group.keys()} queues_by_group = create_queues_by_group(group_sizes) # Make a second pass through the metadata, only considering records that diff --git a/augur/filter/subsample.py b/augur/filter/subsample.py index 4c63f1507..ed0f73c9d 100644 --- a/augur/filter/subsample.py +++ b/augur/filter/subsample.py @@ -249,65 +249,52 @@ def get_items(self): yield item -def get_group_sizes(groups, target_group_size, random_seed=None): +def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None): """Create a dictionary of maximum sizes per group. - When the target maximum size is fractional, probabilistically generate - varying sizes from a Poisson distribution. Make at least the given number of - maximum attempts to generate sizes for which the total of all sizes is - greater than zero. - - Otherwise, each group's size is simply the target maximum size. + Probabilistically generate varying sizes from a Poisson distribution. Make + at least the given number of maximum attempts to generate sizes for which + the total of all sizes is greater than zero. Examples -------- - - Get sizes for two groups with a fixed maximum size. - - >>> groups = ("2015", "2016") - >>> group_sizes = get_group_sizes(groups, 2) - >>> sum(group_sizes.values()) - 4 - Get sizes for two groups with a fractional maximum size. Their total size should still be an integer value greater than zero. + >>> groups = ("2015", "2016") >>> seed = 314159 - >>> group_sizes = get_group_sizes(groups, 0.1, random_seed=seed) + >>> group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed) >>> int(sum(group_sizes.values())) > 0 True A subsequent run of this function with the same groups and random seed should produce the same group sizes. - >>> more_group_sizes = get_group_sizes(groups, 0.1, random_seed=seed) + >>> more_group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed) >>> list(group_sizes.values()) == list(more_group_sizes.values()) True """ - if target_group_size < 1.0: - # For small fractional maximum sizes, it is possible to randomly select - # maximum queue sizes that all equal zero. When this happens, filtering - # fails unexpectedly. We make multiple attempts to create queues with - # maximum sizes greater than zero for at least one queue. - random_generator = np.random.default_rng(random_seed) - total_max_size = 0 - attempts = 0 - max_attempts = 100 - max_sizes_per_group = {} - - while total_max_size == 0 and attempts < max_attempts: - for group in sorted(groups): - max_sizes_per_group[group] = random_generator.poisson(target_group_size) - - total_max_size = sum(max_sizes_per_group.values()) - attempts += 1 - - return max_sizes_per_group - else: - assert type(target_group_size) is int - - return {group: target_group_size for group in groups} + assert target_group_size < 1.0 + + # For small fractional maximum sizes, it is possible to randomly select + # maximum queue sizes that all equal zero. When this happens, filtering + # fails unexpectedly. We make multiple attempts to create queues with + # maximum sizes greater than zero for at least one queue. + random_generator = np.random.default_rng(random_seed) + total_max_size = 0 + attempts = 0 + max_attempts = 100 + max_sizes_per_group = {} + + while total_max_size == 0 and attempts < max_attempts: + for group in sorted(groups): + max_sizes_per_group[group] = random_generator.poisson(target_group_size) + + total_max_size = sum(max_sizes_per_group.values()) + attempts += 1 + + return max_sizes_per_group def create_queues_by_group(max_sizes_per_group):