зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1272: Replace MPI -> torch.distributed
This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_. Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job, as shown in the figure below. ![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png) However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way. I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient. ![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png) There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors. The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough. I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU. So for now the only thing left is to run the sanity-checks on AML, links below. Sanity checks after cleanup: [x] nlg_gru: https://aka.ms/amlt?q=c9tih [x] mlm_bert: https://aka.ms/amlt?q=dbs8y [x] classif_cnn: https://aka.ms/amlt?q=da2qr [x] ecg: https://aka.ms/amlt?q=c9jof [x] cv: https://aka.ms/amlt?q=da2k4
This commit is contained in:
Родитель
887c8ac74f
Коммит
30e41400a4
|
@ -0,0 +1,40 @@
|
|||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
## [0.1.0] - 2021-11-22
|
||||
|
||||
We're super excited to announce FLUTE: Federated Learning Utilities for Testing and Experimentation, a platform for conducting high-performance federated learning simulations!
|
||||
|
||||
This first release fully focuses on implementing fast prototyping to validate different CL scenarios
|
||||
in an Federated environment.
|
||||
|
||||
### Features
|
||||
|
||||
- large scale simulation (millions of clients, sampling tens of thousands per round).
|
||||
- multi-GPU and multi-node orchestration backed up by MPI.
|
||||
- local or global differential privacy.
|
||||
- model quantization.
|
||||
- a variety of standard optimizers and aggregation methods.
|
||||
- most model types including CNNs, RNNs, and Huggingface Transformers.
|
||||
- extensibility, enabling new models, dataloaders, optimizers, and aggregators.
|
||||
- local or cloud-based job staging using AzureML.
|
||||
|
||||
|
||||
## [1.0.0] - 2022-08-29
|
||||
|
||||
This release contain major changes in the communication backbone , in order
|
||||
to run previous experiments you have already integrated in FLUTE, please make sure
|
||||
to use `torch.distributed` instead of `MPI `to launch the jobs. For more documentation
|
||||
about the new command, please refer to the [README](README.md).
|
||||
|
||||
|
||||
### New features
|
||||
|
||||
- 🏎 Better performance: Support for NCCL and Gloo as backend communication protocols.
|
||||
- Improvements in GPU utilization and overall communication speed (on the order of minutes!) for projects with huge models and datasets.
|
||||
- 🌟 Remove file type dependency on client.py, now FLUTE can receive any kind of dataset and even download the data on-the-fly. The data intantiation is completely under control of each task dataset.
|
||||
- In older versions FLUTE only allowed `json` and `hdf5` files, so the client could recognize it.
|
||||
- 🌟 Abstract classes for new models/dataloaders.
|
||||
- 🌟 Allows Federated Learning with Personalization.
|
||||
- Personalization allows you to leverage each client local data to obtain models that are better adjusted to their own data distribution. You can run the `cv` task in order to try out this feature.
|
13
README.md
13
README.md
|
@ -24,15 +24,15 @@ conda create -n FLUTE python==3.8
|
|||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
You will also need some MPI runtime such as OpenMPI (on Linux) or MS-MPI (on Windows). There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
|
||||
FLUTE uses torch.distributed API as its main communication backbone, supporting three built-in backends. For more information please refer to [Distributed Communication Package](https://pytorch.org/docs/stable/distributed.html). Therefore, we highly suggest to use NCCL backend for distributed GPU training and Gloo for distributed CPU training. There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
|
||||
|
||||
After this initial setup, you can use the data created for the integration test inside of `testing` for a first local run. Note that this data needs to be download manually, for more instructions please look at [the README file inside `testing`](testing/README.md).
|
||||
|
||||
```
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru
|
||||
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru -backend nccl
|
||||
```
|
||||
|
||||
This config uses 1 MPI node with 3 workers (1 server, 2 clients). The config file `testing/configs/hello_world_local.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.
|
||||
This config uses 1 node with 3 workers (1 server, 2 clients). The config file `testing/configs/hello_world_local.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.
|
||||
|
||||
## Documentation
|
||||
|
||||
|
@ -55,7 +55,7 @@ The core client/server training code is inside the `core` folder.
|
|||
- Server-side federation and global DP application takes place in `server.py`, more specifically in the `OptimizationServer.train()` method.
|
||||
- Client-side training updates take place in the static method `Client.process_round()`, inside `client.py`.
|
||||
|
||||
General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing MPI, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.
|
||||
General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing server-client calls, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.
|
||||
|
||||
The primary entry point for an experiment is in the script `e2e_trainer.py`. Primary config scripts for experiments are in `configs`. For instance, a basic training scenario for a next-word prediction task is set up in `hello_world_nlg_gru_json.yaml`.
|
||||
|
||||
|
@ -88,14 +88,15 @@ command: >
|
|||
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
|
||||
python3 -m pip install --upgrade pip &&
|
||||
python3 -m pip install -r requirements.txt &&
|
||||
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
|
||||
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
|
||||
-outputPath=./outputs
|
||||
-dataPath={inputs.data}
|
||||
-task=classif_cnn
|
||||
-config=./experiments/classif_cnn/config.yaml
|
||||
-backend=nccl
|
||||
```
|
||||
|
||||
You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch an MPI job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
|
||||
You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch a distributed job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
|
||||
|
||||
Note that the `local_path` above is relative to the location of the YAML file, so setting it to `.` assumes it is in the same folder as `e2e_trainer.py`. All files on this folder will be uploaded to Azure, including hidden folders such as `.git`, so make sure to temporarily get rid of large files and folders that are not needed.
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
'''
|
||||
The Client object is short-lived, instantiated inside of worker 0 and moved to
|
||||
workers 1 to N for processing a given client's data. It's main method is the
|
||||
`process_round` function, used to update the model given a client's data.
|
||||
The Client object is short-lived, instantiated inside workers 1 to N for
|
||||
processing a given client's data. It's main method is the `process_round`
|
||||
function, used to update the model given a client's data.
|
||||
'''
|
||||
|
||||
import copy
|
||||
|
@ -17,12 +17,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
# Internal imports
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
import core.federated as federated
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
||||
|
||||
import core.federated as federated
|
||||
from .strategies import select_strategy
|
||||
from .trainer import (
|
||||
Trainer,
|
||||
|
@ -41,18 +36,20 @@ from utils.dataloaders_utils import (
|
|||
make_train_dataloader,
|
||||
make_val_dataloader,
|
||||
make_test_dataloader,
|
||||
get_dataset,
|
||||
)
|
||||
|
||||
import extensions.privacy
|
||||
from extensions.privacy import metrics as privacy_metrics
|
||||
from experiments import make_model
|
||||
|
||||
global train_dataset
|
||||
|
||||
class Client:
|
||||
# It's unclear why, but sphinx refuses to generate method docs
|
||||
# if there is no docstring for this class.
|
||||
"""Client class for specifying individual client training tasks"""
|
||||
|
||||
def __init__(self, client_id, config, send_gradients, dataloader):
|
||||
def __init__(self, client_id, config, send_gradients):
|
||||
'''
|
||||
Client side processing: computing gradients, update the model and send them back to the server
|
||||
|
||||
|
@ -61,47 +58,38 @@ class Client:
|
|||
config (dict): dictionary with parameters loaded from config file.
|
||||
send_gradients (bool): if True, model gradients are sent back;
|
||||
otherwise, model weights are sent back.
|
||||
dataloader (torch.utils.data.DataLoader): dataloader that generates
|
||||
training data for the client.
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.client_id = client_id
|
||||
self.client_data = self.get_data(client_id, dataloader)
|
||||
self.config = copy.deepcopy(config)
|
||||
self.send_gradients = send_gradients
|
||||
|
||||
def get_client_data(self):
|
||||
def get_client_data(self, dataset=None):
|
||||
'''"Getter" method that returns all object's attributes at once.'''
|
||||
return self.client_id, self.client_data, self.config, self.send_gradients
|
||||
|
||||
client_data = self.get_data(self.client_id, dataset)
|
||||
return self.client_id, client_data, self.config, self.send_gradients
|
||||
|
||||
@staticmethod
|
||||
def get_train_dataset(data_path, client_train_config, task):
|
||||
'''This function will obtain the training dataset for all
|
||||
'''This function will obtain the dataset for all training
|
||||
users.
|
||||
|
||||
Args:
|
||||
data_path (str): path to file containing taining data.
|
||||
client_train_config (dict): trainig data config.
|
||||
task (str): task name.
|
||||
'''
|
||||
|
||||
try:
|
||||
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
|
||||
loader = SourceFileLoader("Dataset",dir).load_module()
|
||||
dataset = loader.Dataset
|
||||
train_file = os.path.join(data_path, client_train_config['list_of_train_data']) if client_train_config['list_of_train_data'] != None else None
|
||||
train_dataset = dataset(train_file, args=client_train_config)
|
||||
num_users = len(train_dataset.user_list)
|
||||
print_rank("Total amount of training users: {}".format(num_users))
|
||||
except:
|
||||
print_rank("Dataset not found, please make sure is located inside the experiment folder")
|
||||
|
||||
return num_users, train_dataset
|
||||
global train_dataset
|
||||
train_dataset = get_dataset(data_path, client_train_config, task, mode="train")
|
||||
return len(train_dataset.user_list)
|
||||
|
||||
@staticmethod
|
||||
def get_data(clients, dataset):
|
||||
''' Create training dictionary'''
|
||||
|
||||
dataset = train_dataset if len(clients) ==1 else dataset # clients is an integer only for training mode
|
||||
data_with_labels = hasattr(dataset,"user_data_label")
|
||||
input_strct = {'users': [], 'num_samples': [],'user_data': dict(), 'user_data_label': dict()} if data_with_labels else {'users': [], 'num_samples': [],'user_data': dict()}
|
||||
|
||||
|
@ -226,7 +214,7 @@ class Client:
|
|||
|
||||
# Ensure the client is assigned to the correct GPU
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == federated.size():
|
||||
torch.cuda.set_device(federated.local_rank())
|
||||
torch.cuda.set_device(federated.rank())
|
||||
|
||||
# Process inputs and initialize variables
|
||||
client_id, data_strct, config, send_gradients = client_data
|
||||
|
|
|
@ -10,16 +10,9 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
# Internal imports
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
import core.federated as federated
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
||||
|
||||
import core.federated as federated
|
||||
from core.client import Client
|
||||
from utils import (
|
||||
print_rank
|
||||
)
|
||||
from utils import print_rank
|
||||
|
||||
# AzureML-related libs
|
||||
from azureml.core import Run
|
||||
|
@ -27,14 +20,14 @@ run = Run.get_context()
|
|||
|
||||
class Evaluation():
|
||||
|
||||
def __init__(self, config, model_path, process_testvalidate, val_dataloader, test_dataloader):
|
||||
def __init__(self, config, model_path, process_testvalidate, idx_val_clients, idx_test_clients):
|
||||
|
||||
self.config = config
|
||||
self.model_path = model_path
|
||||
self.process_testvalidate = process_testvalidate
|
||||
self.test_dataloader = val_dataloader
|
||||
self.val_dataloader = test_dataloader
|
||||
self.server_type = config['server_config']['type']
|
||||
self.idx_val_clients = idx_val_clients
|
||||
self.idx_test_clients = idx_test_clients
|
||||
|
||||
super().__init__()
|
||||
|
||||
|
@ -108,37 +101,35 @@ class Evaluation():
|
|||
def run_distributed_inference(self, mode):
|
||||
'''Call `run_distributed_evaluation` specifically for test or validation.
|
||||
|
||||
This is just a helper function that fetches the dataloader depending on
|
||||
the mode and calls `run_distributed_evaluation` using that dataloader.
|
||||
This is just a helper function that fetches the clients depending on
|
||||
the mode and calls `run_distributed_evaluation` using that list.
|
||||
|
||||
Args:
|
||||
mode (str): `test` or `val`.
|
||||
'''
|
||||
if mode == 'val':
|
||||
dataloader = self.val_dataloader
|
||||
clients = self.idx_val_clients
|
||||
elif mode == 'test':
|
||||
dataloader = self.test_dataloader
|
||||
clients = self.idx_test_clients
|
||||
else:
|
||||
raise NotImplementedError('Unsupported mode: {}'.format(mode))
|
||||
return self.run_distributed_evaluation(dataloader, mode)
|
||||
|
||||
def run_distributed_evaluation(self, dataloader, mode):
|
||||
return self.run_distributed_evaluation(mode, clients)
|
||||
|
||||
def run_distributed_evaluation(self, mode, clients):
|
||||
'''Perform evaluation using available workers.
|
||||
|
||||
See also `process_test_validate` on federated.py.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): used to fetch data.
|
||||
mode (str): `test` or `val`.
|
||||
clients (list): clients for test/val round.
|
||||
'''
|
||||
val_clients = list(self.make_eval_clients(dataloader))
|
||||
print_rank(f'mode: {mode} evaluation_clients {len(val_clients)}', loglevel=logging.DEBUG)
|
||||
|
||||
total = 0
|
||||
self.logits = {'predictions': [], 'probabilities': [], 'labels': []}
|
||||
server_data = (0.0, [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()])
|
||||
|
||||
for result in self.process_testvalidate(val_clients, server_data, mode):
|
||||
for result in self.process_testvalidate(clients, server_data, mode):
|
||||
output, metrics, count = result
|
||||
val_metrics = {key: {'value':0, 'higher_is_better': False} for key in metrics.keys()} if total == 0 else val_metrics
|
||||
|
||||
|
@ -164,34 +155,35 @@ class Evaluation():
|
|||
self.losses = [val_metrics['loss']['value'], val_metrics['acc']['value']] # For compatibility with Server
|
||||
return val_metrics
|
||||
|
||||
def make_eval_clients(self, dataloader):
|
||||
'''Generator that yields clients for evaluation, continuously.
|
||||
def make_eval_clients(dataset, config):
|
||||
'''Generator that yields clients for evaluation, continuously.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): used to get client's data
|
||||
'''
|
||||
Args:
|
||||
dataset (torch.utils.data.Dataset): used to get client's data
|
||||
config (dict): used for the client's constructor
|
||||
'''
|
||||
|
||||
total = sum(dataloader.dataset.num_samples)
|
||||
clients = federated.size() - 1
|
||||
delta = total / clients + 1
|
||||
threshold = delta
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
total = sum(dataset.num_samples)
|
||||
clients = federated.size() - 1
|
||||
delta = total / clients + 1
|
||||
threshold = delta
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
|
||||
if self.server_type == "personalization":
|
||||
for i in range(len(dataloader.dataset.user_list)):
|
||||
yield Client([i], self.config, False, dataloader.dataset)
|
||||
else:
|
||||
for i in range(len(dataloader.dataset.user_list)):
|
||||
current_users_idxs.append(i)
|
||||
count = dataloader.dataset.num_samples[i]
|
||||
current_total += count
|
||||
if current_total > threshold:
|
||||
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
if config["server_config"]["type"] == "personalization":
|
||||
for i in range(len(dataset.user_list)):
|
||||
yield Client([i], config, False)
|
||||
else:
|
||||
for i in range(len(dataset.user_list)):
|
||||
current_users_idxs.append(i)
|
||||
count = dataset.num_samples[i]
|
||||
current_total += count
|
||||
if current_total > threshold:
|
||||
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, config, False)
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
|
||||
if len(current_users_idxs) != 0:
|
||||
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
|
||||
if len(current_users_idxs) != 0:
|
||||
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, config, False)
|
||||
|
|
|
@ -1,53 +1,242 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import cProfile
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
import torch.distributed as dist
|
||||
import numpy as np
|
||||
|
||||
from core.client import Client
|
||||
from utils import (
|
||||
print_rank,
|
||||
print_profiler
|
||||
print_profiler,
|
||||
to_device,
|
||||
)
|
||||
from utils.queue import process_in_parallel
|
||||
|
||||
COMMAND_UPDATE = 0
|
||||
COMMAND_TRAIN = 1
|
||||
COMMAND_TERMINATE = 10
|
||||
COMMAND_TESTVAL = 11
|
||||
COMMAND_SYNC_NODES = 9
|
||||
|
||||
SPLIT_SIZE = 512 * 1024 * 1024 # messages above this size (in bytes) are split
|
||||
|
||||
COMMAND_UPDATE = "update"
|
||||
COMMAND_TRAIN = "train"
|
||||
COMMAND_TERMINATE = "terminate"
|
||||
COMMAND_TESTVAL = "testvalidate"
|
||||
def encode_string(word, string_to_int = True):
|
||||
""" Encodes/Decodes the dictionary keys into an array of integers to be sent
|
||||
as tensors of the same shape during NCCL/Gloo P2P communication.
|
||||
|
||||
Args:
|
||||
word (string/array): key to be encoded/decoded.
|
||||
string_to_int (bool): flag that indicates which action to perform.
|
||||
"""
|
||||
|
||||
if string_to_int: # encode
|
||||
word = word.ljust(8, ' ') if len(word) < 8 else word # padding -- 8 is max length, all tensors must have the same size during communication
|
||||
word_encoded = [letter for letter in word.encode()]
|
||||
return word_encoded
|
||||
else: #decode
|
||||
cleanup_array = [letter for letter in word if letter!= 32] # Remove padding
|
||||
word_decoded = bytes(cleanup_array).decode()
|
||||
return word_decoded
|
||||
|
||||
def rank():
|
||||
"""Return rank of node"""
|
||||
return MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
def local_rank():
|
||||
"""Return local rank of MPI node"""
|
||||
assert (
|
||||
"OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
|
||||
), "local rank can only be determined when using OpenMPI"
|
||||
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
||||
""" Return rank of node. """
|
||||
return dist.get_rank()
|
||||
|
||||
def size():
|
||||
"""Returns number of MPI nodes including server"""
|
||||
return MPI.COMM_WORLD.Get_size()
|
||||
""" Returns number of nodes in the distributed group, including server. """
|
||||
return dist.get_world_size()
|
||||
|
||||
def _recv(x, src=0):
|
||||
""" Receives tensors with a single element or a list of tensors
|
||||
with the same shape during distributed communication. """
|
||||
|
||||
x = torch.tensor(x) if torch.is_tensor(x) == False else x
|
||||
x = to_device(x)
|
||||
dist.recv(tensor=x, src=src)
|
||||
x.to('cpu')
|
||||
|
||||
try:
|
||||
return x.item() # single element
|
||||
except:
|
||||
return x.tolist() # list of tensors
|
||||
|
||||
def _recv_gradients(src):
|
||||
""" Receives a list of tensors with different shape during
|
||||
distributed communication. """
|
||||
|
||||
n, n_dimensions, grads = 0, 0, [] # tensors intialization -- required by torch.
|
||||
n = _recv(n,src)
|
||||
for i in range(n):
|
||||
n_dimensions = _recv(n_dimensions,src)
|
||||
dimensions = [0 for i in range(n_dimensions)]
|
||||
dimensions = _recv(dimensions, src)
|
||||
print_rank(f"Received dimensions {dimensions}", loglevel=logging.DEBUG)
|
||||
param = to_device(torch.zeros(dimensions))
|
||||
print_rank(f"Shape assigned {param.shape}", loglevel=logging.DEBUG)
|
||||
dist.recv(param,src)
|
||||
grads.append(param.detach().cpu())
|
||||
torch.cuda.empty_cache()
|
||||
return grads
|
||||
|
||||
def _send(x, dst=0):
|
||||
""" Send tensors with a single element or a list of tensors
|
||||
with the same shape during distributed communication. """
|
||||
x = torch.tensor(x)
|
||||
x = to_device(x)
|
||||
dist.send(x, dst)
|
||||
del x
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _send_metrics(output):
|
||||
""" Organize the keys and values from the resulting dictionary
|
||||
from test/val rounds into arrays that are sent as independent
|
||||
tensors during distributed communication. """
|
||||
|
||||
keys = [encode_string(key) for key in output.keys()]
|
||||
values = [float(output[key]['value']) for key in output.keys()]
|
||||
higher_is_better = [int(output[key]['higher_is_better']) for key in output.keys()] # send the boolean as int
|
||||
|
||||
_send(len(keys),0)
|
||||
_send(keys)
|
||||
_send(values)
|
||||
_send(higher_is_better)
|
||||
|
||||
def _send_gradients(gradients, dst):
|
||||
""" Send a list of tensors with different shape during
|
||||
distributed communication. """
|
||||
|
||||
_send(len(gradients), dst)
|
||||
for i in gradients:
|
||||
dimensions = [int(d) for d in i.shape]
|
||||
_send(len(dimensions),dst)
|
||||
_send(dimensions,dst)
|
||||
param = to_device(i)
|
||||
dist.send(param,dst)
|
||||
del param
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _send_train_output(output):
|
||||
""" Organize the keys and values from the the returning ´client_output´
|
||||
dictionary in ´Client.proces_round()´ function during training rounds,
|
||||
into arrays that are sent as independent tensors during distributed
|
||||
communication. """
|
||||
|
||||
cs_values = [float(cs_v) for cs_v in output['cs'].values()] # cs dict -- values are flatten in 1d array
|
||||
pl_values = [float(output['pl']['weight'])] # pl dict
|
||||
gradients = output['pl']['gradients'] # gradients are sent independently
|
||||
values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl']] + pl_values # reorganizing values in the order expected by the Server
|
||||
|
||||
# Send data
|
||||
_send(values, 0)
|
||||
_send_gradients(gradients, 0)
|
||||
|
||||
def build_grads_dict(node):
|
||||
""" Reconstruct the dictionary ´client_output´ returned by
|
||||
´Client.proces_round()´ function on the Server side during
|
||||
distributed communication. """
|
||||
|
||||
# Initialize tensors
|
||||
keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl']
|
||||
values = [0.0 for i in range(11)] # initializing tensor shape -- 11 is fixed number of keys expected
|
||||
|
||||
# Read data
|
||||
values = _recv(values,node)
|
||||
grads = _recv_gradients(node)
|
||||
|
||||
# Rebuilding original dictionary
|
||||
cs_values = [{key: values.pop(0) for key in ['setup','training','full cost']}] # recreating cs dict
|
||||
pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
|
||||
values_list = cs_values + [values.pop(0) for i in range(7)] + pl_values # 7 is fixed length for remaining items
|
||||
result = dict(zip(keys,values_list))
|
||||
|
||||
# Cast values to original type
|
||||
for key in ['mg','vg','ng','rg']:
|
||||
result[key] = np.float32(result[key])
|
||||
result['ns'] = int(result['ns'] )
|
||||
|
||||
return result
|
||||
|
||||
def build_metrics_dict(node):
|
||||
""" Reconstruct the dictionary returned during test/val rounds
|
||||
on the Server side during distributed communication. """
|
||||
|
||||
# Initialize tensors
|
||||
n = 0
|
||||
n = _recv(n,node)
|
||||
keys = [[0 for j in range(8)] for i in range(n)] # max_seq_len for metric name is 8
|
||||
values = [0.0 for i in range(n)]
|
||||
higher_is_better = [0 for i in range(n)]
|
||||
|
||||
# Read data
|
||||
keys = _recv(keys,node)
|
||||
values = _recv(values,node)
|
||||
higher_is_better = _recv(higher_is_better,node)
|
||||
|
||||
# Reorganize output + decode dict keys
|
||||
orig_keys = [encode_string(key, string_to_int=False) for key in keys]
|
||||
values_dict = [{'value': float(v), 'higher_is_better': bool(higher_is_better[i])} for i, v in enumerate(values)]
|
||||
metrics = dict(zip(orig_keys,values_dict))
|
||||
num_instances = int(metrics.pop('num')['value'])
|
||||
|
||||
result = None, metrics, num_instances
|
||||
|
||||
return result
|
||||
|
||||
def receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes):
|
||||
""" Receives the clients output on the Server side in async/sync mode.
|
||||
Asynchronous mode is only enabled when using NCCL backend given that Gloo
|
||||
does not provide native non-blocking implementation to check if the operation
|
||||
has been completed during distributed training"""
|
||||
|
||||
if dist.get_backend() == "nccl": # Async
|
||||
for node, req in node_request_map:
|
||||
if req.is_completed():
|
||||
result = build_metrics_dict(node) if command == COMMAND_TESTVAL else build_grads_dict(node)
|
||||
results_list.append(result)
|
||||
free_nodes.append(node)
|
||||
node_request_map.remove((node,req))
|
||||
print_rank(f"Finished releasing the nodes {free_nodes}", loglevel=logging.DEBUG)
|
||||
else: # Sync
|
||||
print_rank(f"Waiting for a workers", loglevel=logging.DEBUG)
|
||||
gather_objects = [(None,None,None) for i in range(size())]
|
||||
output = [None for _ in gather_objects]
|
||||
dist.all_gather_object(output, gather_objects[rank()])
|
||||
print_rank(f" All workers have finished ... taking the remaining clients {len(output)}", loglevel=logging.DEBUG)
|
||||
output = [e for i,e in enumerate(output) if i not in idle_nodes ] # Cleanup for idle workers
|
||||
results_list = results_list + output[1:]
|
||||
free_nodes = list(range(1, size()))
|
||||
|
||||
return node_request_map, results_list, free_nodes
|
||||
|
||||
def append_async_requests(node_request_map, node):
|
||||
""" Appends the asynchronous request sent to each worker during
|
||||
asynchronous training. """
|
||||
|
||||
ack = to_device(torch.zeros(1))
|
||||
req = dist.irecv(tensor=ack, src=node)
|
||||
node_request_map.append((node,req))
|
||||
return node_request_map
|
||||
|
||||
def sync_idle_nodes(client_queue, free_nodes):
|
||||
""" Request dummy outputs to the odd (idle) nodes during synchronous training
|
||||
to prevent them to get trapped in the state of the previous iterations """
|
||||
|
||||
idle_nodes = []
|
||||
if len(client_queue) == 0:
|
||||
print_rank(f"Free idle nodes {len(free_nodes)}", loglevel=logging.DEBUG)
|
||||
while len(free_nodes) > 0:
|
||||
node = free_nodes.pop()
|
||||
idle_nodes.append(node)
|
||||
_send(COMMAND_SYNC_NODES, node)
|
||||
return idle_nodes
|
||||
|
||||
class Server:
|
||||
"""Server object responsible for orchestration and aggregation.
|
||||
|
||||
The Server is one of the two objects that may exist inside of a thread, all
|
||||
throughout its execution (the other being the Worker). At every round, the
|
||||
Server samples clients and sends their data for an available Worker to process.
|
||||
Server samples clients ids and send their data for an available Worker to process.
|
||||
The Workers then each produce a new model, and all models are sent to the Server
|
||||
for aggregation.
|
||||
|
||||
|
@ -59,161 +248,137 @@ class Server:
|
|||
It thus only serves the purpose of grouping the methods, but nothing
|
||||
is actually stored inside of the object.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=None):
|
||||
"""Perform execution of client code on the worker nodes.
|
||||
def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False):
|
||||
"""Perform the orchestration between Clients and Workers.
|
||||
|
||||
This function does the following:
|
||||
1. It sends the server_data to all workers
|
||||
2. For each client:
|
||||
2a. It sends the function process_round of the client
|
||||
to a free worker.
|
||||
2b. It calls get_client_data on the client.
|
||||
2c. It triggers the execution of the payload_fn on the
|
||||
worker with parameters server_data and client_data.
|
||||
2. For each available Worker:
|
||||
2a. It sends the index of the client to instantiate
|
||||
2c. It triggers the execution of the command on the
|
||||
Client.
|
||||
3. Collect and return all client outputs.
|
||||
|
||||
Notes:
|
||||
This function yields the gradients of different clients
|
||||
as they are received. Therefore, the order of the results generally
|
||||
does not correspond to the order of the clients.
|
||||
|
||||
All commands used during Server-Worker communication must be
|
||||
float/integers given that torch.distributed only allows to
|
||||
send/recv tensors.
|
||||
|
||||
Args:
|
||||
clients (list): list of clients to be processed.
|
||||
server_data (dict): server data sent to the workers and passed to
|
||||
clients, typically includes the global model at that step.
|
||||
payload_fn (callback): instructions for worker to execute.
|
||||
clients_in_parallel (int or None): how many threads will be used for
|
||||
processing clients, defaults to None in which case all of them
|
||||
are processed on the same thread.
|
||||
|
||||
command (int): instruction for worker to execute on the Client.
|
||||
mode (int): test/val only provided during evaluation rounds.
|
||||
do_profiling (bool): enables profiler during comunication.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
Generator of results.
|
||||
"""
|
||||
# Send args to workers
|
||||
data_pickled = pickle.dumps(server_data) # pickle once
|
||||
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
|
||||
MPI.COMM_WORLD.send(COMMAND_UPDATE, worker_rank)
|
||||
_send(data_pickled, worker_rank, pickled=True)
|
||||
|
||||
# Perform payload_fn on clients
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
# Initialize communication profiler
|
||||
profiler = None
|
||||
if do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Update lr + model parameters each round for all workers
|
||||
lr, model_params = server_data
|
||||
for worker_rank in range(1, size()):
|
||||
_send(COMMAND_UPDATE, worker_rank)
|
||||
_send(lr,worker_rank)
|
||||
_send_gradients(model_params, worker_rank)
|
||||
print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank}", loglevel=logging.DEBUG)
|
||||
|
||||
print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG)
|
||||
|
||||
client_queue = clients.copy()
|
||||
free_nodes = list(range(1, MPI.COMM_WORLD.Get_size()))
|
||||
node_request_map = []
|
||||
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
|
||||
free_nodes = list(range(1, size()))
|
||||
results_list, node_request_map = [], []
|
||||
|
||||
# Initiate computation for all clients
|
||||
while client_queue:
|
||||
if clients_in_parallel is not None:
|
||||
clients_to_process = [client_queue.pop() for _ in range(clients_in_parallel) if len(client_queue) > 0]
|
||||
else:
|
||||
clients_to_process = client_queue.pop()
|
||||
|
||||
print_rank(f"Queueing {clients_to_process}, {len(client_queue)} remaining", loglevel=logging.DEBUG)
|
||||
|
||||
# Wait for free worker node
|
||||
if not free_nodes:
|
||||
print_rank(f"Waiting for a worker", loglevel=logging.DEBUG)
|
||||
assert(len(node_request_map) > 0)
|
||||
status = MPI.Status()
|
||||
ix, _ = MPI.Request.waitany(node_request_map, status=status)
|
||||
|
||||
# Collects worker output after processing has finished
|
||||
output = _recv(status.source)
|
||||
if isinstance(output, list):
|
||||
yield from output
|
||||
else:
|
||||
yield output
|
||||
|
||||
free_nodes.append(status.source)
|
||||
print_rank(f"Found free worker {ix}:{status.source}", loglevel=logging.DEBUG)
|
||||
node_request_map.pop(ix)
|
||||
|
||||
# Run client computation on free worker node
|
||||
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
|
||||
assert len(free_nodes) > 0
|
||||
node = free_nodes.pop()
|
||||
print_rank(f"Sending to worker {node}", loglevel=logging.DEBUG)
|
||||
payload_fn(clients_to_process, node)
|
||||
print_rank(f"Payload sent. Queueing irecv on {node}", loglevel=logging.DEBUG)
|
||||
node_request_map.append(MPI.COMM_WORLD.irecv(source=node))
|
||||
print_rank(f"Queued irecv for {node}", loglevel=logging.DEBUG)
|
||||
index = len(client_queue)-1
|
||||
client_to_process = client_queue.pop(index)
|
||||
print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG)
|
||||
_send(command, node) # The command should indicate the worker which function to run on the client
|
||||
|
||||
print_rank(f"Done queuing clients. Waiting on workers")
|
||||
if command == COMMAND_TESTVAL:
|
||||
_send(mode,node) # Only for test/val has a value
|
||||
_send(index, node) # Worker receives the index of the client to pop
|
||||
elif command == COMMAND_TRAIN:
|
||||
_send(client_to_process, node)
|
||||
print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG)
|
||||
|
||||
if dist.get_backend() == "nccl":
|
||||
append_async_requests(node_request_map, node)
|
||||
idle_nodes = None
|
||||
else:
|
||||
idle_nodes = sync_idle_nodes(client_queue, free_nodes)
|
||||
|
||||
# Waits until receive the output from all ranks
|
||||
if not free_nodes:
|
||||
print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG)
|
||||
while len(free_nodes) == 0:
|
||||
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
|
||||
|
||||
# Wait for all workers to finish
|
||||
for i, request in enumerate(node_request_map):
|
||||
status = MPI.Status()
|
||||
request.wait(status)
|
||||
print_rank(f"Result for item {i}: source: {status.source}", loglevel=logging.DEBUG)
|
||||
while (len(node_request_map)) != 0:
|
||||
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
|
||||
|
||||
print_rank(f"Calling _recv for {status.source}", loglevel=logging.DEBUG)
|
||||
output = _recv(status.source)
|
||||
if isinstance(output, list):
|
||||
yield from output
|
||||
else:
|
||||
yield output
|
||||
for output in results_list:
|
||||
yield output
|
||||
|
||||
if do_profiling:
|
||||
profiler.disable()
|
||||
print_profiler(profiler)
|
||||
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
@staticmethod
|
||||
def process_clients(clients, server_data, clients_in_parallel):
|
||||
"""Ask workers to process client data.
|
||||
|
||||
The payload function defined below will send a free worker instructions
|
||||
on how to process the data of one or more clients. This payload function
|
||||
is then passed to :code:`dispatch_clients`, which continuously looks for
|
||||
free workers and sends them more clients to process.
|
||||
def process_clients(clients, server_data):
|
||||
"""Ask workers to perform training on Clients.
|
||||
|
||||
Args:
|
||||
clients (list): list of client.Client objects.
|
||||
clients (list): list of clients indexes sampled by ´Server.py´
|
||||
object per iteration.
|
||||
server_data (dict): dictionary containing model.
|
||||
clients_in_parallel (None or int): how many threads to use for
|
||||
processing the clients on a given worker.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
Generator of results.
|
||||
"""
|
||||
|
||||
def payload_fn(clients, node):
|
||||
"""Payload function for a training round."""
|
||||
|
||||
# Send command for training and function to process round
|
||||
MPI.COMM_WORLD.send(COMMAND_TRAIN, node)
|
||||
|
||||
# Loop through clients and send their data
|
||||
if clients_in_parallel is None:
|
||||
MPI.COMM_WORLD.send(clients.process_round, node)
|
||||
MPI.COMM_WORLD.send(clients.get_client_data(), node)
|
||||
else:
|
||||
MPI.COMM_WORLD.send(clients[0].process_round, node) # clients is a list
|
||||
MPI.COMM_WORLD.send(len(clients), node)
|
||||
for client in clients:
|
||||
MPI.COMM_WORLD.send(client.get_client_data(), node)
|
||||
|
||||
return Server.dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=clients_in_parallel)
|
||||
return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN)
|
||||
|
||||
@staticmethod
|
||||
def process_testvalidate(clients, server_data, mode):
|
||||
"""Ask workers to use clients data to compute metrics.
|
||||
|
||||
Similar to :code:`process_round` but asks workers to
|
||||
compute metrics instead, by using a different payload function.
|
||||
"""Ask workers to perform test/val on Clients.
|
||||
|
||||
Args:
|
||||
clients (list): list of client.Client objects.
|
||||
clients (list): list of clients indexes for test/val rounds.
|
||||
server_data (dict): dictionary containing model.
|
||||
mode(str): whether to :code:`test` or :code:`validate`.
|
||||
mode (str): test/val.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
Generator of results.
|
||||
"""
|
||||
|
||||
def payload_fn(client, node):
|
||||
"""Payload function for a test/validation round."""
|
||||
|
||||
MPI.COMM_WORLD.send(COMMAND_TESTVAL, node)
|
||||
MPI.COMM_WORLD.send(client.run_testvalidate, node)
|
||||
MPI.COMM_WORLD.send(client.get_client_data(), node)
|
||||
MPI.COMM_WORLD.send(mode, node)
|
||||
|
||||
return Server.dispatch_clients(clients, server_data, payload_fn)
|
||||
mode = [-2] if mode == "test" else [2]
|
||||
return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode)
|
||||
|
||||
@staticmethod
|
||||
def terminate_workers(terminate=True):
|
||||
|
@ -221,15 +386,15 @@ class Server:
|
|||
|
||||
if terminate:
|
||||
print_rank("Terminating worker processes")
|
||||
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
|
||||
MPI.COMM_WORLD.send(COMMAND_TERMINATE, worker_rank)
|
||||
|
||||
for worker_rank in range(1, size()):
|
||||
_send(COMMAND_TERMINATE, worker_rank)
|
||||
|
||||
class Worker:
|
||||
"""Worker object responsible for processing clients' data.
|
||||
"""Worker object responsible for instantiate Clients based on incoming data
|
||||
from the Server and perform train/eval functions on it.
|
||||
|
||||
Each worker lives on a different MPI thread and is assigned to a different
|
||||
GPU. Via the :code:`dispatch_clients` function, the Server passes the
|
||||
Each worker lives on a different NCCL/Gloo thread and is assigned to a different
|
||||
GPU. Via the :code:`dispatch_clients` function, the Server passes to the
|
||||
Worker specific instructions to process clients' data, typically in order
|
||||
to generate a new model or to compute metrics.
|
||||
|
||||
|
@ -237,116 +402,142 @@ class Worker:
|
|||
model (torch.nn.Module): model being trained.
|
||||
data_path (str): path where all clients' data is located.
|
||||
do_profiling (bool): if True, analyzes execution in depth.
|
||||
clients_in_parallel (None or int): if not None, processes clients in
|
||||
threads during training round.
|
||||
server_data (dict): stores data received from Server when an update
|
||||
command is received.
|
||||
val_clients (list): clients list for validation rounds.
|
||||
test_clients (list): clients list for testing rounds.
|
||||
config (dict): clients configuration.
|
||||
val_dataset (torch.utils.data.Dataset): validation dataset.
|
||||
test_dataset (torch.utils.data.Dataset): testing dataset.
|
||||
"""
|
||||
def __init__(self, model=None, data_path=None, do_profiling=False, val_clients= None, \
|
||||
test_clients=None, config=None, val_dataset = None, test_dataset = None):
|
||||
|
||||
def __init__(self, model=None, data_path=None, do_profiling=False, clients_in_parallel=None):
|
||||
"""
|
||||
Set the GPU workspace for the model to be exchanged between the server and clients
|
||||
This prevents a model instance from being created on the GPU worker many time
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module, optional): model being trained, defaults to None.
|
||||
data_path (str, optional): path where all clients' data is located,
|
||||
defaults to None.
|
||||
do_profiling (bool, optional): if True, analyzes execution in depth; defaults
|
||||
to False.
|
||||
clients_in_parallel (None or int, optional): if not None, processes clients in
|
||||
threads during training round. Defaults to None.
|
||||
"""
|
||||
self.model = model
|
||||
self.data_path = data_path
|
||||
self.do_profiling = do_profiling
|
||||
self.clients_in_parallel = clients_in_parallel
|
||||
|
||||
self.server_data = None
|
||||
|
||||
# For processing in different threads, we need copies of the model
|
||||
if clients_in_parallel is not None:
|
||||
device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
|
||||
self.model_copies = [copy.deepcopy(model).to(device) for _ in range(clients_in_parallel)]
|
||||
self.config = config
|
||||
self.val_clients = val_clients
|
||||
self.test_clients = test_clients
|
||||
self.val_dataset = val_dataset
|
||||
self.test_dataset = test_dataset
|
||||
|
||||
def run(self):
|
||||
"""Main loop executed by worker nodes.
|
||||
|
||||
This method triggers the MPI communication between the worker and
|
||||
This method handles the NCCL/Gloo communication between the worker and
|
||||
the server. It keeps listening for commands from the Server,
|
||||
and performs different actions depending on the command received.
|
||||
and performs different actions on the Client assigned depending on
|
||||
the command received.
|
||||
"""
|
||||
|
||||
while True: # keeps listening for incoming server calls
|
||||
|
||||
while True: # keeps listening for commands on MPI
|
||||
command = MPI.COMM_WORLD.recv()
|
||||
assert isinstance(command, str)
|
||||
# Initialize tensors -- required by torch.distributed
|
||||
command, client_idx, mode = 0, 0, 0 # int
|
||||
lr = torch.zeros(1) # float
|
||||
|
||||
# Read command
|
||||
command = _recv(command)
|
||||
print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG)
|
||||
|
||||
# Receive server data -- lr, model_params
|
||||
if command == COMMAND_UPDATE:
|
||||
self.server_data = _recv(0)
|
||||
|
||||
print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG)
|
||||
lr = _recv(lr, 0)
|
||||
model_params = _recv_gradients(0)
|
||||
server_data = (lr, model_params)
|
||||
print_rank(f"Received lr: {lr} and n_params: {len(model_params)}", loglevel=logging.DEBUG)
|
||||
|
||||
elif command == COMMAND_TRAIN:
|
||||
print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG)
|
||||
|
||||
# Init profiler in training worker
|
||||
profiler = None
|
||||
if self.do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Receive client id from Server
|
||||
client_idx = _recv(client_idx)
|
||||
print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG)
|
||||
|
||||
client_fn = MPI.COMM_WORLD.recv() # NOTE: assumes function is same for all clients
|
||||
# Instantiate client
|
||||
client_to_process = Client(
|
||||
[client_idx],
|
||||
self.config,
|
||||
self.config['client_config']['type'] == 'optimization')
|
||||
|
||||
# Execute Client.get_data()
|
||||
client_data = client_to_process.get_client_data()
|
||||
|
||||
# Pick whether to do processing in batches or not
|
||||
if self.clients_in_parallel is None:
|
||||
client_data = MPI.COMM_WORLD.recv()
|
||||
# Execute Client.process_round()
|
||||
output = client_to_process.process_round(client_data, server_data, self.model, self.data_path)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
output = client_fn(client_data, self.server_data, self.model, self.data_path)
|
||||
# Send output back to Server
|
||||
if dist.get_backend() == "nccl":
|
||||
# ASYNC mode -- enabled only for nccl backend
|
||||
ack = to_device(torch.tensor(1))
|
||||
dist.isend(tensor=ack, dst=0)
|
||||
_send_train_output(output)
|
||||
else:
|
||||
n_clients = MPI.COMM_WORLD.recv()
|
||||
client_data = [MPI.COMM_WORLD.recv() for _ in range(n_clients)]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
output = process_in_parallel(client_fn, client_data, self.server_data, self.model_copies, self.data_path)
|
||||
print_rank(f"Processed batch of size {len(client_data)}, got {len(output)} outputs", loglevel=logging.DEBUG)
|
||||
|
||||
# Wait for server to be available and send output(s)
|
||||
MPI.COMM_WORLD.isend(None, 0).wait()
|
||||
_send(output, 0)
|
||||
|
||||
# Make sure that memory is cleaned up
|
||||
if self.clients_in_parallel is not None:
|
||||
for args in client_data:
|
||||
del args
|
||||
del client_fn, client_data, output
|
||||
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
|
||||
gather_objects = [output for i in range(size())]
|
||||
output = [None for _ in gather_objects]
|
||||
dist.all_gather_object(output, gather_objects[rank()])
|
||||
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
if self.do_profiling:
|
||||
profiler.disable()
|
||||
print_profiler(profiler)
|
||||
|
||||
|
||||
elif command == COMMAND_TESTVAL:
|
||||
print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG)
|
||||
|
||||
# Init profiler in validation worker
|
||||
profiler = None
|
||||
if self.do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Receive mode and client id from Server
|
||||
mode = _recv(mode)
|
||||
mode = "test" if mode == -2 else "val"
|
||||
client_idx = _recv(client_idx)
|
||||
print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG)
|
||||
|
||||
# Get client and dataset
|
||||
clients = self.val_clients if mode == "val" else self.test_clients
|
||||
dataset = self.val_dataset if mode == "val" else self.test_dataset
|
||||
clients_queue = clients.copy()
|
||||
assert 0 <= client_idx < len(clients_queue)
|
||||
client_to_process = clients_queue.pop(client_idx)
|
||||
|
||||
client_fn = MPI.COMM_WORLD.recv()
|
||||
client_data = MPI.COMM_WORLD.recv()
|
||||
client_mode = MPI.COMM_WORLD.recv()
|
||||
# Execute Client.get_data()
|
||||
client_data = client_to_process.get_client_data(dataset)
|
||||
|
||||
# Execute Client.run_testvalidate()
|
||||
output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)
|
||||
|
||||
# Clean up memory before client processing
|
||||
torch.cuda.empty_cache()
|
||||
# Send output back to Server
|
||||
if dist.get_backend() == "nccl":
|
||||
# ASYNC mode -- enabled only for nccl backend
|
||||
_, metrics, num_instances = output
|
||||
metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
|
||||
output = metrics
|
||||
print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG)
|
||||
ack = to_device(torch.tensor(1))
|
||||
dist.isend(tensor=ack, dst=0)
|
||||
_send_metrics(output)
|
||||
else:
|
||||
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
|
||||
gather_objects = [output for i in range(size())]
|
||||
output = [None for _ in gather_objects]
|
||||
dist.all_gather_object(output, gather_objects[rank()])
|
||||
print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG)
|
||||
|
||||
try:
|
||||
output = client_fn(client_data, self.server_data, client_mode, self.model)
|
||||
except RuntimeError as e:
|
||||
_dump_tensors(gpu_only=True)
|
||||
raise RuntimeError("Federated Error: {}".format(str(e)))
|
||||
|
||||
MPI.COMM_WORLD.isend(None, 0).wait()
|
||||
_send(output, 0)
|
||||
|
||||
# Make sure that memory is cleaned up
|
||||
del client_fn, client_data, output
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
|
@ -355,82 +546,23 @@ class Worker:
|
|||
print_profiler(profiler)
|
||||
|
||||
elif command == COMMAND_TERMINATE:
|
||||
print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG)
|
||||
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
return
|
||||
|
||||
elif command == COMMAND_SYNC_NODES: # Only for sync calls
|
||||
print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG)
|
||||
|
||||
gather_objects = [None for i in range(size())]
|
||||
output = [None for _ in gather_objects]
|
||||
dist.all_gather_object(output, gather_objects[rank()])
|
||||
print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG)
|
||||
|
||||
# Some cleanup
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
else:
|
||||
assert False, "unknown command"
|
||||
|
||||
|
||||
def _send(data, rank, pickled=False, verbose=False):
|
||||
"""Send large object by chunking it into multiple MPI messages."""
|
||||
|
||||
# Pickle data
|
||||
data_pickled = data
|
||||
if not pickled:
|
||||
data_pickled = pickle.dumps(data_pickled)
|
||||
|
||||
# Compute in how many chunks data will be sent
|
||||
num_chunks = len(data_pickled) // SPLIT_SIZE + 1
|
||||
if verbose:
|
||||
print_rank(f"_send data_pickled size: {len(data_pickled)}, {num_chunks} chunks")
|
||||
|
||||
# Send data in chunks
|
||||
MPI.COMM_WORLD.send(num_chunks, rank)
|
||||
|
||||
ix = 0
|
||||
while len(data_pickled) - ix > SPLIT_SIZE:
|
||||
MPI.COMM_WORLD.send(data_pickled[ix:ix+SPLIT_SIZE], rank)
|
||||
ix += SPLIT_SIZE
|
||||
MPI.COMM_WORLD.send(data_pickled[ix:], rank)
|
||||
|
||||
def _recv(rank):
|
||||
"""Receive large object by chunking it into multiple MPI messages."""
|
||||
|
||||
num_chunks = MPI.COMM_WORLD.recv(source=rank)
|
||||
pickled_chunks = []
|
||||
for _ in range(num_chunks):
|
||||
pickled_chunks.append(MPI.COMM_WORLD.recv(source=rank))
|
||||
data_pickled = b"".join(pickled_chunks)
|
||||
return pickle.loads(data_pickled)
|
||||
|
||||
def _dump_tensors(gpu_only=True):
|
||||
"""Print a list of the Tensors being tracked by the garbage collector."""
|
||||
|
||||
def pretty_size(size):
|
||||
"""Pretty prints a torch.Size object."""
|
||||
assert(isinstance(size, torch.Size))
|
||||
return " × ".join(map(str, size))
|
||||
|
||||
print_rank("Dump memory allocated")
|
||||
print_rank(torch.cuda.memory_allocated())
|
||||
print_rank("Dump max memory allocated")
|
||||
print_rank(torch.cuda.max_memory_allocated())
|
||||
print_rank("Dump memory cached")
|
||||
print_rank(torch.cuda.memory_cached())
|
||||
print_rank("Dump max memory cached")
|
||||
print_rank(torch.cuda.max_memory_cached())
|
||||
|
||||
total_size = 0
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj):
|
||||
if not gpu_only or obj.is_cuda:
|
||||
print("%s:%s%s %s" % (type(obj).__name__,
|
||||
" GPU" if obj.is_cuda else "",
|
||||
" pinned" if obj.is_pinned else "",
|
||||
pretty_size(obj.size())))
|
||||
total_size += obj.numel()
|
||||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||||
if not gpu_only or obj.is_cuda:
|
||||
print("%s -> %s:%s%s%s%s %s" % (type(obj).__name__,
|
||||
type(obj.data).__name__,
|
||||
" GPU" if obj.is_cuda else "",
|
||||
" pinned" if obj.data.is_pinned else "",
|
||||
" grad" if obj.requires_grad else "",
|
||||
" volatile" if obj.volatile else "",
|
||||
pretty_size(obj.data.size())))
|
||||
total_size += obj.data.numel()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
print_rank("Total size: {}".format(total_size))
|
|
@ -1,8 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
|
||||
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
|
||||
TRAINING_FRAMEWORK_TYPE = 'mpi'
|
||||
logging_level = logging.INFO # DEBUG | INFO
|
|
@ -19,11 +19,7 @@ import numpy as np
|
|||
import torch
|
||||
|
||||
# Internal imports
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
import core.federated as federated
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
||||
import core.federated as federated
|
||||
from core.evaluation import Evaluation
|
||||
from core.client import Client
|
||||
from .strategies import select_strategy
|
||||
|
@ -49,8 +45,8 @@ run = Run.get_context()
|
|||
|
||||
|
||||
class OptimizationServer(federated.Server):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
|
||||
val_dataloader, test_dataloader, config, config_server):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
|
||||
config, idx_val_clients, idx_test_clients):
|
||||
'''Implement Server's orchestration and aggregation.
|
||||
|
||||
This is the main Server class, that actually implements orchestration
|
||||
|
@ -67,11 +63,10 @@ class OptimizationServer(federated.Server):
|
|||
ss_scheduler: scheduled sampling scheduler.
|
||||
data_path (str): points to where data is.
|
||||
model_path (str): points to where pretrained model is.
|
||||
train_dataloader (torch.utils.data.DataLoader): dataloader for training
|
||||
val_dataloader (torch.utils.data.DataLoader): dataloader for validation
|
||||
test_dataloader (torch.utils.data.DataLoader): dataloader for test, can be None
|
||||
server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
|
||||
config (dict): JSON style configuration parameters
|
||||
config_server: deprecated, kept for API compatibility only.
|
||||
idx_val_clients (list): validation client ids
|
||||
idx_test_clients (list): testing clients ids
|
||||
'''
|
||||
|
||||
super().__init__()
|
||||
|
@ -92,7 +87,7 @@ class OptimizationServer(federated.Server):
|
|||
self.val_freq = server_config['val_freq']
|
||||
self.req_freq = server_config['rec_freq']
|
||||
|
||||
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, val_dataloader, test_dataloader)
|
||||
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients)
|
||||
|
||||
# TODO: does this need to be adjusted for custom metrics?
|
||||
self.metrics = dict()
|
||||
|
@ -122,8 +117,8 @@ class OptimizationServer(federated.Server):
|
|||
model=model,
|
||||
optimizer=optimizer,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader if train_dataloader is not None else val_dataloader,
|
||||
val_dataloader=val_dataloader,
|
||||
train_dataloader=server_train_dataloader if server_train_dataloader is not None else None,
|
||||
val_dataloader=None,
|
||||
max_grad_norm=max_grad_norm,
|
||||
anneal_config=server_config['annealing_config'],
|
||||
model_type=self.model_type,
|
||||
|
@ -133,8 +128,7 @@ class OptimizationServer(federated.Server):
|
|||
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
||||
self.server_replay_iterations = None
|
||||
self.server_trainer = None
|
||||
self.train_dataset = train_dataset
|
||||
if train_dataloader is not None:
|
||||
if server_train_dataloader is not None:
|
||||
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
||||
assert 'optimizer_config' in server_config[
|
||||
'server_replay_config'], 'server-side replay training optimizer is not set'
|
||||
|
@ -145,7 +139,7 @@ class OptimizationServer(federated.Server):
|
|||
model=model,
|
||||
optimizer=None,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
train_dataloader=server_train_dataloader,
|
||||
server_replay_config=server_config['server_replay_config'],
|
||||
val_dataloader=None,
|
||||
max_grad_norm=server_config['server_replay_config']\
|
||||
|
@ -180,9 +174,6 @@ class OptimizationServer(federated.Server):
|
|||
|
||||
self.do_profiling = server_config.get('do_profiling', False)
|
||||
|
||||
# Parallel processing
|
||||
self.clients_in_parallel = config['client_config'].get('clients_in_parallel', None)
|
||||
|
||||
StrategyClass = select_strategy(config['strategy'])
|
||||
self.strategy = StrategyClass('server', self.config, self.model_path)
|
||||
print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)
|
||||
|
@ -230,7 +221,7 @@ class OptimizationServer(federated.Server):
|
|||
'secsPerClientFull': [],
|
||||
'secsPerRoundHousekeeping': [],
|
||||
'secsPerRoundTotal': [],
|
||||
'mpiCosts': []
|
||||
'communicationCosts': []
|
||||
}
|
||||
|
||||
run.log('Max iterations', self.max_iteration)
|
||||
|
@ -304,15 +295,7 @@ class OptimizationServer(federated.Server):
|
|||
# Create the pool of clients -- sample from this pool to assign to workers
|
||||
sampled_idx_clients = random.sample(self.client_idx_list,
|
||||
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
||||
sampled_clients = [
|
||||
Client(
|
||||
[client_id],
|
||||
self.config,
|
||||
self.config['client_config']['type'] == 'optimization',
|
||||
self.train_dataset
|
||||
) for client_id in sampled_idx_clients
|
||||
]
|
||||
|
||||
|
||||
# Initialize stats
|
||||
clients_begin = time.time()
|
||||
|
||||
|
@ -326,7 +309,7 @@ class OptimizationServer(federated.Server):
|
|||
self.run_stats['secsPerClientFull'].append([])
|
||||
self.run_stats['secsPerClientTraining'].append([])
|
||||
self.run_stats['secsPerClientSetup'].append([])
|
||||
self.run_stats['mpiCosts'].append([])
|
||||
self.run_stats['communicationCosts'].append([])
|
||||
|
||||
# Check if we want privacy metrics
|
||||
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
|
||||
|
@ -345,7 +328,8 @@ class OptimizationServer(federated.Server):
|
|||
# Reset gradient for the model before assigning the new gradients
|
||||
self.worker_trainer.model.zero_grad()
|
||||
|
||||
for client_output in self.process_clients(sampled_clients, server_data, self.clients_in_parallel):
|
||||
print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
|
||||
for client_output in self.process_clients(sampled_idx_clients, server_data):
|
||||
# Process client output
|
||||
client_timestamp = client_output['ts']
|
||||
client_stats = client_output['cs']
|
||||
|
@ -361,7 +345,7 @@ class OptimizationServer(federated.Server):
|
|||
for metric, value in privacy_stats.items():
|
||||
privacy_metrics_stats[metric].append(value)
|
||||
|
||||
self.run_stats['mpiCosts'][-1].append(time.time() - client_timestamp)
|
||||
self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)
|
||||
|
||||
# Get actual pseudo-gradients for aggregation
|
||||
payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
|
||||
|
@ -513,7 +497,7 @@ class OptimizationServer(federated.Server):
|
|||
'secsPerClientTraining',
|
||||
'secsPerClientFull',
|
||||
'secsPerClientSetup',
|
||||
'mpiCosts',
|
||||
'communicationCosts',
|
||||
]
|
||||
|
||||
for metric in metrics_for_stats:
|
||||
|
|
|
@ -10,13 +10,13 @@ Install the requirements stated inside of requirements.txt. Ideally this sould b
|
|||
conda create -n FLUTE python==3.8
|
||||
pip install -r requirements.txt
|
||||
|
||||
You will also need some MPI runtime such as OpenMPI (on Linux) or MS-MPI (on Windows). There is no setup.py as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
|
||||
FLUTE uses torch.distributed API as its main communication backbone, supporting three buil-in backends. For more information please refer to [Distributed Communication Package](https://pytorch.org/docs/stable/distributed.html). Therefore, we highly suggest to use NCCL backend for distributed GPU training and Gloo for distributed CPU training. There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
|
||||
|
||||
After this initial setup you can use your data for launching a local run. However the following instructions will be adapted to run ``nlg_gru`` task. For running this example, you need to first download and preprocess the data. Instructions can be found `here`_. Once the data is available you can run FLUTE from root as follows:
|
||||
|
||||
.. code:: bash
|
||||
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru
|
||||
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru -backend nccl
|
||||
|
||||
.. _here: https://github.com/microsoft/msrflute/tree/main/testing
|
||||
|
||||
|
@ -56,14 +56,15 @@ For running experiments on AzureML, the CLI can help. You should first install t
|
|||
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
|
||||
python3 -m pip install --upgrade pip &&
|
||||
python3 -m pip install -r requirements.txt &&
|
||||
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
|
||||
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
|
||||
-outputPath=./outputs
|
||||
-dataPath={inputs.data}
|
||||
-task=classif_cnn
|
||||
-config=./experiments/classif_cnn/config.yaml|
|
||||
-config=./experiments/classif_cnn/config.yaml
|
||||
-backend=nccl
|
||||
|
||||
|
||||
You should replace ``compute`` with the name of the one you created before, and adjust the path of the datastore containing the data. In the example above, we created a datastore called ``data`` and added to it a folder called ``cifar``, which contained the two HDF5 files. The command passed above will install dependencies and then launch an MPI job with 4 threads, for the experiment defined in ``experiments/classif_cnn``. Details on how to run a job using the AzureML CLI are given in its `documentation`_ , but typically it suffices to set up the environment and type ``az ml job create -f <name-of-the-yaml-file>``. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
|
||||
You should replace ``compute`` with the name of the one you created before, and adjust the path of the datastore containing the data. In the example above, we created a datastore called ``data`` and added to it a folder called ``cifar``, which contained the two HDF5 files. The command passed above will install dependencies and then launch a NCCL job with 4 threads, for the experiment defined in ``experiments/classif_cnn``. Details on how to run a job using the AzureML CLI are given in its `documentation`_ , but typically it suffices to set up the environment and type ``az ml job create -f <name-of-the-yaml-file>``. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
|
||||
|
||||
.. _documentation: https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli
|
||||
|
||||
|
|
|
@ -27,7 +27,7 @@ Each worker>0 processes client tasks sequentially, consisting of data encoding a
|
|||
:align: center
|
||||
:width: 500
|
||||
|
||||
FLUTE uses a distributed processing architecture backed by OpenMPI.
|
||||
FLUTE uses a distributed processing architecture backed by torch.distributed.
|
||||
|
||||
Execution runs for up to N training rounds. In each round the orchestrator may sample a subset of clients, and may also randomly delay pseudo-gradient updates from some clients to future rounds. The orchestrator will also periodically distribute evaluation tasks to determine model quality on validation and test data.
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
# Licensed under the MIT license.
|
||||
|
||||
'''
|
||||
This is the main script to run on each MPI thread. It will spawn either a
|
||||
This is the main script to run on each NCCL/GLOO thread. It will spawn either a
|
||||
Server or Worker object -- the former is responsible for orchestrating and
|
||||
aggregating models, where as the latter processes clients' data to generate
|
||||
a new model. The Server lives on the very first thread, whereas remaining
|
||||
|
@ -13,16 +13,17 @@ import argparse
|
|||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
import logging
|
||||
from psutil import virtual_memory
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from azureml.core import Run
|
||||
|
||||
from core import federated
|
||||
from core.config import FLUTEConfig
|
||||
from core.server import select_server
|
||||
from core.client import Client
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level
|
||||
from experiments import make_model
|
||||
from utils import (
|
||||
make_optimizer,
|
||||
|
@ -32,12 +33,9 @@ from utils import (
|
|||
)
|
||||
from utils.dataloaders_utils import (
|
||||
make_train_dataloader,
|
||||
make_val_dataloader,
|
||||
make_test_dataloader,
|
||||
get_dataset,
|
||||
)
|
||||
|
||||
assert TRAINING_FRAMEWORK_TYPE == "mpi", "Unsupported platform {}".format(TRAINING_FRAMEWORK_TYPE)
|
||||
|
||||
from core.evaluation import make_eval_clients
|
||||
|
||||
def log_run_properties(config: FLUTEConfig):
|
||||
"""Log parameters on AzureML.
|
||||
|
@ -76,49 +74,64 @@ def log_run_properties(config: FLUTEConfig):
|
|||
run.log(k, properties[k])
|
||||
|
||||
|
||||
def run_worker(model_path, config, task, data_path, local_rank):
|
||||
"""Spawn worker object that lives throughout MPI thread.
|
||||
def run_worker(model_path, config, task, data_path, local_rank, backend):
|
||||
"""Spawn worker object that lives throughout NCCL/GLOO thread.
|
||||
|
||||
Args:
|
||||
model_path (str): path to the pretrained model.
|
||||
config (dict): dictionary containing parameters.
|
||||
task (str): what task to solve, must be a folder of :code:`experiments`.
|
||||
data_path (str): path to data.
|
||||
local_rank (int): the rank of the MPI thread.
|
||||
local_rank (int): the rank of the NCCL/GLOO thread.
|
||||
"""
|
||||
model_config = config["model_config"]
|
||||
server_config = config["server_config"]
|
||||
|
||||
# Get the rank on MPI
|
||||
# Backend initialization
|
||||
print_rank(f"Backend: {backend}")
|
||||
dist.init_process_group(backend=backend, init_method=None)
|
||||
rank = dist.get_rank()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
|
||||
# Get the rank on NCCL/GLOO
|
||||
rank = local_rank if local_rank > -1 else federated.rank()
|
||||
|
||||
# Assign MPI thread to a specific GPU
|
||||
# Assign NCCL thread to a specific GPU
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(federated.local_rank() % n_gpus)
|
||||
print_rank(f"Assigning worker to GPU {federated.local_rank() % n_gpus}")
|
||||
torch.cuda.set_device(federated.rank() % n_gpus)
|
||||
print_rank(f"Assigning worker to GPU {federated.rank() % n_gpus}")
|
||||
|
||||
# Make the Model to distribute to workers
|
||||
model = make_model(model_config)
|
||||
|
||||
# Get evaluation datasets
|
||||
data_config = config['server_config']['data_config']
|
||||
val_dataset = get_dataset(data_path, data_config["val"], task, mode="val")
|
||||
test_dataset = get_dataset(data_path, data_config["test"], task, mode="test")
|
||||
|
||||
# Create list of clients for test/val -- Server need the indexes and Worker the clients list
|
||||
val_clients = list(make_eval_clients(val_dataset, config))
|
||||
test_clients = list(make_eval_clients(test_dataset, config))
|
||||
|
||||
# pre-cache the training data and capture the number of clients for sampling
|
||||
client_train_config = config["client_config"]["data_config"]["train"]
|
||||
num_clients = Client.get_train_dataset(data_path, client_train_config,task)
|
||||
config["server_config"]["data_config"]["num_clients"] = num_clients
|
||||
|
||||
# Instantiate the Server object on the first thread
|
||||
if rank == 0:
|
||||
try:
|
||||
print_rank('Server data preparation')
|
||||
|
||||
# pre-cache the training data and capture the number of clients for sampling
|
||||
client_train_config = config["client_config"]["data_config"]["train"]
|
||||
num_clients, train_dataset = Client.get_train_dataset(data_path, client_train_config,task)
|
||||
config["server_config"]["data_config"]["num_clients"] = num_clients
|
||||
|
||||
# Make the Dataloaders
|
||||
data_config = config['server_config']['data_config']
|
||||
if 'train' in data_config:
|
||||
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
|
||||
else:
|
||||
server_train_dataloader = None
|
||||
val_dataloader = make_val_dataloader(data_config["val"], data_path, task=task)
|
||||
test_dataloader = make_test_dataloader(data_config["test"], data_path, task=task)
|
||||
|
||||
idx_val_clients = list(range(len(val_clients))) # Generates indexes for val clients
|
||||
idx_test_clients = list(range(len(test_clients))) # Generates indexes for test clients
|
||||
|
||||
print_rank("Prepared the dataloaders")
|
||||
|
||||
|
@ -135,18 +148,16 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
|||
server_type = server_config["type"]
|
||||
server_setup = select_server(server_type) # Return the server class
|
||||
server = server_setup(
|
||||
data_config["num_clients"],
|
||||
model,
|
||||
optimizer,
|
||||
None,
|
||||
data_path,
|
||||
model_path,
|
||||
server_train_dataloader,
|
||||
train_dataset,
|
||||
val_dataloader,
|
||||
test_dataloader,
|
||||
config,
|
||||
server_config
|
||||
num_clients=data_config["num_clients"],
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
ss_scheduler=None,
|
||||
data_path=data_path,
|
||||
model_path=model_path,
|
||||
server_train_dataloader=server_train_dataloader,
|
||||
config=config,
|
||||
idx_val_clients=idx_val_clients,
|
||||
idx_test_clients=idx_test_clients,
|
||||
)
|
||||
log_run_properties(config)
|
||||
|
||||
|
@ -163,10 +174,14 @@ def run_worker(model_path, config, task, data_path, local_rank):
|
|||
print_rank("Worker on node {}: process started".format(rank))
|
||||
client_config = config["client_config"]
|
||||
worker = federated.Worker(
|
||||
model,
|
||||
data_path,
|
||||
model=model,
|
||||
data_path=data_path,
|
||||
do_profiling=client_config.get("do_profiling", False),
|
||||
clients_in_parallel=client_config.get("clients_in_parallel", None),
|
||||
val_clients=val_clients,
|
||||
test_clients=test_clients,
|
||||
val_dataset = val_dataset,
|
||||
test_dataset = test_dataset,
|
||||
config= config,
|
||||
)
|
||||
worker.run()
|
||||
|
||||
|
@ -178,6 +193,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-outputPath")
|
||||
parser.add_argument("-dataPath", default=None)
|
||||
parser.add_argument("-task", default=None, help="Define the task for the run")
|
||||
parser.add_argument("-backend", default=None, help="Define the communication protocol")
|
||||
parser.add_argument("-num_skip_decoding", default=-1, type=int, help="Skip decoding in unsupervised learning mode")
|
||||
parser.add_argument("--local_rank", default=-1, type=int)
|
||||
|
||||
|
@ -185,6 +201,8 @@ if __name__ == "__main__":
|
|||
data_path = args.dataPath
|
||||
task = args.task
|
||||
local_rank = args.local_rank
|
||||
assert args.backend in ['nccl','gloo'], f"Backend {args.backend} not recognized, please select nccl or gloo"
|
||||
backend = args.backend
|
||||
|
||||
# The mount point can also be retrieved from input_datasets of the run context
|
||||
if data_path is None:
|
||||
|
@ -195,6 +213,7 @@ if __name__ == "__main__":
|
|||
id = Run.get_context().id
|
||||
experiment_name = "-".join(id.split("-")[-4:-2])
|
||||
experiment_root = os.path.join(args.outputPath, experiment_name)
|
||||
os.makedirs(experiment_root, exist_ok=True)
|
||||
model_path = os.path.join(experiment_root, "models")
|
||||
log_path = os.path.join(experiment_root, "log")
|
||||
|
||||
|
@ -207,7 +226,7 @@ if __name__ == "__main__":
|
|||
shutil.copyfile(args.config, cfg_out)
|
||||
|
||||
# Initialize logging
|
||||
init_logging(log_path, loglevel=logging_level)
|
||||
init_logging(log_path, loglevel=logging.INFO)
|
||||
|
||||
with open(args.config) as f:
|
||||
|
||||
|
@ -222,4 +241,4 @@ if __name__ == "__main__":
|
|||
config.validate()
|
||||
|
||||
# Instantiate either Server or Worker on the thread
|
||||
run_worker(model_path, config, task, data_path, local_rank)
|
||||
run_worker(model_path, config, task, data_path, local_rank, backend)
|
||||
|
|
|
@ -48,10 +48,10 @@ example is provided in `config.yaml`.
|
|||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI
|
||||
script using torch.distributed.
|
||||
|
||||
```
|
||||
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
|
||||
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn -backend gloo
|
||||
```
|
||||
|
||||
The `dataPath`, `outputPath` and `config` arguments should just specify the
|
||||
|
|
|
@ -44,10 +44,10 @@ example is provided in `config.yaml`.
|
|||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI
|
||||
script using torch.distributed.
|
||||
|
||||
```
|
||||
mpiexec -n 4 python e2e_trainer.py -dataPath ./ -outputPath scratch -config experiments/classif_cnn/config.yaml -task cv
|
||||
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py -dataPath ./ -outputPath scratch -config experiments/classif_cnn/config.yaml -task cv -backend gloo
|
||||
```
|
||||
|
||||
The `dataPath`, `outputPath` and `config` arguments should just specify the
|
||||
|
|
|
@ -7,12 +7,12 @@ In this file, we define the local server that lives inside the client.
|
|||
from core.server import OptimizationServer
|
||||
|
||||
class PersonalizationServer(OptimizationServer):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
|
||||
val_dataloader, test_dataloader, config, config_server):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
|
||||
config, idx_val_clients, idx_test_clients):
|
||||
"""
|
||||
Personalization Server.
|
||||
|
||||
Customized routines for server can be included here.
|
||||
"""
|
||||
super().__init__(num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
|
||||
val_dataloader, test_dataloader, config, config_server)
|
||||
super().__init__(num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
|
||||
config, idx_val_clients, idx_test_clients)
|
|
@ -59,9 +59,9 @@ example is provided in `config.yaml`.
|
|||
## Running the experiment locally
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI:
|
||||
script using torch.distributed:
|
||||
|
||||
`mpiexec -n 2 python .\e2e_trainer.py -dataPath experiments/ecg_cnn/data -outputPath scratch -config experiments/ecg_cnn/config.yaml -task ecg_cnn `
|
||||
`python -m torch.distributed.run --nproc_per_node=2 .\e2e_trainer.py -dataPath experiments/ecg_cnn/data -outputPath scratch -config experiments/ecg_cnn/config.yaml -task ecg_cnn -backend nccl`
|
||||
|
||||
The `dataPath`, `outputPath` and `config` arguments should just specify the
|
||||
respective files or folders, as in the example above -- in this case, a folder
|
||||
|
@ -90,11 +90,12 @@ command: >
|
|||
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
|
||||
python3 -m pip install --upgrade pip &&
|
||||
python3 -m pip install -r requirements.txt &&
|
||||
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
|
||||
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
|
||||
-outputPath=./outputs
|
||||
-dataPath={inputs.data}
|
||||
-task=ecg_cnn
|
||||
-config=./experiments/ecg_cnn/config.yaml
|
||||
-backend=nccl
|
||||
|
||||
```
|
||||
To run your job, you can then use the following command:
|
||||
|
|
|
@ -24,10 +24,10 @@ the fields: list_of_train_data, test_data and val_data inside the config file.
|
|||
## Running the experiment locally
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI:
|
||||
script using torch.distributed:
|
||||
|
||||
```code
|
||||
mpiexec -n 2 python .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert
|
||||
python -m torch.distributed.run --nproc_per_node=2 .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert -backend nccl
|
||||
```
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
|
|
|
@ -7,6 +7,8 @@ from transformers import AutoTokenizer
|
|||
from transformers import DataCollatorForLanguageModeling
|
||||
from experiments.mlm_bert.dataloaders.dataset import Dataset
|
||||
from core.dataloader import BaseDataLoader
|
||||
from utils import print_rank
|
||||
import logging
|
||||
|
||||
class DataLoader(BaseDataLoader):
|
||||
"""
|
||||
|
@ -38,7 +40,7 @@ class DataLoader(BaseDataLoader):
|
|||
else:
|
||||
raise ValueError("You are instantiating a new tokenizer from scratch. This is not supported by this script.")
|
||||
|
||||
print("Tokenizer is: ",tokenizer)
|
||||
print_rank("Tokenizer is: {}".format(tokenizer), loglevel=logging.DEBUG)
|
||||
|
||||
dataset = Dataset(
|
||||
data,
|
||||
|
|
|
@ -115,7 +115,7 @@ class Dataset(BaseDataset):
|
|||
if len(self.utt_list) == 0:
|
||||
self.utt_list = [{'src_text': 'N/A', 'duration': 0, 'loss_weight': 1.0}]
|
||||
|
||||
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
|
||||
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.DEBUG)
|
||||
|
||||
def process_user(self, user, user_data):
|
||||
counter=0
|
||||
|
|
|
@ -27,10 +27,10 @@ data to HDF5 format.
|
|||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment locally , it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI, you can use as example the following line:
|
||||
script using torch.distributed , you can use as example the following line:
|
||||
|
||||
```code
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru
|
||||
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru -backend nccl
|
||||
```
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
|
|
|
@ -58,8 +58,9 @@ class Dataset(BaseDataset):
|
|||
self.num_samples = orig_strct['num_samples']
|
||||
self.user_data = orig_strct['user_data']
|
||||
self.user = 'test_only' if self.test_only else self.user_list[user_idx]
|
||||
|
||||
self.process_x(self.user_data)
|
||||
|
||||
if user_idx != -1:
|
||||
self.process_x(self.user_data)
|
||||
|
||||
def process_x(self, user_data):
|
||||
print_rank('Processing data-structure: {} Utterances expected'.format(sum(self.num_samples)), loglevel=logging.DEBUG)
|
||||
|
|
|
@ -30,13 +30,13 @@ def run_pipeline(data_path, output_path, config_path, task):
|
|||
|
||||
# Adjust command to the task and OS
|
||||
sym = "&" if platform.system() == "Windows" else ";"
|
||||
command = 'cd .. '+ sym +' mpiexec '+'-np '+'2 '+ 'python '+ 'e2e_trainer.py '+ \
|
||||
command = 'cd .. '+ sym +' python '+'-m '+'torch.distributed.run '+ '--nproc_per_node=2 '+ 'e2e_trainer.py '+ \
|
||||
'-dataPath '+ data_path+' -outputPath '+output_path+' -config ' +config_path +\
|
||||
' -task '+ task
|
||||
' -task '+ task + ' -backend '+ 'nccl'
|
||||
|
||||
# Execute e2e_trainer + stores the exit code
|
||||
with open('logs.txt','w') as f:
|
||||
process= subprocess.run(command, shell=True,stdout=f,text=True,timeout=2000)
|
||||
process= subprocess.run(command, shell=True,stdout=f,text=True,timeout=900)
|
||||
return_code=process.returncode
|
||||
|
||||
# Print logs
|
||||
|
@ -53,6 +53,12 @@ def test_nlg_gru():
|
|||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
|
||||
def test_ecg_cnn():
|
||||
|
||||
task = 'ecg_cnn'
|
||||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_mlm_bert():
|
||||
|
||||
|
@ -68,12 +74,3 @@ def test_classif_cnn():
|
|||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
print("PASSED")
|
||||
|
||||
def test_ecg_cnn():
|
||||
|
||||
task = 'ecg_cnn'
|
||||
data_path, output_path, config_path = get_info(task)
|
||||
assert run_pipeline(data_path, output_path, config_path, task)==0
|
||||
|
||||
|
||||
|
|
@ -82,4 +82,17 @@ def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
|
|||
|
||||
return test_dataloader
|
||||
|
||||
def get_dataset(data_path, data_config, task, mode):
|
||||
""" Return the task train/val/test dataset """
|
||||
|
||||
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
|
||||
loader = SourceFileLoader("Dataset",dir).load_module()
|
||||
dataset = loader.Dataset
|
||||
data_file = "val_data" if mode == "val" else "test_data" if mode == "test" else "list_of_train_data"
|
||||
data_file = data_config[data_file]
|
||||
data_file = os.path.join(data_path, data_file) if data_file != None else data_file
|
||||
dataset = dataset(data_file, user_idx=-1, args=data_config)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from mpi4py import MPI
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from utils import print_rank
|
||||
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
size = comm.Get_size()
|
||||
rank = comm.Get_rank()
|
||||
|
||||
""" Here we have classes and functions that allow one to send and process multiple
|
||||
messages in parallel on MPI. """
|
||||
|
||||
def process_in_parallel(client_fn, client_data, server_data, models, data_path):
|
||||
""" Process multiple orders in parallel
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_fn: callback
|
||||
Function we want to call.
|
||||
client_data: list of tuples
|
||||
Arguments that will be passed to function.
|
||||
server_data: tuple
|
||||
Data passed from server to update model parameters.
|
||||
models: torch.nn.Module
|
||||
Models we will send to the clients.
|
||||
data_path: str
|
||||
Path to data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
Output of each callback in the list passed as input.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=len(client_data)) as pool:
|
||||
requests = []
|
||||
for k, args in enumerate(client_data):
|
||||
requests.append(pool.submit(client_fn, args, server_data, models[k], data_path))
|
||||
|
||||
results = [request.result() for request in requests]
|
||||
print_rank(f'finished processing batch of size {len(client_data)}')
|
||||
|
||||
return results
|
|
@ -18,19 +18,12 @@ from collections import OrderedDict
|
|||
from utils.optimizers.lars import LarsSGD
|
||||
from utils.optimizers.lamb import LAMB
|
||||
from utils.optimizers.adamW import AdamW
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
from easydict import EasyDict as edict
|
||||
from torch.optim.lr_scheduler import (
|
||||
StepLR,
|
||||
MultiStepLR,
|
||||
ReduceLROnPlateau )
|
||||
|
||||
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
from mpi4py import MPI
|
||||
else:
|
||||
raise NotImplementedError('Training framework is not yet supported')
|
||||
|
||||
def make_optimizer(optimizer_config, model):
|
||||
"""Initialization for optimizer."""
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче