2021-11-22 23:26:49 +03:00
|
|
|
# Copyright (c) Microsoft Corporation.
|
|
|
|
# Licensed under the MIT license.
|
|
|
|
'''
|
|
|
|
In this file, we define the classes that live inside 'worker 0', the worker
|
|
|
|
responsible for orchestration and aggregation. The main class is the
|
|
|
|
OptimizationServer, which sends clients to the other workers to process and
|
|
|
|
combines the resulting models.
|
|
|
|
'''
|
|
|
|
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import random
|
|
|
|
import shutil
|
|
|
|
import time
|
|
|
|
from collections import defaultdict
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
|
|
# Internal imports
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
import core.federated as federated
|
2022-02-17 00:10:26 +03:00
|
|
|
from core.evaluation import Evaluation
|
2021-11-22 23:26:49 +03:00
|
|
|
from core.client import Client
|
2022-02-15 02:10:41 +03:00
|
|
|
from .strategies import select_strategy
|
2021-11-22 23:26:49 +03:00
|
|
|
from .trainer import (
|
|
|
|
ModelUpdater,
|
|
|
|
Trainer,
|
|
|
|
set_component_wise_lr,
|
|
|
|
)
|
|
|
|
from utils import (
|
|
|
|
get_lr,
|
|
|
|
print_rank,
|
|
|
|
update_json_log,
|
2022-04-26 00:31:45 +03:00
|
|
|
to_device,
|
2021-11-22 23:26:49 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
# For profiling
|
|
|
|
import cProfile
|
|
|
|
import pstats
|
|
|
|
|
|
|
|
# AzureML-related libs
|
|
|
|
from azureml.core import Run
|
|
|
|
run = Run.get_context()
|
|
|
|
|
|
|
|
|
|
|
|
class OptimizationServer(federated.Server):
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
|
2023-06-28 18:45:16 +03:00
|
|
|
config, idx_val_clients, idx_test_clients, single_worker):
|
2021-11-22 23:26:49 +03:00
|
|
|
'''Implement Server's orchestration and aggregation.
|
|
|
|
|
|
|
|
This is the main Server class, that actually implements orchestration
|
|
|
|
and aggregation, inheriting from `federated.Server`, which deals with
|
|
|
|
communication only.
|
|
|
|
|
|
|
|
The `train` method is central in FLUTE, as it defines good part of what
|
|
|
|
happens during training.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
num_clients (int): total available clients.
|
|
|
|
model (torch.nn.Module): neural network model.
|
|
|
|
optimizer (torch.optim.Optimizer): optimizer.
|
|
|
|
ss_scheduler: scheduled sampling scheduler.
|
|
|
|
data_path (str): points to where data is.
|
|
|
|
model_path (str): points to where pretrained model is.
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
|
2021-11-22 23:26:49 +03:00
|
|
|
config (dict): JSON style configuration parameters
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
idx_val_clients (list): validation client ids
|
|
|
|
idx_test_clients (list): testing clients ids
|
2021-11-22 23:26:49 +03:00
|
|
|
'''
|
|
|
|
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Initialize all attributes from arguments
|
|
|
|
self.client_idx_list = list(range(num_clients))
|
|
|
|
self.config = config
|
|
|
|
server_config = config['server_config']
|
|
|
|
decoder_config = config.get('decoder_config', None)
|
|
|
|
|
|
|
|
self.max_iteration = server_config['max_iteration']
|
|
|
|
self.do_clustering = server_config.get('clustering', False)
|
2022-12-14 21:18:51 +03:00
|
|
|
self.send_dicts = server_config.get('send_dicts', False)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
self.num_clients_per_iteration = [int(x) for x in server_config['num_clients_per_iteration'].split(',')] \
|
|
|
|
if isinstance(server_config['num_clients_per_iteration'], str) \
|
|
|
|
else [server_config['num_clients_per_iteration']]
|
|
|
|
|
|
|
|
self.val_freq = server_config['val_freq']
|
2022-02-17 00:10:26 +03:00
|
|
|
self.req_freq = server_config['rec_freq']
|
|
|
|
|
2023-06-28 18:45:16 +03:00
|
|
|
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients, single_worker)
|
2022-03-28 21:03:00 +03:00
|
|
|
|
|
|
|
# TODO: does this need to be adjusted for custom metrics?
|
2022-04-26 00:31:45 +03:00
|
|
|
self.metrics = dict()
|
2022-02-17 00:10:26 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
self.model_backup_freq = server_config.get('model_backup_freq', 100)
|
|
|
|
self.worker_trainer_config = server_config.get('trainer_config', {})
|
|
|
|
|
|
|
|
self.aggregate_median = server_config['aggregate_median']
|
|
|
|
self.initial_lr_client = server_config.get('initial_lr_client', -1.0)
|
|
|
|
self.lr_decay_factor = server_config.get('lr_decay_factor', 1.0)
|
|
|
|
|
|
|
|
self.model_type = config['model_config']['model_type']
|
|
|
|
self.quant_thresh = config['client_config'].get('quant_thresh', None)
|
|
|
|
self.quant_bits = config['client_config'].get('quant_bits', 10)
|
|
|
|
|
|
|
|
self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data']
|
|
|
|
self.data_path = data_path
|
2023-06-28 18:45:16 +03:00
|
|
|
self.single_worker = single_worker
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Get max grad norm from data config
|
|
|
|
if 'train' in server_config['data_config']:
|
|
|
|
max_grad_norm = server_config['data_config']['train'].get('max_grad_norm', None)
|
|
|
|
else:
|
|
|
|
max_grad_norm = None
|
|
|
|
|
|
|
|
# Creating an instance to update the model with stats aggregated from workers
|
|
|
|
self.worker_trainer = ModelUpdater(
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
ss_scheduler=ss_scheduler,
|
2023-03-16 19:28:07 +03:00
|
|
|
train_dataloader=server_train_dataloader,
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
val_dataloader=None,
|
2021-11-22 23:26:49 +03:00
|
|
|
max_grad_norm=max_grad_norm,
|
|
|
|
anneal_config=server_config['annealing_config'],
|
|
|
|
model_type=self.model_type,
|
|
|
|
decoder_config=decoder_config
|
|
|
|
)
|
2022-02-17 00:10:26 +03:00
|
|
|
self.metrics['worker_trainer'] = self.worker_trainer
|
2021-11-22 23:26:49 +03:00
|
|
|
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
|
|
|
self.server_replay_iterations = None
|
|
|
|
self.server_trainer = None
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
if server_train_dataloader is not None:
|
2021-11-22 23:26:49 +03:00
|
|
|
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
|
|
|
assert 'optimizer_config' in server_config[
|
|
|
|
'server_replay_config'], 'server-side replay training optimizer is not set'
|
|
|
|
self.server_optimizer_config = server_config['server_replay_config']['optimizer_config']
|
|
|
|
self.server_trainer_config = server_config['server_replay_config'].get('trainer_config', {})
|
|
|
|
self.server_replay_iterations = server_config['server_replay_config']['server_iterations']
|
|
|
|
self.server_trainer = Trainer(
|
|
|
|
model=model,
|
|
|
|
optimizer=None,
|
|
|
|
ss_scheduler=ss_scheduler,
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
train_dataloader=server_train_dataloader,
|
2021-11-22 23:26:49 +03:00
|
|
|
server_replay_config=server_config['server_replay_config'],
|
|
|
|
max_grad_norm=server_config['server_replay_config']\
|
|
|
|
.get('max_grad_norm',server_config['data_config']['train']\
|
|
|
|
.get('max_grad_norm',None)),
|
2023-03-16 19:28:07 +03:00
|
|
|
anneal_config=server_config['server_replay_config'].get('annealing_config', None),
|
|
|
|
ignore_subtask = server_config['server_replay_config'].get('ignore_subtask', False)
|
2021-11-22 23:26:49 +03:00
|
|
|
)
|
|
|
|
|
|
|
|
self.skip_model_update = False # will not update the model if True
|
|
|
|
|
|
|
|
self.train_loss = 0.0
|
|
|
|
self.model_path = model_path
|
|
|
|
self.best_model_criterion = server_config['best_model_criterion']
|
|
|
|
self.fall_back_to_best_model = server_config['fall_back_to_best_model']
|
|
|
|
self.last_model_path = os.path.join(self.model_path, 'latest_model.tar')
|
|
|
|
self.best_model_path = os.path.join(self.model_path,
|
|
|
|
'best_val_{}_model.tar'.format(self.best_model_criterion))
|
|
|
|
self.log_path = os.path.join(self.model_path, 'status_log.json')
|
|
|
|
self.cur_iter_no = 0 # keep the iteration number for Tensor board plotting
|
|
|
|
self.lr_weight = 1.0
|
|
|
|
|
|
|
|
self.losses = []
|
|
|
|
self.no_label_updates = 0 # no. label updates
|
|
|
|
|
|
|
|
# Update the parameters above if the log file
|
|
|
|
if server_config.get('resume_from_checkpoint', False):
|
|
|
|
self.load_saved_status()
|
|
|
|
|
|
|
|
# Decoding config
|
|
|
|
self.decoder_config = decoder_config
|
|
|
|
self.spm_model = server_config['data_config']['test'].get('spm_model', None)
|
|
|
|
|
|
|
|
self.do_profiling = server_config.get('do_profiling', False)
|
|
|
|
|
2022-02-15 02:10:41 +03:00
|
|
|
StrategyClass = select_strategy(config['strategy'])
|
|
|
|
self.strategy = StrategyClass('server', self.config, self.model_path)
|
|
|
|
print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)
|
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
def load_saved_status(self):
|
|
|
|
'''Load checkpoint from disk'''
|
|
|
|
|
|
|
|
# Check if model is on disk, if so loads it onto trainer
|
|
|
|
if os.path.exists(self.last_model_path):
|
|
|
|
print_rank('Resuming from checkpoint model {}'.format(self.last_model_path))
|
|
|
|
self.worker_trainer.load(self.last_model_path, update_lr_scheduler=True, update_ss_scheduler=True)
|
|
|
|
if self.server_trainer is not None:
|
|
|
|
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
|
|
|
|
|
|
|
# Check if log is on disk, if so loads it onto current stats
|
|
|
|
if os.path.exists(self.log_path):
|
|
|
|
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
|
|
|
|
elems = json.load(logfp)
|
|
|
|
self.cur_iter_no = elems.get('i', 0)
|
2022-03-28 21:03:00 +03:00
|
|
|
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
|
|
|
|
self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
|
2022-02-17 00:10:26 +03:00
|
|
|
self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
|
2022-03-28 21:03:00 +03:00
|
|
|
self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
|
2021-11-22 23:26:49 +03:00
|
|
|
self.lr_weight = elems.get('weight', 1.0)
|
|
|
|
self.no_label_updates = elems.get('num_label_updates', 0)
|
|
|
|
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
def run(self):
|
|
|
|
'''Trigger training.
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
This is a simple wrapper to the `train` method.
|
|
|
|
'''
|
|
|
|
print_rank('server started')
|
|
|
|
self.train()
|
|
|
|
print_rank('server terminated')
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
'''Main method for training.'''
|
|
|
|
|
|
|
|
self.run_stats = {
|
|
|
|
'secsPerClientRound': [],
|
|
|
|
'secsPerClient': [],
|
|
|
|
'secsPerClientTraining': [],
|
|
|
|
'secsPerClientSetup': [],
|
|
|
|
'secsPerClientFull': [],
|
|
|
|
'secsPerRoundHousekeeping': [],
|
|
|
|
'secsPerRoundTotal': [],
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
'communicationCosts': []
|
2021-11-22 23:26:49 +03:00
|
|
|
}
|
|
|
|
|
|
|
|
run.log('Max iterations', self.max_iteration)
|
|
|
|
try:
|
2022-04-26 00:31:45 +03:00
|
|
|
self.worker_trainer.model = to_device(self.worker_trainer.model)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Do an initial validation round to understand the pretrained model's validation accuracy
|
|
|
|
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
|
2022-02-17 00:10:26 +03:00
|
|
|
eval_list = []
|
2021-11-22 23:26:49 +03:00
|
|
|
if self.cur_iter_no == 0:
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2022-02-23 20:17:12 +03:00
|
|
|
if self.config['server_config']['initial_rec']:
|
|
|
|
eval_list.append('test')
|
|
|
|
if self.config['server_config']['initial_val']:
|
|
|
|
eval_list.append('val')
|
2022-02-17 00:10:26 +03:00
|
|
|
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
2022-02-23 20:17:12 +03:00
|
|
|
|
|
|
|
print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
|
2022-02-17 00:10:26 +03:00
|
|
|
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
2022-03-28 21:03:00 +03:00
|
|
|
eval_list = [] # some cleanup
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Dump all the information in aggregate_metric
|
|
|
|
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
|
|
|
|
for token in ['best_val_loss', 'best_val_acc', 'best_test_acc', 'latest']:
|
|
|
|
self.worker_trainer.save(
|
|
|
|
model_path=self.model_path,
|
|
|
|
token=token,
|
|
|
|
config=self.config['server_config']
|
|
|
|
)
|
|
|
|
|
2022-03-28 21:03:00 +03:00
|
|
|
# Training loop
|
2021-11-22 23:26:49 +03:00
|
|
|
self.worker_trainer.model.train()
|
|
|
|
for i in range(self.cur_iter_no, self.max_iteration):
|
|
|
|
begin = time.time()
|
|
|
|
metrics_payload = {}
|
|
|
|
|
|
|
|
def log_metric(k, v):
|
|
|
|
metrics_payload[k] = v
|
|
|
|
|
|
|
|
print_rank('==== iteration {}'.format(i))
|
|
|
|
log_metric('Current iteration', i)
|
|
|
|
|
|
|
|
# Initial value for the learning rate of the worker
|
|
|
|
initial_lr = self.initial_lr_client * self.lr_weight
|
|
|
|
print_rank('Client learning rate {}'.format(initial_lr))
|
|
|
|
|
|
|
|
# Run training on clients
|
|
|
|
self.worker_trainer.model.zero_grad()
|
|
|
|
self.train_loss = []
|
2022-12-14 21:18:51 +03:00
|
|
|
|
|
|
|
if self.send_dicts: # Send state dictionaries
|
|
|
|
glob_payload = [self.worker_trainer.model.state_dict()[param_key].to(torch.device('cpu')) for param_key in self.worker_trainer.model.state_dict()]
|
|
|
|
else: # Send parameters
|
|
|
|
glob_payload = [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]
|
|
|
|
|
|
|
|
server_data = (initial_lr, glob_payload, i)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Random number of clients per iteration
|
|
|
|
if len(self.num_clients_per_iteration) > 1:
|
|
|
|
num_clients_curr_iter = random.randint(
|
|
|
|
self.num_clients_per_iteration[0],
|
|
|
|
self.num_clients_per_iteration[1]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
num_clients_curr_iter = self.num_clients_per_iteration[0]
|
|
|
|
log_metric('Clients for round', num_clients_curr_iter)
|
|
|
|
|
|
|
|
# Perform annealing in quantization threshold
|
|
|
|
if self.quant_thresh is not None:
|
|
|
|
self.config['client_config']['quant_thresh'] *= self.config['client_config'].get('quant_anneal', 1.0)
|
|
|
|
self.quant_thresh = self.config['client_config']['quant_thresh']
|
|
|
|
log_metric('Quantization Thresh.', self.config['client_config']['quant_thresh'])
|
|
|
|
|
|
|
|
# Create the pool of clients -- sample from this pool to assign to workers
|
|
|
|
sampled_idx_clients = random.sample(self.client_idx_list,
|
|
|
|
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
# Initialize stats
|
|
|
|
clients_begin = time.time()
|
|
|
|
|
|
|
|
client_losses = []
|
|
|
|
client_mag_grads = []
|
|
|
|
client_mean_grads = []
|
|
|
|
client_var_grads = []
|
|
|
|
client_norm_grads = []
|
|
|
|
|
|
|
|
self.run_stats['secsPerClient'].append([])
|
|
|
|
self.run_stats['secsPerClientFull'].append([])
|
|
|
|
self.run_stats['secsPerClientTraining'].append([])
|
|
|
|
self.run_stats['secsPerClientSetup'].append([])
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
self.run_stats['communicationCosts'].append([])
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Check if we want privacy metrics
|
|
|
|
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
|
|
|
|
self.config['privacy_metrics_config']['apply_metrics']
|
|
|
|
adaptive_leakage = apply_privacy_metrics and \
|
|
|
|
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
|
|
|
|
if apply_privacy_metrics:
|
2022-03-28 21:03:00 +03:00
|
|
|
privacy_metrics_stats = defaultdict(list)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Initialize profiler
|
|
|
|
profiler = None
|
|
|
|
if self.do_profiling:
|
|
|
|
profiler = cProfile.Profile()
|
|
|
|
profiler.enable()
|
|
|
|
|
|
|
|
# Reset gradient for the model before assigning the new gradients
|
|
|
|
self.worker_trainer.model.zero_grad()
|
2022-12-14 21:18:51 +03:00
|
|
|
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
|
2023-06-28 18:45:16 +03:00
|
|
|
for client_output in self.process_clients(sampled_idx_clients, server_data, self.single_worker):
|
2021-11-22 23:26:49 +03:00
|
|
|
# Process client output
|
2022-03-28 21:03:00 +03:00
|
|
|
client_timestamp = client_output['ts']
|
2021-11-22 23:26:49 +03:00
|
|
|
client_stats = client_output['cs']
|
|
|
|
client_loss = client_output['tl']
|
2022-03-28 21:03:00 +03:00
|
|
|
client_mag_grad = client_output['mg']
|
2021-11-22 23:26:49 +03:00
|
|
|
client_mean_grad = client_output['ng']
|
2022-02-15 02:10:41 +03:00
|
|
|
client_var_grad = client_output['vg']
|
2021-11-22 23:26:49 +03:00
|
|
|
client_norm_grad = client_output['rg']
|
2022-02-15 02:10:41 +03:00
|
|
|
client_payload = client_output['pl']
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
if apply_privacy_metrics:
|
|
|
|
privacy_stats = client_output['ps']
|
|
|
|
for metric, value in privacy_stats.items():
|
|
|
|
privacy_metrics_stats[metric].append(value)
|
2022-03-28 21:03:00 +03:00
|
|
|
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
2022-02-15 02:10:41 +03:00
|
|
|
# Get actual pseudo-gradients for aggregation
|
|
|
|
payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
|
|
|
|
if not payload_processed:
|
|
|
|
print_rank('Dropping client', loglevel=logging.DEBUG)
|
2021-11-22 23:26:49 +03:00
|
|
|
num_clients_curr_iter -= 1
|
|
|
|
continue
|
|
|
|
|
|
|
|
# Aggregate stats
|
|
|
|
self.train_loss.append(client_loss)
|
|
|
|
client_losses.append(client_loss)
|
2022-02-15 02:10:41 +03:00
|
|
|
client_mag_grads.append(client_mag_grad.item())
|
2021-11-22 23:26:49 +03:00
|
|
|
client_mean_grads.append(client_mean_grad.item())
|
|
|
|
client_var_grads.append(client_var_grad.item())
|
|
|
|
client_norm_grads.append(client_norm_grad.item())
|
|
|
|
|
|
|
|
# Mark the end of client processing
|
|
|
|
client_end = time.time()
|
|
|
|
|
|
|
|
self.run_stats['secsPerClientFull'][-1].append(client_stats['full cost'])
|
|
|
|
self.run_stats['secsPerClientTraining'][-1].append(client_stats['training'])
|
|
|
|
self.run_stats['secsPerClientSetup'][-1].append(client_stats['setup'])
|
|
|
|
self.run_stats['secsPerClient'][-1].append(client_end - clients_begin)
|
|
|
|
|
|
|
|
# Tear down profiler
|
|
|
|
if self.do_profiling:
|
|
|
|
profiler.disable()
|
|
|
|
stats = pstats.Stats(profiler)
|
|
|
|
stats.sort_stats('cumulative').print_stats()
|
|
|
|
|
|
|
|
# Prepare output
|
|
|
|
client_mag_grads = np.array(client_mag_grads)
|
|
|
|
client_mean_grads = np.array(client_mean_grads)
|
|
|
|
client_var_grads = np.array(client_var_grads)
|
|
|
|
client_norm_grads = np.array(client_norm_grads)
|
2022-02-15 02:10:41 +03:00
|
|
|
|
|
|
|
client_stats = (client_mag_grads, client_mean_grads, client_var_grads)
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
dump_norm_stats = self.config.get('dump_norm_stats', False)
|
|
|
|
if dump_norm_stats:
|
|
|
|
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
|
|
|
|
outF.write('{}\n'.format(json.dumps(list(client_norm_grads))))
|
|
|
|
|
|
|
|
# Print the privacy metrics
|
|
|
|
if apply_privacy_metrics:
|
|
|
|
for metric, values in privacy_metrics_stats.items():
|
|
|
|
if metric == 'Dropped clients':
|
|
|
|
log_metric(metric, sum(values))
|
|
|
|
else:
|
|
|
|
log_metric(metric, max(values))
|
|
|
|
|
|
|
|
if type(adaptive_leakage) is float:
|
|
|
|
values = privacy_metrics_stats['Practical epsilon (Max leakage)']
|
|
|
|
new_threshold = list(sorted(values))[int(adaptive_leakage*len(values))]
|
|
|
|
print_rank('Updating leakage threshold to {}'.format(new_threshold))
|
|
|
|
self.config['privacy_metrics_config']['max_allowed_leakage'] = new_threshold
|
|
|
|
|
|
|
|
# Mark that all clients have been processed
|
|
|
|
end = time.time()
|
|
|
|
self.run_stats['secsPerClientRound'].append(end - begin)
|
|
|
|
begin = end
|
|
|
|
|
|
|
|
# Log the training loss to tensorboard/AML
|
|
|
|
log_metric('Training loss', sum(self.train_loss))
|
|
|
|
|
2022-02-15 02:10:41 +03:00
|
|
|
# Combine payloads
|
|
|
|
self.losses = self.strategy.combine_payloads(
|
|
|
|
worker_trainer=self.worker_trainer,
|
|
|
|
curr_iter=i,
|
|
|
|
num_clients_curr_iter=num_clients_curr_iter,
|
2022-10-24 23:33:31 +03:00
|
|
|
total_clients = len(self.client_idx_list),
|
2022-02-15 02:10:41 +03:00
|
|
|
client_stats=client_stats,
|
|
|
|
logger=log_metric,
|
|
|
|
)
|
2022-12-14 21:18:51 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
# Run a couple of iterations of training data on the server
|
|
|
|
if self.server_trainer is not None:
|
|
|
|
print_rank('Running replay iterations on server')
|
|
|
|
|
|
|
|
if 'updatable_names' in self.server_trainer_config:
|
|
|
|
set_component_wise_lr(
|
|
|
|
self.worker_trainer.model,
|
|
|
|
self.server_optimizer_config,
|
|
|
|
self.server_trainer_config['updatable_names']
|
|
|
|
)
|
|
|
|
self.server_trainer.prepare_iteration(self.worker_trainer.model)
|
|
|
|
self.server_trainer.train_desired_samples(self.server_replay_iterations)
|
|
|
|
self.worker_trainer.model.load_state_dict(self.server_trainer.model.state_dict())
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
# Update a sampling scheduler
|
|
|
|
print_rank('Run ss scheduler')
|
|
|
|
self.worker_trainer.run_ss_scheduler()
|
|
|
|
|
|
|
|
# Run inference and score on val/test depending on the iter. number
|
2022-02-17 00:10:26 +03:00
|
|
|
if ((i+1) % self.val_freq) == 0:
|
|
|
|
eval_list.append("val")
|
|
|
|
if ((i+1) % self.req_freq) == 0 :
|
|
|
|
eval_list.append("test")
|
2022-12-14 21:18:51 +03:00
|
|
|
|
2022-02-17 00:10:26 +03:00
|
|
|
if len(eval_list)> 0:
|
|
|
|
print_rank('Running {} at itr={}'.format(eval_list,i+1))
|
|
|
|
self.metrics['worker_trainer'] = self.worker_trainer
|
2022-12-14 21:18:51 +03:00
|
|
|
if hasattr(self.strategy,'tmp_unsup'):
|
|
|
|
self.metrics['tmp_sup'] = self.strategy.tmp_sup
|
|
|
|
self.metrics['tmp_unsup'] = self.strategy.tmp_unsup
|
2022-02-17 00:10:26 +03:00
|
|
|
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
|
|
|
self.losses = self.evaluation.losses
|
|
|
|
eval_list = []
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2022-02-17 00:10:26 +03:00
|
|
|
# Create a schedule for the initial_lr (for the worker)
|
|
|
|
if 'val' in eval_list:
|
|
|
|
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
2022-03-28 21:03:00 +03:00
|
|
|
if not (self.losses[0] < self.metrics['best_val_loss']):
|
2022-02-17 00:10:26 +03:00
|
|
|
self.lr_weight *= self.lr_decay_factor
|
|
|
|
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
# Backup the current best models
|
|
|
|
self.backup_models(i)
|
|
|
|
|
|
|
|
# Fall back to the best model if the option is enabled
|
|
|
|
self.fall_back_to_prev_best_status()
|
|
|
|
|
2022-04-26 00:31:45 +03:00
|
|
|
# Logging the latest best values only after the 1st val/test round has been executed
|
|
|
|
if len(self.metrics) > 1:
|
|
|
|
update_json_log(
|
|
|
|
self.log_path,
|
|
|
|
{
|
|
|
|
'i': i + 1,
|
|
|
|
'best_val_loss': float(self.metrics['best_val_loss']),
|
|
|
|
'best_val_acc': float(self.metrics['best_val_acc']),
|
|
|
|
'best_test_loss': float(self.metrics['best_test_loss']),
|
|
|
|
'best_test_acc': float(self.metrics['best_test_acc']),
|
|
|
|
'weight': float(self.lr_weight),
|
|
|
|
'num_label_updates': int(self.no_label_updates)
|
|
|
|
},
|
|
|
|
)
|
2021-11-22 23:26:49 +03:00
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
# Aggregate stats
|
|
|
|
self.run_stats['secsPerRoundHousekeeping'].append(end - begin)
|
|
|
|
self.run_stats['secsPerRoundTotal'].append(self.run_stats['secsPerClientRound'][-1] + \
|
|
|
|
self.run_stats['secsPerRoundHousekeeping'][-1])
|
|
|
|
|
|
|
|
log_metric('secsPerRoundTotal', self.run_stats['secsPerRoundTotal'][-1])
|
|
|
|
if self.do_profiling:
|
|
|
|
log_metric('secsPerClientRound', self.run_stats['secsPerClientRound'][-1])
|
|
|
|
log_metric('secsPerRoundHousekeeping', self.run_stats['secsPerRoundHousekeeping'][-1])
|
|
|
|
|
|
|
|
metrics_for_stats = [
|
|
|
|
'secsPerClient',
|
|
|
|
'secsPerClientTraining',
|
|
|
|
'secsPerClientFull',
|
2022-01-24 19:33:47 +03:00
|
|
|
'secsPerClientSetup',
|
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.
Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below.
![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)
However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.
I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.
![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)
There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.
The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.
I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.
So for now the only thing left is to run the sanity-checks on AML, links below.
Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
2022-08-26 17:54:27 +03:00
|
|
|
'communicationCosts',
|
2021-11-22 23:26:49 +03:00
|
|
|
]
|
|
|
|
|
|
|
|
for metric in metrics_for_stats:
|
2022-01-24 19:33:47 +03:00
|
|
|
log_metric(f'{metric}Mean', np.mean(self.run_stats[metric][-1]))
|
|
|
|
log_metric(f'{metric}Median', np.median(self.run_stats[metric][-1]))
|
|
|
|
log_metric(f'{metric}Max', max(self.run_stats[metric][-1]))
|
2021-11-22 23:26:49 +03:00
|
|
|
|
2022-01-24 19:33:47 +03:00
|
|
|
for k in self.run_stats:
|
2021-11-22 23:26:49 +03:00
|
|
|
if k in metrics_for_stats:
|
|
|
|
print_rank('{}: {}'.format(k, max(self.run_stats[k][-1])), loglevel=logging.DEBUG)
|
|
|
|
else:
|
|
|
|
print_rank('{}: {}'.format(k, self.run_stats[k][-1]), loglevel=logging.DEBUG)
|
|
|
|
|
|
|
|
# Log all the metrics
|
|
|
|
for k in metrics_payload:
|
|
|
|
run.log(k, metrics_payload[k])
|
|
|
|
|
|
|
|
finally: # perform cleanup even if error was raised above
|
|
|
|
self.terminate_workers(terminate=(not self.do_clustering))
|
|
|
|
|
|
|
|
def backup_models(self, i):
|
|
|
|
'''Save the current best models.
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
Save CER model, the best loss model and the best WER model. This occurs
|
|
|
|
at a specified period.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
i: no. of iterations.
|
|
|
|
'''
|
|
|
|
|
|
|
|
# Always save the latest model
|
|
|
|
self.worker_trainer.save(
|
|
|
|
model_path=self.model_path,
|
|
|
|
token='latest',
|
|
|
|
config=self.config['server_config'],
|
|
|
|
)
|
|
|
|
|
|
|
|
if (i % self.model_backup_freq) == 0: # save the current best models
|
|
|
|
self.worker_trainer.save(
|
|
|
|
model_path=self.model_path,
|
|
|
|
token='epoch{}'.format(i),
|
|
|
|
config=self.config['server_config']
|
|
|
|
)
|
|
|
|
|
|
|
|
for bodyname in ['best_val_acc', 'best_val_loss', 'best_test_acc']:
|
|
|
|
src_model_path = os.path.join(self.model_path, '{}_model.tar'.format(bodyname))
|
|
|
|
if os.path.exists(src_model_path):
|
|
|
|
dst_model_path = os.path.join(self.model_path, 'epoch{}_{}_model.tar'.format(i, bodyname))
|
|
|
|
shutil.copyfile(src_model_path, dst_model_path)
|
|
|
|
print_rank('Saved {}'.format(dst_model_path))
|
|
|
|
|
|
|
|
def fall_back_to_prev_best_status(self):
|
|
|
|
'''Go back to the past best status and switch to the recent best model.'''
|
|
|
|
|
|
|
|
if self.fall_back_to_best_model:
|
|
|
|
print_rank('falling back to model {}'.format(self.best_model_path))
|
|
|
|
|
|
|
|
# Save current learning rate
|
|
|
|
tmp_lr = get_lr(self.worker_trainer.optimizer)
|
|
|
|
|
|
|
|
# Load previous best model
|
|
|
|
self.worker_trainer.load(self.best_model_path, update_lr_scheduler=False, update_ss_scheduler=False)
|
|
|
|
|
|
|
|
# Update previous learning rate on optimizer
|
|
|
|
for g in self.worker_trainer.optimizer.param_groups:
|
|
|
|
g['lr'] = tmp_lr
|
|
|
|
|
|
|
|
if self.server_trainer is not None:
|
|
|
|
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
|
|
|
|
|
|
|
|
2022-06-14 21:52:31 +03:00
|
|
|
def select_server(server_type):
|
2021-11-22 23:26:49 +03:00
|
|
|
'''Select a server type using different possible strings.
|
|
|
|
|
|
|
|
Right now this just returns `OptimizationServer`, but this
|
|
|
|
function could be useful when there are multiple choices of
|
|
|
|
server.
|
2022-03-28 21:03:00 +03:00
|
|
|
|
2021-11-22 23:26:49 +03:00
|
|
|
Args:
|
|
|
|
server_type (str): indicates server choice.
|
|
|
|
config (dict): config parsed from YAML, passed so that
|
|
|
|
parameters can be used to select a given server.
|
|
|
|
'''
|
2022-06-14 21:52:31 +03:00
|
|
|
if server_type == "personalization":
|
|
|
|
from experiments.cv.server import PersonalizationServer
|
|
|
|
return PersonalizationServer
|
|
|
|
else:
|
|
|
|
return OptimizationServer
|