зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1402: Fix DP accounting bug
Fix DP accounting bug. The accountant needed the size of the total pool of clients to sample from.
This commit is contained in:
Родитель
1c3f556648
Коммит
6430cbf708
|
@ -416,6 +416,7 @@ class OptimizationServer(federated.Server):
|
|||
worker_trainer=self.worker_trainer,
|
||||
curr_iter=i,
|
||||
num_clients_curr_iter=num_clients_curr_iter,
|
||||
total_clients = len(self.client_idx_list),
|
||||
client_stats=client_stats,
|
||||
logger=log_metric,
|
||||
)
|
||||
|
|
|
@ -42,7 +42,7 @@ class BaseStrategy:
|
|||
'''
|
||||
pass
|
||||
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||
'''Combine payloads to update model
|
||||
|
||||
Args:
|
||||
|
@ -50,6 +50,7 @@ class BaseStrategy:
|
|||
(aka model updater).
|
||||
curr_iter (int): current iteration.
|
||||
num_clients_curr_iter (int): number of clients on current iteration.
|
||||
total_clients (int): size of total pool of clients (for privacy accounting)
|
||||
client_stats (dict): stats being collected.
|
||||
logger (callback): function called to log quantities.
|
||||
'''
|
||||
|
|
|
@ -180,7 +180,7 @@ class DGA(BaseStrategy):
|
|||
self.client_parameters_stack.append(payload['gradients'])
|
||||
return True
|
||||
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||
'''Combine payloads to update model
|
||||
|
||||
Args:
|
||||
|
@ -188,6 +188,7 @@ class DGA(BaseStrategy):
|
|||
(aka model updater).
|
||||
curr_iter (int): current iteration.
|
||||
num_clients_curr_iter (int): number of clients on current iteration.
|
||||
total_clients (int): size of total pool of clients (for privacy accounting)
|
||||
client_stats (dict): stats being collected.
|
||||
logger (callback): function called to log quantities.
|
||||
|
||||
|
@ -220,7 +221,7 @@ class DGA(BaseStrategy):
|
|||
|
||||
# DP-specific steps
|
||||
privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger)
|
||||
eps = privacy.update_privacy_accountant(self.config, num_clients_curr_iter, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
|
||||
eps = privacy.update_privacy_accountant(self.config, total_clients, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
|
||||
if eps:
|
||||
print_rank(f'DP result: {eps}')
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class FedAvg(BaseStrategy):
|
|||
self.client_parameters_stack.append(payload['gradients'])
|
||||
return True
|
||||
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||
'''Combine payloads to update model
|
||||
|
||||
Args:
|
||||
|
@ -161,6 +161,8 @@ class FedAvg(BaseStrategy):
|
|||
print_rank('Updating learning rate scheduler')
|
||||
losses = worker_trainer.run_lr_scheduler(force_run_val=False)
|
||||
|
||||
# TODO: Global DP. See dga.py
|
||||
|
||||
return losses
|
||||
|
||||
def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None):
|
||||
|
|
Загрузка…
Ссылка в новой задаче