Skip to content

Commit

Permalink
Add max_num_validation_steps member of config and client and related …
Browse files Browse the repository at this point in the history
…logic
  • Loading branch information
jewelltaylor committed Dec 16, 2024
1 parent a2fd930 commit 0413f2d
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions fl4health/clients/basic_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __init__(
self.num_val_samples: int
self.num_test_samples: Optional[int] = None
self.learning_rate: Optional[float] = None
# Config can contain max_num_validation_steps key, which determines an upper bound
# for the validation steps taken. If not specified, no upper bound will be enforced.
self.max_num_validation_steps: int | None = None

def _maybe_checkpoint(self, loss: float, metrics: dict[str, Scalar], checkpoint_mode: CheckpointMode) -> None:
"""
Expand Down Expand Up @@ -231,6 +234,7 @@ def process_config(self, config: Config) -> Tuple[Union[int, None], Union[int, N
"""
current_server_round = narrow_dict_type(config, "current_server_round", int)

# Parse config to determine train by steps or train by epochs
if ("local_epochs" in config) and ("local_steps" in config):
raise ValueError("Config cannot contain both local_epochs and local_steps. Please specify only one.")
elif "local_epochs" in config:
Expand Down Expand Up @@ -726,7 +730,8 @@ def _validate_or_test(
include_losses_in_metrics: bool = False,
) -> Tuple[float, Dict[str, Scalar]]:
"""
Evaluate the model on the given validation or test dataset.
Evaluate the model on the given validation or test dataset. If max_num_validation_steps attribute
is not None and in validation phase, steps are limited to the value of max_num_validation_steps.
Args:
loader (DataLoader): The data loader for the dataset (validation or test).
Expand All @@ -748,7 +753,10 @@ def _validate_or_test(
metric_manager.clear()
loss_meter.clear()
with torch.no_grad():
for input, target in maybe_progress_bar(loader, self.progress_bar):
for i, (input, target) in enumerate(maybe_progress_bar(loader, self.progress_bar)):
# Limit validation to self.max_num_validation_steps if it is defined
if logging_mode == LoggingMode.VALIDATION and self.max_num_validation_steps == i:
break
input = move_data_to_device(input, self.device)
target = move_data_to_device(target, self.device)
losses, preds = self.val_step(input, target)
Expand Down Expand Up @@ -830,11 +838,24 @@ def setup_client(self, config: Config) -> None:
self.val_loader = val_loader
self.test_loader = self.get_test_data_loader(config)

if "max_num_validation_steps" in config:
self.max_num_validation_steps = narrow_dict_type(config, "max_num_validation_steps", int)
else:
self.max_num_validation_steps = None

# The following lines are type ignored because torch datasets are not "Sized"
# IE __len__ is considered optionally defined. In practice, it is almost always defined
# and as such, we will make that assumption.
self.num_train_samples = len(self.train_loader.dataset) # type: ignore

# if max_num_validation_steps is defined, limit validation set to minimum of
# batch_size * max_num_validation_steps and the length of validation set
self.num_val_samples = len(self.val_loader.dataset) # type: ignore
if self.max_num_validation_steps is not None:
val_batch_size = self.val_loader.batch_size
max_val_size = self.max_num_validation_steps * val_batch_size # type: ignore
self.num_val_samples = min(self.num_val_samples, max_val_size) # type: ignore

if self.test_loader:
self.num_test_samples = len(self.test_loader.dataset) # type: ignore

Expand Down

0 comments on commit 0413f2d

Please sign in to comment.