Merged PR 1402: Fix DP accounting bug

Fix DP accounting bug.  The accountant needed the size of the total pool of clients to sample from.
This commit is contained in:
Robert Sim 2022-10-24 20:33:31 +00:00
Родитель 1c3f556648
Коммит 6430cbf708
4 изменённых файлов: 9 добавлений и 4 удалений

Просмотреть файл

@ -416,6 +416,7 @@ class OptimizationServer(federated.Server):
worker_trainer=self.worker_trainer, worker_trainer=self.worker_trainer,
curr_iter=i, curr_iter=i,
num_clients_curr_iter=num_clients_curr_iter, num_clients_curr_iter=num_clients_curr_iter,
total_clients = len(self.client_idx_list),
client_stats=client_stats, client_stats=client_stats,
logger=log_metric, logger=log_metric,
) )

Просмотреть файл

@ -42,7 +42,7 @@ class BaseStrategy:
''' '''
pass pass
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None): def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
'''Combine payloads to update model '''Combine payloads to update model
Args: Args:
@ -50,6 +50,7 @@ class BaseStrategy:
(aka model updater). (aka model updater).
curr_iter (int): current iteration. curr_iter (int): current iteration.
num_clients_curr_iter (int): number of clients on current iteration. num_clients_curr_iter (int): number of clients on current iteration.
total_clients (int): size of total pool of clients (for privacy accounting)
client_stats (dict): stats being collected. client_stats (dict): stats being collected.
logger (callback): function called to log quantities. logger (callback): function called to log quantities.
''' '''

Просмотреть файл

@ -180,7 +180,7 @@ class DGA(BaseStrategy):
self.client_parameters_stack.append(payload['gradients']) self.client_parameters_stack.append(payload['gradients'])
return True return True
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None): def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
'''Combine payloads to update model '''Combine payloads to update model
Args: Args:
@ -188,6 +188,7 @@ class DGA(BaseStrategy):
(aka model updater). (aka model updater).
curr_iter (int): current iteration. curr_iter (int): current iteration.
num_clients_curr_iter (int): number of clients on current iteration. num_clients_curr_iter (int): number of clients on current iteration.
total_clients (int): size of total pool of clients (for privacy accounting)
client_stats (dict): stats being collected. client_stats (dict): stats being collected.
logger (callback): function called to log quantities. logger (callback): function called to log quantities.
@ -220,7 +221,7 @@ class DGA(BaseStrategy):
# DP-specific steps # DP-specific steps
privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger) privacy.apply_global_dp(self.config, worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=logger)
eps = privacy.update_privacy_accountant(self.config, num_clients_curr_iter, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter) eps = privacy.update_privacy_accountant(self.config, total_clients, curr_iter=curr_iter, num_clients_curr_iter=num_clients_curr_iter)
if eps: if eps:
print_rank(f'DP result: {eps}') print_rank(f'DP result: {eps}')

Просмотреть файл

@ -116,7 +116,7 @@ class FedAvg(BaseStrategy):
self.client_parameters_stack.append(payload['gradients']) self.client_parameters_stack.append(payload['gradients'])
return True return True
def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, client_stats, logger=None): def combine_payloads(self, worker_trainer, curr_iter, num_clients_curr_iter, total_clients, client_stats, logger=None):
'''Combine payloads to update model '''Combine payloads to update model
Args: Args:
@ -161,6 +161,8 @@ class FedAvg(BaseStrategy):
print_rank('Updating learning rate scheduler') print_rank('Updating learning rate scheduler')
losses = worker_trainer.run_lr_scheduler(force_run_val=False) losses = worker_trainer.run_lr_scheduler(force_run_val=False)
# TODO: Global DP. See dga.py
return losses return losses
def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None): def _aggregate_gradients(self, worker_trainer, num_clients_curr_iter, client_weights, metric_logger=None):