From 0790da4f8a6f390ec8c794c289285c9b334ede5d Mon Sep 17 00:00:00 2001 From: Shawn Carere Date: Wed, 30 Oct 2024 14:19:10 -0400 Subject: [PATCH] modified progress bar formatting to inherit from flwr formatting --- fl4health/clients/basic_client.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/fl4health/clients/basic_client.py b/fl4health/clients/basic_client.py index ce6962831..bfed7f1a7 100644 --- a/fl4health/clients/basic_client.py +++ b/fl4health/clients/basic_client.py @@ -1,15 +1,17 @@ import copy import datetime +import os from collections.abc import Iterable, Sequence from enum import Enum -from logging import INFO, WARNING +from inspect import currentframe, getframeinfo +from logging import INFO, WARNING, LogRecord from pathlib import Path from typing import Any, Optional, Tuple, Union import torch import torch.nn as nn from flwr.client import NumPyClient -from flwr.common.logger import LOG_COLORS, log +from flwr.common.logger import LOGGER_NAME, console_handler, log from flwr.common.typing import Config, NDArrays, Scalar from torch.nn.modules.loss import _Loss from torch.optim import Optimizer @@ -1264,6 +1266,19 @@ def maybe_progress_bar(self, iterable: Iterable) -> Iterable: if not self.progress_bar: return iterable else: + # We can use the flwr console handler to format progress bar + frame = currentframe() + lineno = 0 if frame is None else getframeinfo(frame).lineno + record = LogRecord( + name=LOGGER_NAME, + pathname=os.path.abspath(os.getcwd()), + lineno=lineno, # + args={}, + exc_info=None, + level=INFO, + msg="{l_bar}{bar}{r_bar}", + ) + format = console_handler.format(record) # Create a clean looking tqdm instance that matches the flwr logging kwargs: Any = { "leave": True, @@ -1271,7 +1286,7 @@ def maybe_progress_bar(self, iterable: Iterable) -> Iterable: # "desc": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']} ", "unit": "steps", "dynamic_ncols": True, - "bar_format": f"{LOG_COLORS['INFO']}INFO{LOG_COLORS['RESET']}" + " : {l_bar}{bar}{r_bar}", + "bar_format": format, } return tqdm(iterable, **kwargs)