зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1402: Fix DP accounting bug
Fix DP accounting bug. The accountant needed the size of the total pool of clients to sample from.
This commit is contained in:
Родитель
1c3f556648
Коммит
6430cbf708
|
@ -416,6 +416,7 @@ class OptimizationServer(federated.Server):
|
||||||
worker_trainer=self.worker_trainer,
|
worker_trainer=self.worker_trainer,
|
||||||
curr_iter=i,
|
curr_iter=i,
|
||||||
num_clients_curr_iter=num_clients_curr_iter,
|
num_clients_curr_iter=num_clients_curr_iter,
|
||||||
|
total_clients = len(self.client_idx_list),
|
||||||
client_stats=client_stats,
|
client_stats=client_stats,
|
||||||
logger=log_metric,
|
logger=log_metric,
|
||||||
)
|
)
|
||||||
|
|
|
@ -42,7 +42,7 @@ class BaseStrategy:
|
||||||
'''
|
'''
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||||
'''Combine payloads to update model
|
'''Combine payloads to update model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -50,6 +50,7 @@ class BaseStrategy:
|
||||||
(aka model updater).
|
(aka model updater).
|
||||||
curr_iter (int): current iteration.
|
curr_iter (int): current iteration.
|
||||||
num_clients_curr_iter (int): number of clients on current iteration.
|
num_clients_curr_iter (int): number of clients on current iteration.
|
||||||
|
total_clients (int): size of total pool of clients (for privacy accounting)
|
||||||
client_stats (dict): stats being collected.
|
client_stats (dict): stats being collected.
|
||||||
logger (callback): function called to log quantities.
|
logger (callback): function called to log quantities.
|
||||||
'''
|
'''
|
||||||
|
|
|
@ -180,7 +180,7 @@ class DGA(BaseStrategy):
|
||||||
self.client_parameters_stack.append(payload['gradients'])
|
self.client_parameters_stack.append(payload['gradients'])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||||
'''Combine payloads to update model
|
'''Combine payloads to update model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -188,6 +188,7 @@ class DGA(BaseStrategy):
|
||||||
(aka model updater).
|
(aka model updater).
|
||||||
curr_iter (int): current iteration.
|
curr_iter (int): current iteration.
|
||||||
num_clients_curr_iter (int): number of clients on current iteration.
|
num_clients_curr_iter (int): number of clients on current iteration.
|
||||||
|
total_clients (int): size of total pool of clients (for privacy accounting)
|
||||||
client_stats (dict): stats being collected.
|
client_stats (dict): stats being collected.
|
||||||
logger (callback): function called to log quantities.
|
logger (callback): function called to log quantities.
|
||||||
|
|
||||||
|
@ -220,7 +221,7 @@ class DGA(BaseStrategy):
|
||||||
|
|
||||||
# DP-specific steps
|
# DP-specific steps
|
||||||
privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger)
|
privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger)
|
||||||
eps = privacy.update_privacy_accountant(self.config, num_clients_curr_iter, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
|
eps = privacy.update_privacy_accountant(self.config, total_clients, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
|
||||||
if eps:
|
if eps:
|
||||||
print_rank(f'DP result: {eps}')
|
print_rank(f'DP result: {eps}')
|
||||||
|
|
||||||
|
|
|
@ -116,7 +116,7 @@ class FedAvg(BaseStrategy):
|
||||||
self.client_parameters_stack.append(payload['gradients'])
|
self.client_parameters_stack.append(payload['gradients'])
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None):
|
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
|
||||||
'''Combine payloads to update model
|
'''Combine payloads to update model
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -161,6 +161,8 @@ class FedAvg(BaseStrategy):
|
||||||
print_rank('Updating learning rate scheduler')
|
print_rank('Updating learning rate scheduler')
|
||||||
losses = worker_trainer.run_lr_scheduler(force_run_val=False)
|
losses = worker_trainer.run_lr_scheduler(force_run_val=False)
|
||||||
|
|
||||||
|
# TODO: Global DP. See dga.py
|
||||||
|
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None):
|
def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче