Skip to content

Commit

Permalink
Merge branch 'main' into dbe/maybe_checkpoint_logging
Browse files Browse the repository at this point in the history
  • Loading branch information
emersodb committed Nov 11, 2024
2 parents 3e96809 + 23b4224 commit 8f46917
Show file tree
Hide file tree
Showing 125 changed files with 604 additions and 585 deletions.
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_dim_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@


class CvaeDimClient(BasicClient):
def __init__(self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device, condition: torch.Tensor):
super().__init__(data_path, metrics, DEVICE)
def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device, condition: torch.Tensor):
super().__init__(data_path, metrics, device)
self.condition = condition

def get_data_loaders(self, config: Config) -> Tuple[DataLoader, DataLoader]:
Expand Down Expand Up @@ -64,11 +64,11 @@ def get_model(self, config: Config) -> nn.Module:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
set_all_random_seeds(42)
# Creating the condition vector used for training this CVAE.
condition_vector = torch.nn.functional.one_hot(torch.tensor(args.condition), num_classes=args.num_conditions)
client = CvaeDimClient(data_path, [Accuracy("accuracy")], DEVICE, condition_vector)
client = CvaeDimClient(data_path, [Accuracy("accuracy")], device, condition_vector)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def binary_class_condition_data_converter(


class CondConvAutoEncoderClient(BasicClient):
def __init__(self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device) -> None:
super().__init__(data_path, metrics, DEVICE)
def __init__(self, data_path: Path, metrics: Sequence[Metric], device: torch.device) -> None:
super().__init__(data_path, metrics, device)
# To train an autoencoder-based model we need to define a data converter that prepares the data
# for self-supervised learning, concatenates the inputs and condition (packing) to let the data
# fit into the training pipeline, and unpacks the input from condition for the model inference.
Expand Down Expand Up @@ -93,8 +93,8 @@ def get_model(self, config: Config) -> nn.Module:
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
args = parser.parse_args()
set_all_random_seeds(42)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = CondConvAutoEncoderClient(data_path=data_path, metrics=[], DEVICE=DEVICE)
client = CondConvAutoEncoderClient(data_path=data_path, metrics=[], device=device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
8 changes: 4 additions & 4 deletions examples/ae_examples/cvae_examples/mlp_cvae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

class CondAutoEncoderClient(BasicClient):
def __init__(
self, data_path: Path, metrics: Sequence[Metric], DEVICE: torch.device, condition: torch.Tensor
self, data_path: Path, metrics: Sequence[Metric], device: torch.device, condition: torch.Tensor
) -> None:
super().__init__(data_path, metrics, DEVICE)
super().__init__(data_path, metrics, device)
# In this example, condition is based on client ID.
self.condition_vector = condition
# To train an autoencoder-based model we need to define a data converter that prepares the data
Expand Down Expand Up @@ -96,13 +96,13 @@ def get_model(self, config: Config) -> nn.Module:
)
args = parser.parse_args()
set_all_random_seeds(42)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
# Create the condition vector. This creation needs to be "consistent" across clients.
# In this example, condition is based on client ID.
# Client should decide how they want to create their condition vector.
# Here we use simple one_hot_encoding but it can be any vector.
condition_vector = torch.nn.functional.one_hot(torch.tensor(args.condition), num_classes=args.num_conditions)
client = CondAutoEncoderClient(data_path, [], DEVICE, condition_vector)
client = CondAutoEncoderClient(data_path, [], device, condition_vector)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/ae_examples/fedprox_vae_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def get_model(self, config: Config) -> nn.Module:
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = VaeFedProxClient(data_path, [], DEVICE)
client = VaeFedProxClient(data_path, [], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/apfl_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def get_criterion(self, config: Config) -> _Loss:

args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
client = MnistApflClient(data_path, [Accuracy()], device, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown() # This will tell the JsonReporter to dump data
4 changes: 2 additions & 2 deletions examples/basic_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def get_model(self, config: Config) -> nn.Module:
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE)
client = CifarClient(data_path, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
6 changes: 3 additions & 3 deletions examples/ditto_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ def get_criterion(self, config: Config) -> _Loss:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Device to be used: {device}")
log(INFO, f"Server Address: {args.server_address}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistDittoClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
client = MnistDittoClient(data_path, [Accuracy()], device, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
Expand Down
4 changes: 2 additions & 2 deletions examples/docker_basic_example/fl_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def setup_client(self, config: Config) -> None:
args = parser.parse_args()

# Load model and data
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE)
client = CifarClient(data_path, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="fl_server:8080", client=client.to_client())
4 changes: 2 additions & 2 deletions examples/dp_fed_examples/client_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def get_criterion(self, config: Config) -> _Loss:

# Load model and data
data_path = Path(args.dataset_path)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/dp_fed_examples/client_level_dp_weighted/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def get_criterion(self, config: Config) -> _Loss:
args = parser.parse_args()

# Load model and data
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = HospitalClient(data_path, [Accuracy("accuracy")], DEVICE)
client = HospitalClient(data_path, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/dp_fed_examples/instance_level_dp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_criterion(self, config: Config) -> _Loss:

# Load model and data
data_path = Path(args.dataset_path)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], DEVICE, checkpointer=checkpointer)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
client = CifarClient(data_path, [Accuracy("accuracy")], device, checkpointer=checkpointer)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())

client.shutdown()
4 changes: 2 additions & 2 deletions examples/dp_scaffold_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def get_criterion(self, config: Config) -> _Loss:

args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)

client = MnistDPScaffoldClient(data_path=data_path, metrics=[Accuracy()], device=DEVICE)
client = MnistDPScaffoldClient(data_path=data_path, metrics=[Accuracy()], device=device)

fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/dynamic_layer_exchange_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = CifarDynamicLayerClient(data_path, [Accuracy("accuracy")], DEVICE, store_initial_model=True)
client = CifarDynamicLayerClient(data_path, [Accuracy("accuracy")], device, store_initial_model=True)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/ensemble_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def get_criterion(self, config: Config) -> _Loss:

args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)

