Skip to content

Commit

Permalink
Only use group size function for probabilistic sampling
Browse files Browse the repository at this point in the history
The logic for deterministic sampling is so simple that it can be brought
out into the calling context. This also allows the logging messages to
be paired with the logic it describes.
  • Loading branch information
victorlin committed May 6, 2024
1 parent 90bfce9 commit 22a60b3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 51 deletions.
23 changes: 12 additions & 11 deletions augur/filter/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from . import include_exclude_rules
from .io import cleanup_outputs, get_useful_metadata_columns, read_priority_scores, write_metadata_based_outputs
from .include_exclude_rules import apply_filters, construct_filters
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_group_sizes, create_queues_by_group, get_groups_for_subsampling
from .subsample import PriorityQueue, TooManyGroupsError, calculate_sequences_per_group, get_probabilistic_group_sizes, create_queues_by_group, get_groups_for_subsampling


def run(args):
Expand Down Expand Up @@ -276,19 +276,20 @@ def run(args):
except TooManyGroupsError as error:
raise AugurError(error)

if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
else:
print_err(f"Sampling at {sequences_per_group} per group.")

if queues_by_group is None:
# We know all of the possible groups now from the first pass through
# the metadata, so we can create queues for all groups at once.
group_sizes = get_group_sizes(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
)
if (probabilistic_used):
print_err(f"Sampling probabilistically at {sequences_per_group:0.4f} sequences per group, meaning it is possible to have more than the requested maximum of {args.subsample_max_sequences} sequences after filtering.")
group_sizes = get_probabilistic_group_sizes(
records_per_group.keys(),
sequences_per_group,
random_seed=args.subsample_seed,
)
else:
print_err(f"Sampling at {sequences_per_group} per group.")
assert type(sequences_per_group) is int
group_sizes = {group: sequences_per_group for group in records_per_group.keys()}
queues_by_group = create_queues_by_group(group_sizes)

# Make a second pass through the metadata, only considering records that
Expand Down
67 changes: 27 additions & 40 deletions augur/filter/subsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,65 +249,52 @@ def get_items(self):
yield item


def get_group_sizes(groups, target_group_size, random_seed=None):
def get_probabilistic_group_sizes(groups, target_group_size, random_seed=None):
"""Create a dictionary of maximum sizes per group.
When the target maximum size is fractional, probabilistically generate
varying sizes from a Poisson distribution. Make at least the given number of
maximum attempts to generate sizes for which the total of all sizes is
greater than zero.
Otherwise, each group's size is simply the target maximum size.
Probabilistically generate varying sizes from a Poisson distribution. Make
at least the given number of maximum attempts to generate sizes for which
the total of all sizes is greater than zero.
Examples
--------
Get sizes for two groups with a fixed maximum size.
>>> groups = ("2015", "2016")
>>> group_sizes = get_group_sizes(groups, 2)
>>> sum(group_sizes.values())
4
Get sizes for two groups with a fractional maximum size. Their total
size should still be an integer value greater than zero.
>>> groups = ("2015", "2016")
>>> seed = 314159
>>> group_sizes = get_group_sizes(groups, 0.1, random_seed=seed)
>>> group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed)
>>> int(sum(group_sizes.values())) > 0
True
A subsequent run of this function with the same groups and random seed
should produce the same group sizes.
>>> more_group_sizes = get_group_sizes(groups, 0.1, random_seed=seed)
>>> more_group_sizes = get_probabilistic_group_sizes(groups, 0.1, random_seed=seed)
>>> list(group_sizes.values()) == list(more_group_sizes.values())
True
"""
if target_group_size < 1.0:
# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
random_generator = np.random.default_rng(random_seed)
total_max_size = 0
attempts = 0
max_attempts = 100
max_sizes_per_group = {}

while total_max_size == 0 and attempts < max_attempts:
for group in sorted(groups):
max_sizes_per_group[group] = random_generator.poisson(target_group_size)

total_max_size = sum(max_sizes_per_group.values())
attempts += 1

return max_sizes_per_group
else:
assert type(target_group_size) is int

return {group: target_group_size for group in groups}
assert target_group_size < 1.0

# For small fractional maximum sizes, it is possible to randomly select
# maximum queue sizes that all equal zero. When this happens, filtering
# fails unexpectedly. We make multiple attempts to create queues with
# maximum sizes greater than zero for at least one queue.
random_generator = np.random.default_rng(random_seed)
total_max_size = 0
attempts = 0
max_attempts = 100
max_sizes_per_group = {}

while total_max_size == 0 and attempts < max_attempts:
for group in sorted(groups):
max_sizes_per_group[group] = random_generator.poisson(target_group_size)

total_max_size = sum(max_sizes_per_group.values())
attempts += 1

return max_sizes_per_group


def create_queues_by_group(max_sizes_per_group):
Expand Down

0 comments on commit 22a60b3

Please sign in to comment.