Merged PR 1272: Replace MPI -> torch.distributed

This PR replaces MPI by torch.distributed as main communication backbone, allowing to use NCCL with GPUs and Gloo for CPU distributed jobs. Most significative changes are inside _federated.py_.

Asynchronous mode is enabled when using NCCL , which means that the workers are being reassigned to a new Client as soon as they finish, improving the overall GPU utilization + reducing the total time of the job,  as shown in the figure below.

![COMPARISON (2).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/COMPARISON%20%282%29.png)

However Gloo does not have a native implementation for non-blocking ways to check if the recv/send request have been completed (see details here: https://github.com/pytorch/pytorch/issues/30723 ) Therefore, when using Gloo the communication works in synchronous way.

I've added a fix for the CUDA OOM issues I was receiving when running the bert experiment, the GPU memory was being overloaded during training. Comparison below MPI (https://aka.ms/amlt?q=dcbbn) vs NCCL now, some cleanup is performed after the server receives the gradient.

![image (14).png](https://msktg.visualstudio.com/c507252c-d1be-4d67-a4a1-03b0181c35c7/_apis/git/repositories/0392018c-4507-44bf-97e2-f2bb75d454f1/pullRequests/1272/attachments/image%20%2814%29.png)

There are a couple minor changes in _server_, _client_ and _evaluation_ as well. The main reason is that now the Server doesn't hold the list of clients, these ones live inside the worker since the moment is created and the Server is only passing the indexes of the Client to the Worker. The reason behind this change is that torch.distributed does not allow to send objects P2P, only tensors.

The rest of modified files are only to update the documentation + the testing file. I tried to be very explicit for each new function inside _federated.py_ to explain the new flow. Let me know if something it's not clear enough.

I've tested all experiments already in the sandbox using NCCL and in my local machine (Windows) using Gloo (surprisingly for this case is not as slow as I was expecting, I used some dummy datasets that I had prepared though) --> pending task compare the new performance using CPU.

So for now the only thing left is to run the sanity-checks on AML, links below.

Sanity checks after cleanup:
[x] nlg_gru: https://aka.ms/amlt?q=c9tih
[x] mlm_bert: https://aka.ms/amlt?q=dbs8y
[x] classif_cnn: https://aka.ms/amlt?q=da2qr
[x] ecg: https://aka.ms/amlt?q=c9jof
[x] cv: https://aka.ms/amlt?q=da2k4
This commit is contained in:
Mirian Hipolito Garcia 2022-08-26 14:54:27 +00:00
Родитель 887c8ac74f
Коммит 30e41400a4
23 изменённых файлов: 657 добавлений и 546 удалений

40
CHANGELOG.md Normal file
Просмотреть файл

@ -0,0 +1,40 @@
# Changelog
All notable changes to this project will be documented in this file.
## [0.1.0] - 2021-11-22
We're super excited to announce FLUTE: Federated Learning Utilities for Testing and Experimentation, a platform for conducting high-performance federated learning simulations!
This first release fully focuses on implementing fast prototyping to validate different CL scenarios
in an Federated environment.
### Features
- large scale simulation (millions of clients, sampling tens of thousands per round).
- multi-GPU and multi-node orchestration backed up by MPI.
- local or global differential privacy.
- model quantization.
- a variety of standard optimizers and aggregation methods.
- most model types including CNNs, RNNs, and Huggingface Transformers.
- extensibility, enabling new models, dataloaders, optimizers, and aggregators.
- local or cloud-based job staging using AzureML.
## [1.0.0] - 2022-08-29
This release contain major changes in the communication backbone , in order
to run previous experiments you have already integrated in FLUTE, please make sure
to use `torch.distributed` instead of `MPI `to launch the jobs. For more documentation
about the new command, please refer to the [README](README.md).
### New features
- 🏎 Better performance: Support for NCCL and Gloo as backend communication protocols.
- Improvements in GPU utilization and overall communication speed (on the order of minutes!) for projects with huge models and datasets.
- 🌟 Remove file type dependency on client.py, now FLUTE can receive any kind of dataset and even download the data on-the-fly. The data intantiation is completely under control of each task dataset.
- In older versions FLUTE only allowed `json` and `hdf5` files, so the client could recognize it.
- 🌟 Abstract classes for new models/dataloaders.
- 🌟 Allows Federated Learning with Personalization.
- Personalization allows you to leverage each client local data to obtain models that are better adjusted to their own data distribution. You can run the `cv` task in order to try out this feature.

Просмотреть файл

@ -24,15 +24,15 @@ conda create -n FLUTE python==3.8
pip install -r requirements.txt
```
You will also need some MPI runtime such as OpenMPI (on Linux) or MS-MPI (on Windows). There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
FLUTE uses torch.distributed API as its main communication backbone, supporting three built-in backends. For more information please refer to [Distributed Communication Package](https://pytorch.org/docs/stable/distributed.html). Therefore, we highly suggest to use NCCL backend for distributed GPU training and Gloo for distributed CPU training. There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
After this initial setup, you can use the data created for the integration test inside of `testing` for a first local run. Note that this data needs to be download manually, for more instructions please look at [the README file inside `testing`](testing/README.md).
```
mpiexec -n 3 python e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru -backend nccl
```
This config uses 1 MPI node with 3 workers (1 server, 2 clients). The config file `testing/configs/hello_world_local.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.
This config uses 1 node with 3 workers (1 server, 2 clients). The config file `testing/configs/hello_world_local.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.
## Documentation
@ -55,7 +55,7 @@ The core client/server training code is inside the `core` folder.
- Server-side federation and global DP application takes place in `server.py`, more specifically in the `OptimizationServer.train()` method.
- Client-side training updates take place in the static method `Client.process_round()`, inside `client.py`.
General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing MPI, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.
General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing server-client calls, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.
The primary entry point for an experiment is in the script `e2e_trainer.py`. Primary config scripts for experiments are in `configs`. For instance, a basic training scenario for a next-word prediction task is set up in `hello_world_nlg_gru_json.yaml`.
@ -88,14 +88,15 @@ command: >
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
python3 -m pip install --upgrade pip &&
python3 -m pip install -r requirements.txt &&
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
-outputPath=./outputs
-dataPath={inputs.data}
-task=classif_cnn
-config=./experiments/classif_cnn/config.yaml
-backend=nccl
```
You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch an MPI job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch a distributed job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
Note that the `local_path` above is relative to the location of the YAML file, so setting it to `.` assumes it is in the same folder as `e2e_trainer.py`. All files on this folder will be uploaded to Azure, including hidden folders such as `.git`, so make sure to temporarily get rid of large files and folders that are not needed.

Просмотреть файл

@ -1,9 +1,9 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
The Client object is short-lived, instantiated inside of worker 0 and moved to
workers 1 to N for processing a given client's data. It's main method is the
`process_round` function, used to update the model given a client's data.
The Client object is short-lived, instantiated inside workers 1 to N for
processing a given client's data. It's main method is the `process_round`
function, used to update the model given a client's data.
'''
import copy
@ -17,12 +17,7 @@ import numpy as np
import torch
# Internal imports
from core.globals import TRAINING_FRAMEWORK_TYPE
if TRAINING_FRAMEWORK_TYPE == 'mpi':
import core.federated as federated
else:
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
from .strategies import select_strategy
from .trainer import (
Trainer,
@ -41,18 +36,20 @@ from utils.dataloaders_utils import (
make_train_dataloader,
make_val_dataloader,
make_test_dataloader,
get_dataset,
)
import extensions.privacy
from extensions.privacy import metrics as privacy_metrics
from experiments import make_model
global train_dataset
class Client:
# It's unclear why, but sphinx refuses to generate method docs
# if there is no docstring for this class.
"""Client class for specifying individual client training tasks"""
def __init__(self, client_id, config, send_gradients, dataloader):
def __init__(self, client_id, config, send_gradients):
'''
Client side processing: computing gradients, update the model and send them back to the server
@ -61,47 +58,38 @@ class Client:
config (dict): dictionary with parameters loaded from config file.
send_gradients (bool): if True, model gradients are sent back;
otherwise, model weights are sent back.
dataloader (torch.utils.data.DataLoader): dataloader that generates
training data for the client.
'''
super().__init__()
self.client_id = client_id
self.client_data = self.get_data(client_id, dataloader)
self.config = copy.deepcopy(config)
self.send_gradients = send_gradients
def get_client_data(self):
def get_client_data(self, dataset=None):
'''"Getter" method that returns all object's attributes at once.'''
return self.client_id, self.client_data, self.config, self.send_gradients
client_data = self.get_data(self.client_id, dataset)
return self.client_id, client_data, self.config, self.send_gradients
@staticmethod
def get_train_dataset(data_path, client_train_config, task):
'''This function will obtain the training dataset for all
'''This function will obtain the dataset for all training
users.
Args:
data_path (str): path to file containing taining data.
client_train_config (dict): trainig data config.
task (str): task name.
'''
try:
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
loader = SourceFileLoader("Dataset",dir).load_module()
dataset = loader.Dataset
train_file = os.path.join(data_path, client_train_config['list_of_train_data']) if client_train_config['list_of_train_data'] != None else None
train_dataset = dataset(train_file, args=client_train_config)
num_users = len(train_dataset.user_list)
print_rank("Total amount of training users: {}".format(num_users))
except:
print_rank("Dataset not found, please make sure is located inside the experiment folder")
return num_users, train_dataset
global train_dataset
train_dataset = get_dataset(data_path, client_train_config, task, mode="train")
return len(train_dataset.user_list)
@staticmethod
def get_data(clients, dataset):
''' Create training dictionary'''
dataset = train_dataset if len(clients) ==1 else dataset # clients is an integer only for training mode
data_with_labels = hasattr(dataset,"user_data_label")
input_strct = {'users': [], 'num_samples': [],'user_data': dict(), 'user_data_label': dict()} if data_with_labels else {'users': [], 'num_samples': [],'user_data': dict()}
@ -226,7 +214,7 @@ class Client:
# Ensure the client is assigned to the correct GPU
if torch.cuda.is_available() and torch.cuda.device_count() == federated.size():
torch.cuda.set_device(federated.local_rank())
torch.cuda.set_device(federated.rank())
# Process inputs and initialize variables
client_id, data_strct, config, send_gradients = client_data

Просмотреть файл

@ -10,16 +10,9 @@ import torch
import numpy as np
# Internal imports
from core.globals import TRAINING_FRAMEWORK_TYPE
if TRAINING_FRAMEWORK_TYPE == 'mpi':
import core.federated as federated
else:
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
from core.client import Client
from utils import (
print_rank
)
from utils import print_rank
# AzureML-related libs
from azureml.core import Run
@ -27,14 +20,14 @@ run = Run.get_context()
class Evaluation():
def __init__(self, config, model_path, process_testvalidate, val_dataloader, test_dataloader):
def __init__(self, config, model_path, process_testvalidate, idx_val_clients, idx_test_clients):
self.config = config
self.model_path = model_path
self.process_testvalidate = process_testvalidate
self.test_dataloader = val_dataloader
self.val_dataloader = test_dataloader
self.server_type = config['server_config']['type']
self.idx_val_clients = idx_val_clients
self.idx_test_clients = idx_test_clients
super().__init__()
@ -108,37 +101,35 @@ class Evaluation():
def run_distributed_inference(self, mode):
'''Call `run_distributed_evaluation` specifically for test or validation.
This is just a helper function that fetches the dataloader depending on
the mode and calls `run_distributed_evaluation` using that dataloader.
This is just a helper function that fetches the clients depending on
the mode and calls `run_distributed_evaluation` using that list.
Args:
mode (str): `test` or `val`.
'''
if mode == 'val':
dataloader = self.val_dataloader
clients = self.idx_val_clients
elif mode == 'test':
dataloader = self.test_dataloader
clients = self.idx_test_clients
else:
raise NotImplementedError('Unsupported mode: {}'.format(mode))
return self.run_distributed_evaluation(dataloader, mode)
def run_distributed_evaluation(self, dataloader, mode):
return self.run_distributed_evaluation(mode, clients)
def run_distributed_evaluation(self, mode, clients):
'''Perform evaluation using available workers.
See also `process_test_validate` on federated.py.
Args:
dataloader (torch.utils.data.DataLoader): used to fetch data.
mode (str): `test` or `val`.
clients (list): clients for test/val round.
'''
val_clients = list(self.make_eval_clients(dataloader))
print_rank(f'mode: {mode} evaluation_clients {len(val_clients)}', loglevel=logging.DEBUG)
total = 0
self.logits = {'predictions': [], 'probabilities': [], 'labels': []}
server_data = (0.0, [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()])
for result in self.process_testvalidate(val_clients, server_data, mode):
for result in self.process_testvalidate(clients, server_data, mode):
output, metrics, count = result
val_metrics = {key: {'value':0, 'higher_is_better': False} for key in metrics.keys()} if total == 0 else val_metrics
@ -164,34 +155,35 @@ class Evaluation():
self.losses = [val_metrics['loss']['value'], val_metrics['acc']['value']] # For compatibility with Server
return val_metrics
def make_eval_clients(self, dataloader):
def make_eval_clients(dataset, config):
'''Generator that yields clients for evaluation, continuously.
Args:
dataloader (torch.utils.data.DataLoader): used to get client's data
dataset (torch.utils.data.Dataset): used to get client's data
config (dict): used for the client's constructor
'''
total = sum(dataloader.dataset.num_samples)
total = sum(dataset.num_samples)
clients = federated.size() - 1
delta = total / clients + 1
threshold = delta
current_users_idxs = list()
current_total = 0
if self.server_type == "personalization":
for i in range(len(dataloader.dataset.user_list)):
yield Client([i], self.config, False, dataloader.dataset)
if config["server_config"]["type"] == "personalization":
for i in range(len(dataset.user_list)):
yield Client([i], config, False)
else:
for i in range(len(dataloader.dataset.user_list)):
for i in range(len(dataset.user_list)):
current_users_idxs.append(i)
count = dataloader.dataset.num_samples[i]
count = dataset.num_samples[i]
current_total += count
if current_total > threshold:
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
yield Client(current_users_idxs, config, False)
current_users_idxs = list()
current_total = 0
if len(current_users_idxs) != 0:
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
yield Client(current_users_idxs, self.config, False, dataloader.dataset)
yield Client(current_users_idxs, config, False)

Просмотреть файл

@ -1,53 +1,242 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import cProfile
import gc
import os
import pickle
import logging
import torch
from mpi4py import MPI
import torch.distributed as dist
import numpy as np
from core.client import Client
from utils import (
print_rank,
print_profiler
print_profiler,
to_device,
)
from utils.queue import process_in_parallel
COMMAND_UPDATE = 0
COMMAND_TRAIN = 1
COMMAND_TERMINATE = 10
COMMAND_TESTVAL = 11
COMMAND_SYNC_NODES = 9
SPLIT_SIZE = 512 * 1024 * 1024 # messages above this size (in bytes) are split
def encode_string(word, string_to_int = True):
""" Encodes/Decodes the dictionary keys into an array of integers to be sent
as tensors of the same shape during NCCL/Gloo P2P communication.
COMMAND_UPDATE = "update"
COMMAND_TRAIN = "train"
COMMAND_TERMINATE = "terminate"
COMMAND_TESTVAL = "testvalidate"
Args:
word (string/array): key to be encoded/decoded.
string_to_int (bool): flag that indicates which action to perform.
"""
if string_to_int: # encode
word = word.ljust(8, ' ') if len(word) < 8 else word # padding -- 8 is max length, all tensors must have the same size during communication
word_encoded = [letter for letter in word.encode()]
return word_encoded
else: #decode
cleanup_array = [letter for letter in word if letter!= 32] # Remove padding
word_decoded = bytes(cleanup_array).decode()
return word_decoded
def rank():
"""Return rank of node"""
return MPI.COMM_WORLD.Get_rank()
def local_rank():
"""Return local rank of MPI node"""
assert (
"OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
), "local rank can only be determined when using OpenMPI"
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
""" Return rank of node. """
return dist.get_rank()
def size():
"""Returns number of MPI nodes including server"""
return MPI.COMM_WORLD.Get_size()
""" Returns number of nodes in the distributed group, including server. """
return dist.get_world_size()
def _recv(x, src=0):
""" Receives tensors with a single element or a list of tensors
with the same shape during distributed communication. """
x = torch.tensor(x) if torch.is_tensor(x) == False else x
x = to_device(x)
dist.recv(tensor=x, src=src)
x.to('cpu')
try:
return x.item() # single element
except:
return x.tolist() # list of tensors
def _recv_gradients(src):
""" Receives a list of tensors with different shape during
distributed communication. """
n, n_dimensions, grads = 0, 0, [] # tensors intialization -- required by torch.
n = _recv(n,src)
for i in range(n):
n_dimensions = _recv(n_dimensions,src)
dimensions = [0 for i in range(n_dimensions)]
dimensions = _recv(dimensions, src)
print_rank(f"Received dimensions {dimensions}", loglevel=logging.DEBUG)
param = to_device(torch.zeros(dimensions))
print_rank(f"Shape assigned {param.shape}", loglevel=logging.DEBUG)
dist.recv(param,src)
grads.append(param.detach().cpu())
torch.cuda.empty_cache()
return grads
def _send(x, dst=0):
""" Send tensors with a single element or a list of tensors
with the same shape during distributed communication. """
x = torch.tensor(x)
x = to_device(x)
dist.send(x, dst)
del x
torch.cuda.empty_cache()
def _send_metrics(output):
""" Organize the keys and values from the resulting dictionary
from test/val rounds into arrays that are sent as independent
tensors during distributed communication. """
keys = [encode_string(key) for key in output.keys()]
values = [float(output[key]['value']) for key in output.keys()]
higher_is_better = [int(output[key]['higher_is_better']) for key in output.keys()] # send the boolean as int
_send(len(keys),0)
_send(keys)
_send(values)
_send(higher_is_better)
def _send_gradients(gradients, dst):
""" Send a list of tensors with different shape during
distributed communication. """
_send(len(gradients), dst)
for i in gradients:
dimensions = [int(d) for d in i.shape]
_send(len(dimensions),dst)
_send(dimensions,dst)
param = to_device(i)
dist.send(param,dst)
del param
torch.cuda.empty_cache()
def _send_train_output(output):
""" Organize the keys and values from the the returning ´client_output´
dictionary in ´Client.proces_round()´ function during training rounds,
into arrays that are sent as independent tensors during distributed
communication. """
cs_values = [float(cs_v) for cs_v in output['cs'].values()] # cs dict -- values are flatten in 1d array
pl_values = [float(output['pl']['weight'])] # pl dict
gradients = output['pl']['gradients'] # gradients are sent independently
values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl']] + pl_values # reorganizing values in the order expected by the Server
# Send data
_send(values, 0)
_send_gradients(gradients, 0)
def build_grads_dict(node):
""" Reconstruct the dictionary ´client_output´ returned by
´Client.proces_round()´ function on the Server side during
distributed communication. """
# Initialize tensors
keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl']
values = [0.0 for i in range(11)] # initializing tensor shape -- 11 is fixed number of keys expected
# Read data
values = _recv(values,node)
grads = _recv_gradients(node)
# Rebuilding original dictionary
cs_values = [{key: values.pop(0) for key in ['setup','training','full cost']}] # recreating cs dict
pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
values_list = cs_values + [values.pop(0) for i in range(7)] + pl_values # 7 is fixed length for remaining items
result = dict(zip(keys,values_list))
# Cast values to original type
for key in ['mg','vg','ng','rg']:
result[key] = np.float32(result[key])
result['ns'] = int(result['ns'] )
return result
def build_metrics_dict(node):
""" Reconstruct the dictionary returned during test/val rounds
on the Server side during distributed communication. """
# Initialize tensors
n = 0
n = _recv(n,node)
keys = [[0 for j in range(8)] for i in range(n)] # max_seq_len for metric name is 8
values = [0.0 for i in range(n)]
higher_is_better = [0 for i in range(n)]
# Read data
keys = _recv(keys,node)
values = _recv(values,node)
higher_is_better = _recv(higher_is_better,node)
# Reorganize output + decode dict keys
orig_keys = [encode_string(key, string_to_int=False) for key in keys]
values_dict = [{'value': float(v), 'higher_is_better': bool(higher_is_better[i])} for i, v in enumerate(values)]
metrics = dict(zip(orig_keys,values_dict))
num_instances = int(metrics.pop('num')['value'])
result = None, metrics, num_instances
return result
def receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes):
""" Receives the clients output on the Server side in async/sync mode.
Asynchronous mode is only enabled when using NCCL backend given that Gloo
does not provide native non-blocking implementation to check if the operation
has been completed during distributed training"""
if dist.get_backend() == "nccl": # Async
for node, req in node_request_map:
if req.is_completed():
result = build_metrics_dict(node) if command == COMMAND_TESTVAL else build_grads_dict(node)
results_list.append(result)
free_nodes.append(node)
node_request_map.remove((node,req))
print_rank(f"Finished releasing the nodes {free_nodes}", loglevel=logging.DEBUG)
else: # Sync
print_rank(f"Waiting for a workers", loglevel=logging.DEBUG)
gather_objects = [(None,None,None) for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f" All workers have finished ... taking the remaining clients {len(output)}", loglevel=logging.DEBUG)
output = [e for i,e in enumerate(output) if i not in idle_nodes ] # Cleanup for idle workers
results_list = results_list + output[1:]
free_nodes = list(range(1, size()))
return node_request_map, results_list, free_nodes
def append_async_requests(node_request_map, node):
""" Appends the asynchronous request sent to each worker during
asynchronous training. """
ack = to_device(torch.zeros(1))
req = dist.irecv(tensor=ack, src=node)
node_request_map.append((node,req))
return node_request_map
def sync_idle_nodes(client_queue, free_nodes):
""" Request dummy outputs to the odd (idle) nodes during synchronous training
to prevent them to get trapped in the state of the previous iterations """
idle_nodes = []
if len(client_queue) == 0:
print_rank(f"Free idle nodes {len(free_nodes)}", loglevel=logging.DEBUG)
while len(free_nodes) > 0:
node = free_nodes.pop()
idle_nodes.append(node)
_send(COMMAND_SYNC_NODES, node)
return idle_nodes
class Server:
"""Server object responsible for orchestration and aggregation.
The Server is one of the two objects that may exist inside of a thread, all
throughout its execution (the other being the Worker). At every round, the
Server samples clients and sends their data for an available Worker to process.
Server samples clients ids and send their data for an available Worker to process.
The Workers then each produce a new model, and all models are sent to the Server
for aggregation.
@ -59,161 +248,137 @@ class Server:
It thus only serves the purpose of grouping the methods, but nothing
is actually stored inside of the object.
"""
@staticmethod
def dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=None):
"""Perform execution of client code on the worker nodes.
def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False):
"""Perform the orchestration between Clients and Workers.
This function does the following:
1. It sends the server_data to all workers
2. For each client:
2a. It sends the function process_round of the client
to a free worker.
2b. It calls get_client_data on the client.
2c. It triggers the execution of the payload_fn on the
worker with parameters server_data and client_data.
2. For each available Worker:
2a. It sends the index of the client to instantiate
2c. It triggers the execution of the command on the
Client.
3. Collect and return all client outputs.
Notes:
This function yields the gradients of different clients
as they are received. Therefore, the order of the results generally
does not correspond to the order of the clients.
All commands used during Server-Worker communication must be
float/integers given that torch.distributed only allows to
send/recv tensors.
Args:
clients (list): list of clients to be processed.
server_data (dict): server data sent to the workers and passed to
clients, typically includes the global model at that step.
payload_fn (callback): instructions for worker to execute.
clients_in_parallel (int or None): how many threads will be used for
processing clients, defaults to None in which case all of them
are processed on the same thread.
command (int): instruction for worker to execute on the Client.
mode (int): test/val only provided during evaluation rounds.
do_profiling (bool): enables profiler during comunication.
Returns:
Generator of results sent by server via MPI.
Generator of results.
"""
# Send args to workers
data_pickled = pickle.dumps(server_data) # pickle once
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
MPI.COMM_WORLD.send(COMMAND_UPDATE, worker_rank)
_send(data_pickled, worker_rank, pickled=True)
# Perform payload_fn on clients
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
# Initialize communication profiler
profiler = None
if do_profiling:
profiler = cProfile.Profile()
profiler.enable()
# Update lr + model parameters each round for all workers
lr, model_params = server_data
for worker_rank in range(1, size()):
_send(COMMAND_UPDATE, worker_rank)
_send(lr,worker_rank)
_send_gradients(model_params, worker_rank)
print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank}", loglevel=logging.DEBUG)
print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG)
client_queue = clients.copy()
free_nodes = list(range(1, MPI.COMM_WORLD.Get_size()))
node_request_map = []
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
free_nodes = list(range(1, size()))
results_list, node_request_map = [], []
# Initiate computation for all clients
while client_queue:
if clients_in_parallel is not None:
clients_to_process = [client_queue.pop() for _ in range(clients_in_parallel) if len(client_queue) > 0]
else:
clients_to_process = client_queue.pop()
print_rank(f"Queueing {clients_to_process}, {len(client_queue)} remaining", loglevel=logging.DEBUG)
# Wait for free worker node
if not free_nodes:
print_rank(f"Waiting for a worker", loglevel=logging.DEBUG)
assert(len(node_request_map) > 0)
status = MPI.Status()
ix, _ = MPI.Request.waitany(node_request_map, status=status)
# Collects worker output after processing has finished
output = _recv(status.source)
if isinstance(output, list):
yield from output
else:
yield output
free_nodes.append(status.source)
print_rank(f"Found free worker {ix}:{status.source}", loglevel=logging.DEBUG)
node_request_map.pop(ix)
# Run client computation on free worker node
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
assert len(free_nodes) > 0
node = free_nodes.pop()
print_rank(f"Sending to worker {node}", loglevel=logging.DEBUG)
payload_fn(clients_to_process, node)
print_rank(f"Payload sent. Queueing irecv on {node}", loglevel=logging.DEBUG)
node_request_map.append(MPI.COMM_WORLD.irecv(source=node))
print_rank(f"Queued irecv for {node}", loglevel=logging.DEBUG)
index = len(client_queue)-1
client_to_process = client_queue.pop(index)
print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG)
_send(command, node) # The command should indicate the worker which function to run on the client
print_rank(f"Done queuing clients. Waiting on workers")
if command == COMMAND_TESTVAL:
_send(mode,node) # Only for test/val has a value
_send(index, node) # Worker receives the index of the client to pop
elif command == COMMAND_TRAIN:
_send(client_to_process, node)
print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG)
if dist.get_backend() == "nccl":
append_async_requests(node_request_map, node)
idle_nodes = None
else:
idle_nodes = sync_idle_nodes(client_queue, free_nodes)
# Waits until receive the output from all ranks
if not free_nodes:
print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG)
while len(free_nodes) == 0:
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
# Wait for all workers to finish
for i, request in enumerate(node_request_map):
status = MPI.Status()
request.wait(status)
print_rank(f"Result for item {i}: source: {status.source}", loglevel=logging.DEBUG)
while (len(node_request_map)) != 0:
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
print_rank(f"Calling _recv for {status.source}", loglevel=logging.DEBUG)
output = _recv(status.source)
if isinstance(output, list):
yield from output
else:
for output in results_list:
yield output
@staticmethod
def process_clients(clients, server_data, clients_in_parallel):
"""Ask workers to process client data.
if do_profiling:
profiler.disable()
print_profiler(profiler)
The payload function defined below will send a free worker instructions
on how to process the data of one or more clients. This payload function
is then passed to :code:`dispatch_clients`, which continuously looks for
free workers and sends them more clients to process.
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
@staticmethod
def process_clients(clients, server_data):
"""Ask workers to perform training on Clients.
Args:
clients (list): list of client.Client objects.
clients (list): list of clients indexes sampled by ´Server.py´
object per iteration.
server_data (dict): dictionary containing model.
clients_in_parallel (None or int): how many threads to use for
processing the clients on a given worker.
Returns:
Generator of results sent by server via MPI.
Generator of results.
"""
def payload_fn(clients, node):
"""Payload function for a training round."""
# Send command for training and function to process round
MPI.COMM_WORLD.send(COMMAND_TRAIN, node)
# Loop through clients and send their data
if clients_in_parallel is None:
MPI.COMM_WORLD.send(clients.process_round, node)
MPI.COMM_WORLD.send(clients.get_client_data(), node)
else:
MPI.COMM_WORLD.send(clients[0].process_round, node) # clients is a list
MPI.COMM_WORLD.send(len(clients), node)
for client in clients:
MPI.COMM_WORLD.send(client.get_client_data(), node)
return Server.dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=clients_in_parallel)
return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN)
@staticmethod
def process_testvalidate(clients, server_data, mode):
"""Ask workers to use clients data to compute metrics.
Similar to :code:`process_round` but asks workers to
compute metrics instead, by using a different payload function.
"""Ask workers to perform test/val on Clients.
Args:
clients (list): list of client.Client objects.
clients (list): list of clients indexes for test/val rounds.
server_data (dict): dictionary containing model.
mode(str): whether to :code:`test` or :code:`validate`.
mode (str): test/val.
Returns:
Generator of results sent by server via MPI.
Generator of results.
"""
def payload_fn(client, node):
"""Payload function for a test/validation round."""
MPI.COMM_WORLD.send(COMMAND_TESTVAL, node)
MPI.COMM_WORLD.send(client.run_testvalidate, node)
MPI.COMM_WORLD.send(client.get_client_data(), node)
MPI.COMM_WORLD.send(mode, node)
return Server.dispatch_clients(clients, server_data, payload_fn)
mode = [-2] if mode == "test" else [2]
return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode)
@staticmethod
def terminate_workers(terminate=True):
@ -221,15 +386,15 @@ class Server:
if terminate:
print_rank("Terminating worker processes")
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
MPI.COMM_WORLD.send(COMMAND_TERMINATE, worker_rank)
for worker_rank in range(1, size()):
_send(COMMAND_TERMINATE, worker_rank)
class Worker:
"""Worker object responsible for processing clients' data.
"""Worker object responsible for instantiate Clients based on incoming data
from the Server and perform train/eval functions on it.
Each worker lives on a different MPI thread and is assigned to a different
GPU. Via the :code:`dispatch_clients` function, the Server passes the
Each worker lives on a different NCCL/Gloo thread and is assigned to a different
GPU. Via the :code:`dispatch_clients` function, the Server passes to the
Worker specific instructions to process clients' data, typically in order
to generate a new model or to compute metrics.
@ -237,85 +402,89 @@ class Worker:
model (torch.nn.Module): model being trained.
data_path (str): path where all clients' data is located.
do_profiling (bool): if True, analyzes execution in depth.
clients_in_parallel (None or int): if not None, processes clients in
threads during training round.
server_data (dict): stores data received from Server when an update
command is received.
val_clients (list): clients list for validation rounds.
test_clients (list): clients list for testing rounds.
config (dict): clients configuration.
val_dataset (torch.utils.data.Dataset): validation dataset.
test_dataset (torch.utils.data.Dataset): testing dataset.
"""
def __init__(self, model=None, data_path=None, do_profiling=False, val_clients= None, \
test_clients=None, config=None, val_dataset = None, test_dataset = None):
def __init__(self, model=None, data_path=None, do_profiling=False, clients_in_parallel=None):
"""
Set the GPU workspace for the model to be exchanged between the server and clients
This prevents a model instance from being created on the GPU worker many time
Args:
model (torch.nn.Module, optional): model being trained, defaults to None.
data_path (str, optional): path where all clients' data is located,
defaults to None.
do_profiling (bool, optional): if True, analyzes execution in depth; defaults
to False.
clients_in_parallel (None or int, optional): if not None, processes clients in
threads during training round. Defaults to None.
"""
self.model = model
self.data_path = data_path
self.do_profiling = do_profiling
self.clients_in_parallel = clients_in_parallel
self.server_data = None
# For processing in different threads, we need copies of the model
if clients_in_parallel is not None:
device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
self.model_copies = [copy.deepcopy(model).to(device) for _ in range(clients_in_parallel)]
self.config = config
self.val_clients = val_clients
self.test_clients = test_clients
self.val_dataset = val_dataset
self.test_dataset = test_dataset
def run(self):
"""Main loop executed by worker nodes.
This method triggers the MPI communication between the worker and
This method handles the NCCL/Gloo communication between the worker and
the server. It keeps listening for commands from the Server,
and performs different actions depending on the command received.
and performs different actions on the Client assigned depending on
the command received.
"""
while True: # keeps listening for commands on MPI
command = MPI.COMM_WORLD.recv()
assert isinstance(command, str)
while True: # keeps listening for incoming server calls
# Initialize tensors -- required by torch.distributed
command, client_idx, mode = 0, 0, 0 # int
lr = torch.zeros(1) # float
# Read command
command = _recv(command)
print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG)
# Receive server data -- lr, model_params
if command == COMMAND_UPDATE:
self.server_data = _recv(0)
print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG)
lr = _recv(lr, 0)
model_params = _recv_gradients(0)
server_data = (lr, model_params)
print_rank(f"Received lr: {lr} and n_params: {len(model_params)}", loglevel=logging.DEBUG)
elif command == COMMAND_TRAIN:
print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG)
# Init profiler in training worker
profiler = None
if self.do_profiling:
profiler = cProfile.Profile()
profiler.enable()
client_fn = MPI.COMM_WORLD.recv() # NOTE: assumes function is same for all clients
# Receive client id from Server
client_idx = _recv(client_idx)
print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG)
# Pick whether to do processing in batches or not
if self.clients_in_parallel is None:
client_data = MPI.COMM_WORLD.recv()
# Instantiate client
client_to_process = Client(
[client_idx],
self.config,
self.config['client_config']['type'] == 'optimization')
torch.cuda.empty_cache()
output = client_fn(client_data, self.server_data, self.model, self.data_path)
# Execute Client.get_data()
client_data = client_to_process.get_client_data()
# Execute Client.process_round()
output = client_to_process.process_round(client_data, server_data, self.model, self.data_path)
# Send output back to Server
if dist.get_backend() == "nccl":
# ASYNC mode -- enabled only for nccl backend
ack = to_device(torch.tensor(1))
dist.isend(tensor=ack, dst=0)
_send_train_output(output)
else:
n_clients = MPI.COMM_WORLD.recv()
client_data = [MPI.COMM_WORLD.recv() for _ in range(n_clients)]
torch.cuda.empty_cache()
output = process_in_parallel(client_fn, client_data, self.server_data, self.model_copies, self.data_path)
print_rank(f"Processed batch of size {len(client_data)}, got {len(output)} outputs", loglevel=logging.DEBUG)
# Wait for server to be available and send output(s)
MPI.COMM_WORLD.isend(None, 0).wait()
_send(output, 0)
# Make sure that memory is cleaned up
if self.clients_in_parallel is not None:
for args in client_data:
del args
del client_fn, client_data, output
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
gather_objects = [output for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
@ -324,29 +493,51 @@ class Worker:
print_profiler(profiler)
elif command == COMMAND_TESTVAL:
print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG)
# Init profiler in validation worker
profiler = None
if self.do_profiling:
profiler = cProfile.Profile()
profiler.enable()
client_fn = MPI.COMM_WORLD.recv()
client_data = MPI.COMM_WORLD.recv()
client_mode = MPI.COMM_WORLD.recv()
# Receive mode and client id from Server
mode = _recv(mode)
mode = "test" if mode == -2 else "val"
client_idx = _recv(client_idx)
print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG)
# Clean up memory before client processing
torch.cuda.empty_cache()
# Get client and dataset
clients = self.val_clients if mode == "val" else self.test_clients
dataset = self.val_dataset if mode == "val" else self.test_dataset
clients_queue = clients.copy()
assert 0 <= client_idx < len(clients_queue)
client_to_process = clients_queue.pop(client_idx)
try:
output = client_fn(client_data, self.server_data, client_mode, self.model)
except RuntimeError as e:
_dump_tensors(gpu_only=True)
raise RuntimeError("Federated Error: {}".format(str(e)))
# Execute Client.get_data()
client_data = client_to_process.get_client_data(dataset)
MPI.COMM_WORLD.isend(None, 0).wait()
_send(output, 0)
# Execute Client.run_testvalidate()
output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)
# Make sure that memory is cleaned up
del client_fn, client_data, output
# Send output back to Server
if dist.get_backend() == "nccl":
# ASYNC mode -- enabled only for nccl backend
_, metrics, num_instances = output
metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
output = metrics
print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG)
ack = to_device(torch.tensor(1))
dist.isend(tensor=ack, dst=0)
_send_metrics(output)
else:
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
gather_objects = [output for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
@ -355,82 +546,23 @@ class Worker:
print_profiler(profiler)
elif command == COMMAND_TERMINATE:
print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
return
elif command == COMMAND_SYNC_NODES: # Only for sync calls
print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG)
gather_objects = [None for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
else:
assert False, "unknown command"
def _send(data, rank, pickled=False, verbose=False):
"""Send large object by chunking it into multiple MPI messages."""
# Pickle data
data_pickled = data
if not pickled:
data_pickled = pickle.dumps(data_pickled)
# Compute in how many chunks data will be sent
num_chunks = len(data_pickled) // SPLIT_SIZE + 1
if verbose:
print_rank(f"_send data_pickled size: {len(data_pickled)}, {num_chunks} chunks")
# Send data in chunks
MPI.COMM_WORLD.send(num_chunks, rank)
ix = 0
while len(data_pickled) - ix > SPLIT_SIZE:
MPI.COMM_WORLD.send(data_pickled[ix:ix+SPLIT_SIZE], rank)
ix += SPLIT_SIZE
MPI.COMM_WORLD.send(data_pickled[ix:], rank)
def _recv(rank):
"""Receive large object by chunking it into multiple MPI messages."""
num_chunks = MPI.COMM_WORLD.recv(source=rank)
pickled_chunks = []
for _ in range(num_chunks):
pickled_chunks.append(MPI.COMM_WORLD.recv(source=rank))
data_pickled = b"".join(pickled_chunks)
return pickle.loads(data_pickled)
def _dump_tensors(gpu_only=True):
"""Print a list of the Tensors being tracked by the garbage collector."""
def pretty_size(size):
"""Pretty prints a torch.Size object."""
assert(isinstance(size, torch.Size))
return " × ".join(map(str, size))
print_rank("Dump memory allocated")
print_rank(torch.cuda.memory_allocated())
print_rank("Dump max memory allocated")
print_rank(torch.cuda.max_memory_allocated())
print_rank("Dump memory cached")
print_rank(torch.cuda.memory_cached())
print_rank("Dump max memory cached")
print_rank(torch.cuda.max_memory_cached())
total_size = 0
for obj in gc.get_objects():
try:
if torch.is_tensor(obj):
if not gpu_only or obj.is_cuda:
print("%s:%s%s %s" % (type(obj).__name__,
" GPU" if obj.is_cuda else "",
" pinned" if obj.is_pinned else "",
pretty_size(obj.size())))
total_size += obj.numel()
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
if not gpu_only or obj.is_cuda:
print("%s -> %s:%s%s%s%s %s" % (type(obj).__name__,
type(obj.data).__name__,
" GPU" if obj.is_cuda else "",
" pinned" if obj.data.is_pinned else "",
" grad" if obj.requires_grad else "",
" volatile" if obj.volatile else "",
pretty_size(obj.data.size())))
total_size += obj.data.numel()
except Exception as e:
pass
print_rank("Total size: {}".format(total_size))

Просмотреть файл

@ -1,8 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
TRAINING_FRAMEWORK_TYPE = 'mpi'
logging_level = logging.INFO # DEBUG | INFO

Просмотреть файл

@ -19,11 +19,7 @@ import numpy as np
import torch
# Internal imports
from core.globals import TRAINING_FRAMEWORK_TYPE
if TRAINING_FRAMEWORK_TYPE == 'mpi':
import core.federated as federated
else:
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
from core.evaluation import Evaluation
from core.client import Client
from .strategies import select_strategy
@ -49,8 +45,8 @@ run = Run.get_context()
class OptimizationServer(federated.Server):
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
val_dataloader, test_dataloader, config, config_server):
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
config, idx_val_clients, idx_test_clients):
'''Implement Server's orchestration and aggregation.
This is the main Server class, that actually implements orchestration
@ -67,11 +63,10 @@ class OptimizationServer(federated.Server):
ss_scheduler: scheduled sampling scheduler.
data_path (str): points to where data is.
model_path (str): points to where pretrained model is.
train_dataloader (torch.utils.data.DataLoader): dataloader for training
val_dataloader (torch.utils.data.DataLoader): dataloader for validation
test_dataloader (torch.utils.data.DataLoader): dataloader for test, can be None
server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
config (dict): JSON style configuration parameters
config_server: deprecated, kept for API compatibility only.
idx_val_clients (list): validation client ids
idx_test_clients (list): testing clients ids
'''
super().__init__()
@ -92,7 +87,7 @@ class OptimizationServer(federated.Server):
self.val_freq = server_config['val_freq']
self.req_freq = server_config['rec_freq']
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, val_dataloader, test_dataloader)
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients)
# TODO: does this need to be adjusted for custom metrics?
self.metrics = dict()
@ -122,8 +117,8 @@ class OptimizationServer(federated.Server):
model=model,
optimizer=optimizer,
ss_scheduler=ss_scheduler,
train_dataloader=train_dataloader if train_dataloader is not None else val_dataloader,
val_dataloader=val_dataloader,
train_dataloader=server_train_dataloader if server_train_dataloader is not None else None,
val_dataloader=None,
max_grad_norm=max_grad_norm,
anneal_config=server_config['annealing_config'],
model_type=self.model_type,
@ -133,8 +128,7 @@ class OptimizationServer(federated.Server):
# Creating an instance for the server-side trainer (runs mini-batch SGD)
self.server_replay_iterations = None
self.server_trainer = None
self.train_dataset = train_dataset
if train_dataloader is not None:
if server_train_dataloader is not None:
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
assert 'optimizer_config' in server_config[
'server_replay_config'], 'server-side replay training optimizer is not set'
@ -145,7 +139,7 @@ class OptimizationServer(federated.Server):
model=model,
optimizer=None,
ss_scheduler=ss_scheduler,
train_dataloader=train_dataloader,
train_dataloader=server_train_dataloader,
server_replay_config=server_config['server_replay_config'],
val_dataloader=None,
max_grad_norm=server_config['server_replay_config']\
@ -180,9 +174,6 @@ class OptimizationServer(federated.Server):
self.do_profiling = server_config.get('do_profiling', False)
# Parallel processing
self.clients_in_parallel = config['client_config'].get('clients_in_parallel', None)
StrategyClass = select_strategy(config['strategy'])
self.strategy = StrategyClass('server', self.config, self.model_path)
print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)
@ -230,7 +221,7 @@ class OptimizationServer(federated.Server):
'secsPerClientFull': [],
'secsPerRoundHousekeeping': [],
'secsPerRoundTotal': [],
'mpiCosts': []
'communicationCosts': []
}
run.log('Max iterations', self.max_iteration)
@ -304,14 +295,6 @@ class OptimizationServer(federated.Server):
# Create the pool of clients -- sample from this pool to assign to workers
sampled_idx_clients = random.sample(self.client_idx_list,
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
sampled_clients = [
Client(
[client_id],
self.config,
self.config['client_config']['type'] == 'optimization',
self.train_dataset
) for client_id in sampled_idx_clients
]
# Initialize stats
clients_begin = time.time()
@ -326,7 +309,7 @@ class OptimizationServer(federated.Server):
self.run_stats['secsPerClientFull'].append([])
self.run_stats['secsPerClientTraining'].append([])
self.run_stats['secsPerClientSetup'].append([])
self.run_stats['mpiCosts'].append([])
self.run_stats['communicationCosts'].append([])
# Check if we want privacy metrics
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
@ -345,7 +328,8 @@ class OptimizationServer(federated.Server):
# Reset gradient for the model before assigning the new gradients
self.worker_trainer.model.zero_grad()
for client_output in self.process_clients(sampled_clients, server_data, self.clients_in_parallel):
print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
for client_output in self.process_clients(sampled_idx_clients, server_data):
# Process client output
client_timestamp = client_output['ts']
client_stats = client_output['cs']
@ -361,7 +345,7 @@ class OptimizationServer(federated.Server):
for metric, value in privacy_stats.items():
privacy_metrics_stats[metric].append(value)
self.run_stats['mpiCosts'][-1].append(time.time() - client_timestamp)
self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)
# Get actual pseudo-gradients for aggregation
payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
@ -513,7 +497,7 @@ class OptimizationServer(federated.Server):
'secsPerClientTraining',
'secsPerClientFull',
'secsPerClientSetup',
'mpiCosts',
'communicationCosts',
]
for metric in metrics_for_stats:

Просмотреть файл

@ -10,13 +10,13 @@ Install the requirements stated inside of requirements.txt. Ideally this sould b
conda create -n FLUTE python==3.8
pip install -r requirements.txt
You will also need some MPI runtime such as OpenMPI (on Linux) or MS-MPI (on Windows). There is no setup.py as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
FLUTE uses torch.distributed API as its main communication backbone, supporting three buil-in backends. For more information please refer to [Distributed Communication Package](https://pytorch.org/docs/stable/distributed.html). Therefore, we highly suggest to use NCCL backend for distributed GPU training and Gloo for distributed CPU training. There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
After this initial setup you can use your data for launching a local run. However the following instructions will be adapted to run ``nlg_gru`` task. For running this example, you need to first download and preprocess the data. Instructions can be found `here`_. Once the data is available you can run FLUTE from root as follows:
.. code:: bash
mpiexec -n 3 python e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru -backend nccl
.. _here: https://github.com/microsoft/msrflute/tree/main/testing
@ -56,14 +56,15 @@ For running experiments on AzureML, the CLI can help. You should first install t
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
python3 -m pip install --upgrade pip &&
python3 -m pip install -r requirements.txt &&
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
-outputPath=./outputs
-dataPath={inputs.data}
-task=classif_cnn
-config=./experiments/classif_cnn/config.yaml|
-config=./experiments/classif_cnn/config.yaml
-backend=nccl
You should replace ``compute`` with the name of the one you created before, and adjust the path of the datastore containing the data. In the example above, we created a datastore called ``data`` and added to it a folder called ``cifar``, which contained the two HDF5 files. The command passed above will install dependencies and then launch an MPI job with 4 threads, for the experiment defined in ``experiments/classif_cnn``. Details on how to run a job using the AzureML CLI are given in its `documentation`_ , but typically it suffices to set up the environment and type ``az ml job create -f <name-of-the-yaml-file>``. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
You should replace ``compute`` with the name of the one you created before, and adjust the path of the datastore containing the data. In the example above, we created a datastore called ``data`` and added to it a folder called ``cifar``, which contained the two HDF5 files. The command passed above will install dependencies and then launch a NCCL job with 4 threads, for the experiment defined in ``experiments/classif_cnn``. Details on how to run a job using the AzureML CLI are given in its `documentation`_ , but typically it suffices to set up the environment and type ``az ml job create -f <name-of-the-yaml-file>``. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
.. _documentation: https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli

Просмотреть файл

@ -27,7 +27,7 @@ Each worker>0 processes client tasks sequentially, consisting of data encoding a
:align: center
:width: 500
FLUTE uses a distributed processing architecture backed by OpenMPI.
FLUTE uses a distributed processing architecture backed by torch.distributed.
Execution runs for up to N training rounds. In each round the orchestrator may sample a subset of clients, and may also randomly delay pseudo-gradient updates from some clients to future rounds. The orchestrator will also periodically distribute evaluation tasks to determine model quality on validation and test data.

Просмотреть файл

@ -2,7 +2,7 @@
# Licensed under the MIT license.
'''
This is the main script to run on each MPI thread. It will spawn either a
This is the main script to run on each NCCL/GLOO thread. It will spawn either a
Server or Worker object -- the former is responsible for orchestrating and
aggregating models, where as the latter processes clients' data to generate
a new model. The Server lives on the very first thread, whereas remaining
@ -13,16 +13,17 @@ import argparse
import os
import shutil
import yaml
import logging
from psutil import virtual_memory
import torch
import torch.distributed as dist
from azureml.core import Run
from core import federated
from core.config import FLUTEConfig
from core.server import select_server
from core.client import Client
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level
from experiments import make_model
from utils import (
make_optimizer,
@ -32,12 +33,9 @@ from utils import (
)
from utils.dataloaders_utils import (
make_train_dataloader,
make_val_dataloader,
make_test_dataloader,
get_dataset,
)
assert TRAINING_FRAMEWORK_TYPE == "mpi", "Unsupported platform {}".format(TRAINING_FRAMEWORK_TYPE)
from core.evaluation import make_eval_clients
def log_run_properties(config: FLUTEConfig):
"""Log parameters on AzureML.
@ -76,49 +74,64 @@ def log_run_properties(config: FLUTEConfig):
run.log(k, properties[k])
def run_worker(model_path, config, task, data_path, local_rank):
"""Spawn worker object that lives throughout MPI thread.
def run_worker(model_path, config, task, data_path, local_rank, backend):
"""Spawn worker object that lives throughout NCCL/GLOO thread.
Args:
model_path (str): path to the pretrained model.
config (dict): dictionary containing parameters.
task (str): what task to solve, must be a folder of :code:`experiments`.
data_path (str): path to data.
local_rank (int): the rank of the MPI thread.
local_rank (int): the rank of the NCCL/GLOO thread.
"""
model_config = config["model_config"]
server_config = config["server_config"]
# Get the rank on MPI
# Backend initialization
print_rank(f"Backend: {backend}")
dist.init_process_group(backend=backend, init_method=None)
rank = dist.get_rank()
if torch.cuda.is_available():
torch.cuda.set_device(rank)
# Get the rank on NCCL/GLOO
rank = local_rank if local_rank > -1 else federated.rank()
# Assign MPI thread to a specific GPU
# Assign NCCL thread to a specific GPU
if torch.cuda.is_available():
n_gpus = torch.cuda.device_count()
torch.cuda.set_device(federated.local_rank() % n_gpus)
print_rank(f"Assigning worker to GPU {federated.local_rank() % n_gpus}")
torch.cuda.set_device(federated.rank() % n_gpus)
print_rank(f"Assigning worker to GPU {federated.rank() % n_gpus}")
# Make the Model to distribute to workers
model = make_model(model_config)
# Get evaluation datasets
data_config = config['server_config']['data_config']
val_dataset = get_dataset(data_path, data_config["val"], task, mode="val")
test_dataset = get_dataset(data_path, data_config["test"], task, mode="test")
# Create list of clients for test/val -- Server need the indexes and Worker the clients list
val_clients = list(make_eval_clients(val_dataset, config))
test_clients = list(make_eval_clients(test_dataset, config))
# pre-cache the training data and capture the number of clients for sampling
client_train_config = config["client_config"]["data_config"]["train"]
num_clients = Client.get_train_dataset(data_path, client_train_config,task)
config["server_config"]["data_config"]["num_clients"] = num_clients
# Instantiate the Server object on the first thread
if rank == 0:
try:
print_rank('Server data preparation')
# pre-cache the training data and capture the number of clients for sampling
client_train_config = config["client_config"]["data_config"]["train"]
num_clients, train_dataset = Client.get_train_dataset(data_path, client_train_config,task)
config["server_config"]["data_config"]["num_clients"] = num_clients
# Make the Dataloaders
data_config = config['server_config']['data_config']
if 'train' in data_config:
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
else:
server_train_dataloader = None
val_dataloader = make_val_dataloader(data_config["val"], data_path, task=task)
test_dataloader = make_test_dataloader(data_config["test"], data_path, task=task)
idx_val_clients = list(range(len(val_clients))) # Generates indexes for val clients
idx_test_clients = list(range(len(test_clients))) # Generates indexes for test clients
print_rank("Prepared the dataloaders")
@ -135,18 +148,16 @@ def run_worker(model_path, config, task, data_path, local_rank):
server_type = server_config["type"]
server_setup = select_server(server_type) # Return the server class
server = server_setup(
data_config["num_clients"],
model,
optimizer,
None,
data_path,
model_path,
server_train_dataloader,
train_dataset,
val_dataloader,
test_dataloader,
config,
server_config
num_clients=data_config["num_clients"],
model=model,
optimizer=optimizer,
ss_scheduler=None,
data_path=data_path,
model_path=model_path,
server_train_dataloader=server_train_dataloader,
config=config,
idx_val_clients=idx_val_clients,
idx_test_clients=idx_test_clients,
)
log_run_properties(config)
@ -163,10 +174,14 @@ def run_worker(model_path, config, task, data_path, local_rank):
print_rank("Worker on node {}: process started".format(rank))
client_config = config["client_config"]
worker = federated.Worker(
model,
data_path,
model=model,
data_path=data_path,
do_profiling=client_config.get("do_profiling", False),
clients_in_parallel=client_config.get("clients_in_parallel", None),
val_clients=val_clients,
test_clients=test_clients,
val_dataset = val_dataset,
test_dataset = test_dataset,
config= config,
)
worker.run()
@ -178,6 +193,7 @@ if __name__ == "__main__":
parser.add_argument("-outputPath")
parser.add_argument("-dataPath", default=None)
parser.add_argument("-task", default=None, help="Define the task for the run")
parser.add_argument("-backend", default=None, help="Define the communication protocol")
parser.add_argument("-num_skip_decoding", default=-1, type=int, help="Skip decoding in unsupervised learning mode")
parser.add_argument("--local_rank", default=-1, type=int)
@ -185,6 +201,8 @@ if __name__ == "__main__":
data_path = args.dataPath
task = args.task
local_rank = args.local_rank
assert args.backend in ['nccl','gloo'], f"Backend {args.backend} not recognized, please select nccl or gloo"
backend = args.backend
# The mount point can also be retrieved from input_datasets of the run context
if data_path is None:
@ -195,6 +213,7 @@ if __name__ == "__main__":
id = Run.get_context().id
experiment_name = "-".join(id.split("-")[-4:-2])
experiment_root = os.path.join(args.outputPath, experiment_name)
os.makedirs(experiment_root, exist_ok=True)
model_path = os.path.join(experiment_root, "models")
log_path = os.path.join(experiment_root, "log")
@ -207,7 +226,7 @@ if __name__ == "__main__":
shutil.copyfile(args.config, cfg_out)
# Initialize logging
init_logging(log_path, loglevel=logging_level)
init_logging(log_path, loglevel=logging.INFO)
with open(args.config) as f:
@ -222,4 +241,4 @@ if __name__ == "__main__":
config.validate()
# Instantiate either Server or Worker on the thread
run_worker(model_path, config, task, data_path, local_rank)
run_worker(model_path, config, task, data_path, local_rank, backend)

Просмотреть файл

@ -48,10 +48,10 @@ example is provided in `config.yaml`.
## Running the experiment
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
script using MPI
script using torch.distributed.
```
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn -backend gloo
```
The `dataPath`, `outputPath` and `config` arguments should just specify the

Просмотреть файл

@ -44,10 +44,10 @@ example is provided in `config.yaml`.
## Running the experiment
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
script using MPI
script using torch.distributed.
```
mpiexec -n 4 python e2e_trainer.py -dataPath ./ -outputPath scratch -config experiments/classif_cnn/config.yaml -task cv
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py -dataPath ./ -outputPath scratch -config experiments/classif_cnn/config.yaml -task cv -backend gloo
```
The `dataPath`, `outputPath` and `config` arguments should just specify the

Просмотреть файл

@ -7,12 +7,12 @@ In this file, we define the local server that lives inside the client.
from core.server import OptimizationServer
class PersonalizationServer(OptimizationServer):
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
val_dataloader, test_dataloader, config, config_server):
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
config, idx_val_clients, idx_test_clients):
"""
Personalization Server.
Customized routines for server can be included here.
"""
super().__init__(num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader, train_dataset,
val_dataloader, test_dataloader, config, config_server)
super().__init__(num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
config, idx_val_clients, idx_test_clients)

Просмотреть файл

@ -59,9 +59,9 @@ example is provided in `config.yaml`.
## Running the experiment locally
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
script using MPI:
script using torch.distributed:
`mpiexec -n 2 python .\e2e_trainer.py -dataPath experiments/ecg_cnn/data -outputPath scratch -config experiments/ecg_cnn/config.yaml -task ecg_cnn `
`python -m torch.distributed.run --nproc_per_node=2 .\e2e_trainer.py -dataPath experiments/ecg_cnn/data -outputPath scratch -config experiments/ecg_cnn/config.yaml -task ecg_cnn -backend nccl`
The `dataPath`, `outputPath` and `config` arguments should just specify the
respective files or folders, as in the example above -- in this case, a folder
@ -90,11 +90,12 @@ command: >
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
python3 -m pip install --upgrade pip &&
python3 -m pip install -r requirements.txt &&
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
python -m torch.distributed.run --nproc_per_node=4 e2e_trainer.py
-outputPath=./outputs
-dataPath={inputs.data}
-task=ecg_cnn
-config=./experiments/ecg_cnn/config.yaml
-backend=nccl
```
To run your job, you can then use the following command:

Просмотреть файл

@ -24,10 +24,10 @@ the fields: list_of_train_data, test_data and val_data inside the config file.
## Running the experiment locally
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
script using MPI:
script using torch.distributed:
```code
mpiexec -n 2 python .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert
python -m torch.distributed.run --nproc_per_node=2 .\e2e_trainer.py -dataPath data_folder -outputPath scratch -config configs\hello_world_mlm_bert_json.yaml -task mlm_bert -backend nccl
```
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`

Просмотреть файл

@ -7,6 +7,8 @@ from transformers import AutoTokenizer
from transformers import DataCollatorForLanguageModeling
from experiments.mlm_bert.dataloaders.dataset import Dataset
from core.dataloader import BaseDataLoader
from utils import print_rank
import logging
class DataLoader(BaseDataLoader):
"""
@ -38,7 +40,7 @@ class DataLoader(BaseDataLoader):
else:
raise ValueError("You are instantiating a new tokenizer from scratch. This is not supported by this script.")
print("Tokenizer is: ",tokenizer)
print_rank("Tokenizer is: {}".format(tokenizer), loglevel=logging.DEBUG)
dataset = Dataset(
data,

Просмотреть файл

@ -115,7 +115,7 @@ class Dataset(BaseDataset):
if len(self.utt_list) == 0:
self.utt_list = [{'src_text': 'N/A', 'duration': 0, 'loss_weight': 1.0}]
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.DEBUG)
def process_user(self, user, user_data):
counter=0

Просмотреть файл

@ -27,10 +27,10 @@ data to HDF5 format.
## Running the experiment
Finally, to launch the experiment locally , it suffices to launch the `e2e_trainer.py`
script using MPI, you can use as example the following line:
script using torch.distributed , you can use as example the following line:
```code
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru
python -m torch.distributed.run --nproc_per_node=3 e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_nlg_gru.yaml -task nlg_gru -backend nccl
```
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`

Просмотреть файл

@ -59,6 +59,7 @@ class Dataset(BaseDataset):
self.user_data = orig_strct['user_data']
self.user = 'test_only' if self.test_only else self.user_list[user_idx]
if user_idx != -1:
self.process_x(self.user_data)
def process_x(self, user_data):

Просмотреть файл

@ -30,13 +30,13 @@ def run_pipeline(data_path, output_path, config_path, task):
# Adjust command to the task and OS
sym = "&" if platform.system() == "Windows" else ";"
command = 'cd .. '+ sym +' mpiexec '+'-np '+'2 '+ 'python '+ 'e2e_trainer.py '+ \
command = 'cd .. '+ sym +' python '+'-m '+'torch.distributed.run '+ '--nproc_per_node=2 '+ 'e2e_trainer.py '+ \
'-dataPath '+ data_path+' -outputPath '+output_path+' -config ' +config_path +\
' -task '+ task
' -task '+ task + ' -backend '+ 'nccl'
# Execute e2e_trainer + stores the exit code
with open('logs.txt','w') as f:
process= subprocess.run(command, shell=True,stdout=f,text=True,timeout=2000)
process= subprocess.run(command, shell=True,stdout=f,text=True,timeout=900)
return_code=process.returncode
# Print logs
@ -53,6 +53,12 @@ def test_nlg_gru():
data_path, output_path, config_path = get_info(task)
assert run_pipeline(data_path, output_path, config_path, task)==0
def test_ecg_cnn():
task = 'ecg_cnn'
data_path, output_path, config_path = get_info(task)
assert run_pipeline(data_path, output_path, config_path, task)==0
@pytest.mark.xfail
def test_mlm_bert():
@ -68,12 +74,3 @@ def test_classif_cnn():
data_path, output_path, config_path = get_info(task)
assert run_pipeline(data_path, output_path, config_path, task)==0
print("PASSED")
def test_ecg_cnn():
task = 'ecg_cnn'
data_path, output_path, config_path = get_info(task)
assert run_pipeline(data_path, output_path, config_path, task)==0

Просмотреть файл

@ -82,4 +82,17 @@ def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
return test_dataloader
def get_dataset(data_path, data_config, task, mode):
""" Return the task train/val/test dataset """
dir = os.path.join('experiments',task,'dataloaders','dataset.py')
loader = SourceFileLoader("Dataset",dir).load_module()
dataset = loader.Dataset
data_file = "val_data" if mode == "val" else "test_data" if mode == "test" else "list_of_train_data"
data_file = data_config[data_file]
data_file = os.path.join(data_path, data_file) if data_file != None else data_file
dataset = dataset(data_file, user_idx=-1, args=data_config)
return dataset

Просмотреть файл

@ -1,45 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from mpi4py import MPI
from concurrent.futures import ThreadPoolExecutor
from utils import print_rank
comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()
""" Here we have classes and functions that allow one to send and process multiple
messages in parallel on MPI. """
def process_in_parallel(client_fn, client_data, server_data, models, data_path):
""" Process multiple orders in parallel
Parameters
----------
client_fn: callback
Function we want to call.
client_data: list of tuples
Arguments that will be passed to function.
server_data: tuple
Data passed from server to update model parameters.
models: torch.nn.Module
Models we will send to the clients.
data_path: str
Path to data.
Returns
-------
list
Output of each callback in the list passed as input.
"""
with ThreadPoolExecutor(max_workers=len(client_data)) as pool:
requests = []
for k, args in enumerate(client_data):
requests.append(pool.submit(client_fn, args, server_data, models[k], data_path))
results = [request.result() for request in requests]
print_rank(f'finished processing batch of size {len(client_data)}')
return results

Просмотреть файл

@ -18,19 +18,12 @@ from collections import OrderedDict
from utils.optimizers.lars import LarsSGD
from utils.optimizers.lamb import LAMB
from utils.optimizers.adamW import AdamW
from core.globals import TRAINING_FRAMEWORK_TYPE
from easydict import EasyDict as edict
from torch.optim.lr_scheduler import (
StepLR,
MultiStepLR,
ReduceLROnPlateau )
if TRAINING_FRAMEWORK_TYPE == 'mpi':
from mpi4py import MPI
else:
raise NotImplementedError('Training framework is not yet supported')
def make_optimizer(optimizer_config, model):
"""Initialization for optimizer."""