client = MnistEnsembleClient(data_path, [Accuracy()], DEVICE)
client = MnistEnsembleClient(data_path, [Accuracy()], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
6 changes: 3 additions & 3 deletions examples/feature_alignment_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ def get_data_frame(self, config: Config) -> pd.DataFrame:
)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

data_path = Path(args.dataset_path)

log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Device to be used: {device}")
log(INFO, f"Server Address: {args.server_address}")

# ham_id is the id column and LOSgroupNum is the target column.
client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], DEVICE, "hadm_id", ["LOSgroupNum"])
client = Mimic3TabularDataClient(data_path, [Accuracy("accuracy")], device, "hadm_id", ["LOSgroupNum"])
# This call demonstrates how the user may specify a particular sklearn pipeline for a specific feature.
client.preset_specific_pipeline("NumNotes", MaxAbsScaler())
fl.client.start_client(server_address=args.server_address, client=client.to_client())
8 changes: 4 additions & 4 deletions examples/fedbn_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def get_parameter_exchanger(self, config: Config) -> ParameterExchanger:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Device to be used: {device}")
log(INFO, f"Server Address: {args.server_address}")

if args.dataset_name in ["Barcelona", "Rosendahl", "Vienna", "UFES", "Canada"]:
client: BasicClient = SkinCancerFedBNClient(data_path, [Accuracy()], DEVICE, args.dataset_name)
client: BasicClient = SkinCancerFedBNClient(data_path, [Accuracy()], device, args.dataset_name)
elif args.dataset_name == "mnist":
client = MnistFedBNClient(data_path, [Accuracy()], DEVICE)
client = MnistFedBNClient(data_path, [Accuracy()], device)
else:
raise ValueError(
"Unsupported dataset name. Please choose from 'Barcelona', 'Rosendahl', \
Expand Down
4 changes: 2 additions & 2 deletions examples/feddg_ga_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ def get_criterion(self, config: Config) -> _Loss:

args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistApflClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
client = MnistApflClient(data_path, [Accuracy()], device, reporters=[JsonReporter()])
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/federated_eval_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def get_criterion(self, config: Config) -> _Loss:
data_path = Path(args.dataset_path)
client_checkpoint_path = Path(args.checkpoint_path) if args.checkpoint_path else None

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

client = CifarClient(
data_path=data_path,
metrics=[Accuracy("accuracy")],
device=DEVICE,
device=device,
model_checkpoint_path=client_checkpoint_path,
)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
Expand Down
4 changes: 2 additions & 2 deletions examples/fedopt_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def predict(
args = parser.parse_args()

# Load model and data
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
client = NewsClassifierClient(data_path, [CompoundMetric("Compound Metric")], DEVICE)
client = NewsClassifierClient(data_path, [CompoundMetric("Compound Metric")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
4 changes: 2 additions & 2 deletions examples/fedpca_examples/dim_reduction/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,14 @@ def get_model(self, config: Config) -> nn.Module:
parser.add_argument("--seed", action="store", type=int, help="Random seed for this client.")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
seed = args.seed

# If the user wants to ensure that this example uses the same data as
# the data used in the perform_pca example, then both examples
# should use the same random seed.
set_all_random_seeds(seed)
client = MnistFedPcaClient(data_path, [Accuracy("accuracy")], DEVICE)
client = MnistFedPcaClient(data_path, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/fedpca_examples/perform_pca/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_data_tensor(self, data_loader: DataLoader) -> Tensor:
parser.add_argument("--seed", action="store", type=int, help="Random seed for this client.")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
components_save_path = Path(args.components_save_path)
seed = args.seed
Expand All @@ -45,5 +45,5 @@ def get_data_tensor(self, data_loader: DataLoader) -> Tensor:
# the data used in the dim_reduction example, then both examples
# should use the same random seed.
set_all_random_seeds(seed)
client = MnistFedPCAClient(data_path=data_path, device=DEVICE, model_save_path=components_save_path)
client = MnistFedPCAClient(data_path=data_path, device=device, model_save_path=components_save_path)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
4 changes: 2 additions & 2 deletions examples/fedper_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def get_criterion(self, config: Config) -> _Loss:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
minority_numbers = {int(number) for number in args.minority_numbers}
client = MnistFedPerClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers)
client = MnistFedPerClient(data_path, [Accuracy("accuracy")], device, minority_numbers)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
4 changes: 2 additions & 2 deletions examples/fedpm_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def get_criterion(self, config: Config) -> _Loss:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
minority_numbers = {int(number) for number in args.minority_numbers}
client = MnistFedPmClient(data_path, [Accuracy("accuracy")], DEVICE, minority_numbers)
client = MnistFedPmClient(data_path, [Accuracy("accuracy")], device, minority_numbers)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
6 changes: 3 additions & 3 deletions examples/fedprox_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,15 @@ def get_criterion(self, config: Config) -> _Loss:
)
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_path = Path(args.dataset_path)
log(INFO, f"Device to be used: {DEVICE}")
log(INFO, f"Device to be used: {device}")
log(INFO, f"Server Address: {args.server_address}")

# Set the random seed for reproducibility
set_all_random_seeds(args.seed)

client = MnistFedProxClient(data_path, [Accuracy()], DEVICE, reporters=[JsonReporter()])
client = MnistFedProxClient(data_path, [Accuracy()], device, reporters=[JsonReporter()])
fl.client.start_client(server_address=args.server_address, client=client.to_client())

# Shutdown the client gracefully
Expand Down
4 changes: 2 additions & 2 deletions examples/fedrep_example/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def get_criterion(self, config: Config) -> _Loss:
parser.add_argument("--dataset_path", action="store", type=str, help="Path to the local dataset")
args = parser.parse_args()

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_dir = Path(args.dataset_path)
client = CifarFedRepClient(data_dir, [Accuracy("accuracy")], DEVICE)
client = CifarFedRepClient(data_dir, [Accuracy("accuracy")], device)
fl.client.start_client(server_address="0.0.0.0:8080", client=client.to_client())
client.shutdown()
Loading

0 comments on commit 8f46917

Please sign in to comment.