зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1578: Include FedProx aggregation method
Implementation of FedProx aggregation method, taken from "Federated Learning on Non-IID Data Silos: An Experimental Study" paper (https://arxiv.org/pdf/2102.02079.pdf). [x] nlg_gru_fedprox: https://ml.azure.com/runs/8c052875-d053-4e70-b5b6-8f591faf5936?wsid=/subscriptions/d4404794-ab5b-48de-b7c7-ec1fefb0a04e/resourcegroups/gcr-singularity-octo/workspaces/msroctows&tid=72f988bf-86f1-41af-91ab-2d7cd011db47 **Comparison** - DGA ( Acc 0.15, Loss 5.5) ![image.png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1578/attachments/image.png) - FedProx ( Acc 0.18, Loss 4.8) ![image (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1578/attachments/image%20%282%29.png)
This commit is contained in:
Родитель
e8fe10b6a2
Коммит
8bfe0854ab
|
@ -47,3 +47,8 @@ allowing users to run experiments using a single-GPU worker by instantiating bot
|
|||
and clients on the same device. For more documentation about how to run an experiments
|
||||
using a single GPU, please refer to the [README](README.md).
|
||||
|
||||
|
||||
### New features
|
||||
|
||||
- 🌟 Include FedProx aggregation method
|
||||
|
||||
|
|
|
@ -42,5 +42,6 @@ This software includes parts of Fast AutoAugment repository (https://github.com/
|
|||
Code from the paper "Fast AutoAugment" (Accepted at NeurIPS 2019). This example is licenced
|
||||
under MIT License, you can find a copy of this licence at https://github.com/kakaobrain/fast-autoaugment/blob/master/LICENSE
|
||||
|
||||
|
||||
|
||||
This software includes parts of NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench).
|
||||
Code from the paper "Federated Learning on Non-IID Data Silos: An Experimental Study". This example is
|
||||
licenced under MIT License, you can find a copy of this licence at https://github.com/Xtra-Computing/NIID-Bench/blob/main/LICENSE
|
||||
|
|
|
@ -188,6 +188,8 @@ This software includes the model implementation of the FedNewsRec repository (ht
|
|||
For more information about third-party OSS licence, please refer to [NOTICE.txt](NOTICE.txt).
|
||||
|
||||
This software includes the Data Augmentation scripts of the Fast AutoAugment repository (https://github.com/kakaobrain/fast-autoaugment) to preprocess the data used in the [semisupervision](experiments/semisupervision/dataloaders/cifar_dataset.py) experiment.
|
||||
|
||||
This software included the FedProx logic implementation of the NIID-Bench repository (https://github.com/Xtra-Computing/NIID-Bench/tree/main) as Federated aggregation method used in the [trainer](core/trainer.py) object.
|
||||
## Support
|
||||
|
||||
You are welcome to open issues on this repository related to bug reports and feature requests.
|
||||
|
|
|
@ -35,7 +35,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # If enabled, the rest of parameters is needed.
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: DGA
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -34,8 +34,8 @@ privacy_metrics_config:
|
|||
# type: adamax
|
||||
# amsgrad: false
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
strategy: DGA
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedProx
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
server_config:
|
||||
|
@ -119,6 +119,7 @@ server_config:
|
|||
|
||||
# Dictates the learning parameters for client-side model updates. Train data is defined inside this config.
|
||||
client_config:
|
||||
mu: 0.001 # Used only for FedProx aggregation method
|
||||
meta_learning: basic
|
||||
stats_on_smooth_grad: true
|
||||
ignore_subtask: false
|
||||
|
|
|
@ -260,7 +260,8 @@ class Client:
|
|||
privacy_metrics_config = config.get('privacy_metrics_config', None)
|
||||
model_path = config["model_path"]
|
||||
|
||||
StrategyClass = select_strategy(config['strategy'])
|
||||
strategy_algo = config['strategy']
|
||||
StrategyClass = select_strategy(strategy_algo)
|
||||
strategy = StrategyClass('client', config)
|
||||
print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG)
|
||||
send_dicts = config['server_config'].get('send_dicts', False)
|
||||
|
@ -358,10 +359,11 @@ class Client:
|
|||
# This is where training actually happens
|
||||
algo_payload = None
|
||||
|
||||
if semisupervision_config != None:
|
||||
if strategy_algo == 'FedLabels':
|
||||
datasets =[get_dataset(data_path, config, task, mode="train", test_only=False, data_strct=data_strcts[i], user_idx=0) for i in range(3)]
|
||||
algo_payload = {'algo':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config}
|
||||
|
||||
algo_payload = {'strategy':'FedLabels', 'data': datasets, 'iter': iteration, 'config': semisupervision_config}
|
||||
elif strategy_algo == 'FedProx':
|
||||
algo_payload = {'strategy':'FedProx', 'mu': client_config.get('mu',0.001)}
|
||||
train_loss, num_samples, algo_computation = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics, algo_payload = algo_payload)
|
||||
print_rank('client={}: training loss={}'.format(client_id[0], train_loss), loglevel=logging.DEBUG)
|
||||
|
||||
|
|
|
@ -7,9 +7,15 @@ from .dga import DGA
|
|||
from .fedlabels import FedLabels
|
||||
|
||||
def select_strategy(strategy):
|
||||
''' Selects the aggregation strategy class
|
||||
|
||||
NOTE: FedProx uses FedAvg weights during aggregation,
|
||||
which are proportional to the number of samples in
|
||||
each client.
|
||||
'''
|
||||
if strategy.lower() == 'dga':
|
||||
return DGA
|
||||
elif strategy.lower() == 'fedavg':
|
||||
elif strategy.lower() in ['fedavg', 'fedprox']:
|
||||
return FedAvg
|
||||
elif strategy.lower() == 'fedlabels':
|
||||
return FedLabels
|
||||
|
|
|
@ -328,8 +328,10 @@ class Trainer(TrainerBase):
|
|||
|
||||
if algo_payload == None:
|
||||
num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch(desired_max_samples, apply_privacy_metrics)
|
||||
elif algo_payload['algo'] == 'FedLabels':
|
||||
elif algo_payload['strategy'] == 'FedLabels':
|
||||
num_samples_per_epoch, train_loss_per_epoch, algo_computation = self.run_train_epoch_sup(desired_max_samples, apply_privacy_metrics, algo_payload)
|
||||
elif algo_payload['strategy'] == 'FedProx':
|
||||
num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch_fedprox(desired_max_samples, apply_privacy_metrics, algo_payload)
|
||||
|
||||
num_samples += num_samples_per_epoch
|
||||
total_train_loss += train_loss_per_epoch
|
||||
|
@ -411,6 +413,93 @@ class Trainer(TrainerBase):
|
|||
|
||||
return num_samples, sum_train_loss
|
||||
|
||||
def run_train_epoch_fedprox(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None):
|
||||
"""Implementation example for training the model.
|
||||
|
||||
The training process should stop after the desired number of samples is processed.
|
||||
|
||||
Args:
|
||||
desired_max_samples (int): number of samples that you would like to process.
|
||||
apply_privacy_metrics (bool): whether to save the batches used for the round for privacy metrics evaluation.
|
||||
algo_payload (dict): hyperparameters needed to fine-tune FedProx algorithm.
|
||||
|
||||
Returns:
|
||||
2-tuple of (int, float): number of processed samples and total training loss.
|
||||
"""
|
||||
|
||||
sum_train_loss = 0.0
|
||||
num_samples = 0
|
||||
self.reset_gradient_power()
|
||||
|
||||
# Reset gradient just in case
|
||||
self.model.zero_grad()
|
||||
|
||||
# FedProx parameters
|
||||
mu = algo_payload['mu']
|
||||
global_model = to_device(copy.deepcopy(self.model))
|
||||
global_weight_collector = list(global_model.parameters())
|
||||
|
||||
train_loader = self.train_dataloader.create_loader()
|
||||
for batch in train_loader:
|
||||
if desired_max_samples is not None and num_samples >= desired_max_samples:
|
||||
break
|
||||
|
||||
# Compute loss
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.ignore_subtask is True:
|
||||
loss = self.model.single_task_loss(batch)
|
||||
else:
|
||||
if apply_privacy_metrics:
|
||||
if "x" in batch:
|
||||
indices = to_device(batch["x"])
|
||||
elif "input_ids" in batch:
|
||||
indices = to_device(batch["input_ids"])
|
||||
self.cached_batches.append(indices)
|
||||
loss = self.model.loss(batch)
|
||||
|
||||
# Fedprox regularization term
|
||||
fed_prox_reg = 0.0
|
||||
for param_index, param in enumerate(self.model.parameters()):
|
||||
fed_prox_reg += ((mu / 2) * torch.norm((param - global_weight_collector[param_index]))**2)
|
||||
loss += fed_prox_reg
|
||||
loss.backward()
|
||||
|
||||
# Apply gradient clipping
|
||||
if self.max_grad_norm is not None:
|
||||
grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
|
||||
# Sum up the gradient power
|
||||
self.estimate_sufficient_stats()
|
||||
|
||||
# Now that the gradients have been scaled, we can apply them
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.step()
|
||||
|
||||
print_rank("step: {}, loss: {}".format(self.step, loss.item()), loglevel=logging.DEBUG)
|
||||
|
||||
# Post-processing in this loop
|
||||
# Sum up the loss
|
||||
sum_train_loss += loss.item()
|
||||
|
||||
# Increment the number of frames processed already
|
||||
if "attention_mask" in batch:
|
||||
num_samples += torch.sum(batch["attention_mask"].detach().cpu() == 1).item()
|
||||
elif "total_frames" in batch:
|
||||
num_samples += batch["total_frames"]
|
||||
else:
|
||||
num_samples += len(batch["x"])
|
||||
|
||||
# Update the counters
|
||||
self.step += 1
|
||||
|
||||
# Take a step in lr_scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return num_samples, sum_train_loss
|
||||
|
||||
def run_train_epoch_sup(self, desired_max_samples=None, apply_privacy_metrics=False, algo_payload=None):
|
||||
"""Implementation example for training the model using semisupervision.
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: DGA
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -9,7 +9,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
strategy: DGA
|
||||
strategy: DGA # Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
|
||||
server_config:
|
||||
wantRL: false # whether to use RL-based meta-optimizers
|
||||
|
|
|
@ -12,7 +12,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedAvg
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -13,7 +13,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedAvg
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -12,7 +12,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedAvg
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -12,7 +12,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: DGA
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -11,7 +11,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedAvg
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -12,7 +12,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedAvg
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
|
@ -13,7 +13,7 @@ dp_config:
|
|||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
# Select the Federated optimizer to use (e.g. DGA or FedAvg)
|
||||
# Select the Federated optimizer to use (e.g. DGA, FedAvg or FedProx)
|
||||
strategy: FedLabels
|
||||
|
||||
# Determines all the server-side settings for training and evaluation rounds
|
||||
|
|
Загрузка…
Ссылка в новой задаче