msrflute/core/federated.py

676 строки
28 KiB
Python
Исходник Постоянная ссылка Ответственный История

Этот файл содержит неоднозначные символы Юникода!

Этот файл содержит неоднозначные символы Юникода, которые могут быть перепутаны с другими в текущей локали. Если это намеренно, можете спокойно проигнорировать это предупреждение. Используйте кнопку Экранировать, чтобы подсветить эти символы.

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import cProfile
import logging
import threading
import torch
import torch.distributed as dist
import numpy as np
from core.client import Client
from utils import (
print_rank,
print_profiler,
to_device,
)
COMMAND_UPDATE = 0
COMMAND_TRAIN = 1
COMMAND_TERMINATE = 10
COMMAND_TESTVAL = 11
COMMAND_SYNC_NODES = 9
GLOBAL_MESSAGE = None
def encode_string(word, string_to_int = True):
""" Encodes/Decodes the dictionary keys into an array of integers to be sent
as tensors of the same shape during NCCL/Gloo P2P communication.
Args:
word (string/array): key to be encoded/decoded.
string_to_int (bool): flag that indicates which action to perform.
"""
if string_to_int: # encode
word = word.ljust(8, ' ') if len(word) < 8 else word # padding -- 8 is max length, all tensors must have the same size during communication
word_encoded = [letter for letter in word.encode()]
return word_encoded
else: #decode
cleanup_array = [letter for letter in word if letter!= 32] # Remove padding
word_decoded = bytes(cleanup_array).decode()
return word_decoded
def rank():
""" Return rank of node. """
return int(os.environ['RANK'])
def local_rank():
""" Return local rank of node. """
return int(os.environ['LOCAL_RANK'])
def size():
""" Returns number of nodes in the distributed group, including server. """
return int(os.environ['WORLD_SIZE'])
def _recv(x, src=0):
""" Receives tensors with a single element or a list of tensors
with the same shape during distributed communication. """
x = torch.tensor(x) if torch.is_tensor(x) == False else x
x = to_device(x)
dist.recv(tensor=x, src=src)
x.to('cpu')
try:
return x.item() # single element
except:
return x.tolist() # list of tensors
def _recv_gradients(src):
""" Receives a list of tensors with different shape during
distributed communication. """
n, n_dimensions, grads = 0, 0, [] # tensors intialization -- required by torch.
n = _recv(n,src)
for i in range(n):
n_dimensions = _recv(n_dimensions,src)
dimensions = [0 for i in range(n_dimensions)]
dimensions = _recv(dimensions, src)
print_rank(f"Received dimensions {dimensions}", loglevel=logging.DEBUG)
param = to_device(torch.zeros(dimensions))
print_rank(f"Shape assigned {param.shape}", loglevel=logging.DEBUG)
dist.recv(param,src)
grads.append(param.detach().cpu())
torch.cuda.empty_cache()
return grads
def _send(x, dst=0):
""" Send tensors with a single element or a list of tensors
with the same shape during distributed communication. """
x = torch.tensor(x)
x = to_device(x)
dist.send(x, dst)
del x
torch.cuda.empty_cache()
def _send_metrics(output):
""" Organize the keys and values from the resulting dictionary
from test/val rounds into arrays that are sent as independent
tensors during distributed communication. """
keys = [encode_string(key) for key in output.keys()]
values = [float(output[key]['value']) for key in output.keys()]
higher_is_better = [int(output[key]['higher_is_better']) for key in output.keys()] # send the boolean as int
_send(len(keys),0)
_send(keys)
_send(values)
_send(higher_is_better)
def _send_gradients(gradients, dst):
""" Send a list of tensors with different shape during
distributed communication. """
_send(len(gradients), dst)
for i in gradients:
dimensions = [int(d) for d in i.shape]
_send(len(dimensions),dst)
_send(dimensions,dst)
param = to_device(i)
dist.send(param,dst)
del param
torch.cuda.empty_cache()
def _send_train_output(output):
""" Organize the keys and values from the the returning ´client_output´
dictionary in ´Client.proces_round()´ function during training rounds,
into arrays that are sent as independent tensors during distributed
communication. """
cs_values = [float(cs_v) for cs_v in output['cs'].values()] # cs dict -- values are flatten in 1d array
pl_values = [float(output['pl']['weight'])] # pl dict
gradients = output['pl']['gradients'] # gradients are sent independently
if len(output.keys()) > 9: # DP metrics
ps_values = [float(ps_v) for ps_v in output['ps'].values()]
values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl','ps']] + pl_values + ps_values # reorganizing values in the order expected by the Server
else:
values = cs_values + [float(output[key]) for key in output.keys() if key not in ['cs','pl']] + pl_values # reorganizing values in the order expected by the Server
# Send data
_send(int(len(output.keys())),0) # Warn for number of keys
_send(values, 0)
_send_gradients(gradients, 0)
def build_grads_dict(node):
""" Reconstruct the dictionary ´client_output´ returned by
´Client.proces_round()´ function on the Server side during
distributed communication. """
# Initialize tensors
n_keys = 0
n_keys = _recv(n_keys,node)
print(n_keys)
if n_keys == 9:
keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl']
values = [0.0 for i in range(11)] # initializing tensor shape -- 11 is fixed number of keys expected
elif n_keys == 10:
keys = ['cs','tl','mg','vg','ng','rg','ns','ts','pl','ps']
values = [0.0 for i in range(15)] # When the privacy metrics are enabled
elif n_keys == 11:
keys = ['cs','tl','mg','vg','ng','rg','ns','wt','ts','pl','ps']
values = [0.0 for i in range(16)] # When the privacy metrics are enabled
# Read data
values = _recv(values,node)
grads = _recv_gradients(node)
cs_values = [{key: values.pop(0) for key in ['setup','training','full cost']}] # recreating cs dict
# Rebuilding original dictionary
if n_keys == 9:
pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
values_list = cs_values + [values.pop(0) for i in range(7)] + pl_values # 7 is fixed length for remaining items
else:
ps_values = [{key: values.pop() for key in ['Practical epsilon (Max leakage)','Words percentage above 9000 word rank','Extracted indices percentage','Dropped clients']}]
pl_values = [{'weight':values.pop(), 'gradients': grads}] # recreating pl dict
values_list = cs_values + [values.pop(0) for i in range(len(values))] + pl_values + ps_values
result = dict(zip(keys,values_list))
# Cast values to original type
for key in ['mg','vg','ng','rg']:
result[key] = np.float32(result[key])
result['ns'] = int(result['ns'] )
return result
def build_metrics_dict(node):
""" Reconstruct the dictionary returned during test/val rounds
on the Server side during distributed communication. """
# Initialize tensors
n = 0
n = _recv(n,node)
keys = [[0 for j in range(8)] for i in range(n)] # max_seq_len for metric name is 8
values = [0.0 for i in range(n)]
higher_is_better = [0 for i in range(n)]
# Read data
keys = _recv(keys,node)
values = _recv(values,node)
higher_is_better = _recv(higher_is_better,node)
# Reorganize output + decode dict keys
orig_keys = [encode_string(key, string_to_int=False) for key in keys]
values_dict = [{'value': float(v), 'higher_is_better': bool(higher_is_better[i])} for i, v in enumerate(values)]
metrics = dict(zip(orig_keys,values_dict))
num_instances = int(metrics.pop('num')['value'])
result = None, metrics, num_instances
return result
def receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes):
""" Receives the clients output on the Server side in async/sync mode.
Asynchronous mode is only enabled when using NCCL backend given that Gloo
does not provide native non-blocking implementation to check if the operation
has been completed during distributed training"""
if dist.get_backend() == "nccl": # Async
for node, req in node_request_map:
if req.is_completed():
result = build_metrics_dict(node) if command == COMMAND_TESTVAL else build_grads_dict(node)
results_list.append(result)
free_nodes.append(node)
node_request_map.remove((node,req))
print_rank(f"Finished releasing the nodes {free_nodes}", loglevel=logging.DEBUG)
else: # Sync
print_rank(f"Waiting for a workers", loglevel=logging.DEBUG)
gather_objects = [(None,None,None) for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f" All workers have finished ... taking the remaining clients {len(output)}", loglevel=logging.DEBUG)
output = [e for i,e in enumerate(output) if i not in idle_nodes ] # Cleanup for idle workers
results_list = results_list + output[1:]
free_nodes = list(range(1, size()))
return node_request_map, results_list, free_nodes
def append_async_requests(node_request_map, node):
""" Appends the asynchronous request sent to each worker during
asynchronous training. """
ack = to_device(torch.tensor(1))
req = dist.irecv(tensor=ack, src=node)
node_request_map.append((node,req))
return node_request_map
def sync_idle_nodes(client_queue, free_nodes):
""" Request dummy outputs to the odd (idle) nodes during synchronous training
to prevent them to get trapped in the state of the previous iterations """
idle_nodes = []
if len(client_queue) == 0:
print_rank(f"Free idle nodes {len(free_nodes)}", loglevel=logging.DEBUG)
while len(free_nodes) > 0:
node = free_nodes.pop()
idle_nodes.append(node)
_send(COMMAND_SYNC_NODES, node)
return idle_nodes
class Server:
"""Server object responsible for orchestration and aggregation.
The Server is one of the two objects that may exist inside of a thread, all
throughout its execution (the other being the Worker). At every round, the
Server samples clients ids and send their data for an available Worker to process.
The Workers then each produce a new model, and all models are sent to the Server
for aggregation.
The methods defined here are related to orchestration only, the aggregation
will be done by a different object which inherits from this one.
Notes:
This class has no :code`__init__` method, and all its methods are static.
It thus only serves the purpose of grouping the methods, but nothing
is actually stored inside of the object.
"""
@staticmethod
def dispatch_clients(clients, server_data, command, mode=None, do_profiling=False, single_worker=None):
"""Perform the orchestration between Clients and Workers.
This function does the following:
1. It sends the server_data to all workers
2. For each available Worker:
2a. It sends the index of the client to instantiate
2c. It triggers the execution of the command on the
Client.
3. Collect and return all client outputs.
Notes:
This function yields the gradients of different clients
as they are received. Therefore, the order of the results generally
does not correspond to the order of the clients.
All commands used during Server-Worker communication must be
float/integers given that torch.distributed only allows to
send/recv tensors.
Args:
clients (list): list of clients to be processed.
server_data (dict): server data sent to the workers and passed to
clients, typically includes the global model at that step.
command (int): instruction for worker to execute on the Client.
mode (int): test/val only provided during evaluation rounds.
do_profiling (bool): enables profiler during comunication.
Returns:
Generator of results.
"""
# Single GPU flag
single_gpu = True if size()==1 else False
print_rank(f"Single GPU flag Server: {single_gpu}", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
# Initialize communication profiler
profiler = None
if do_profiling:
profiler = cProfile.Profile()
profiler.enable()
# Update lr + model parameters each round for all workers
lr, model_params, nround = server_data
if not single_gpu:
for worker_rank in range(1, size()):
_send(COMMAND_UPDATE, worker_rank)
_send(lr,worker_rank)
_send_gradients(model_params, worker_rank)
_send(float(nround),worker_rank)
print_rank(f"Finished sending lr {lr} and n_params {len(model_params)} to worker {worker_rank} - round {nround}", loglevel=logging.DEBUG)
print_rank(f"Finished sending server_data to workers", loglevel=logging.DEBUG)
client_queue = clients.copy()
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
free_nodes = list(range(1, size()))
results_list, node_request_map = [], []
# Initiate computation for all clients
while client_queue:
print_rank(f"Clients queue: {client_queue}", loglevel=logging.DEBUG)
assert len(free_nodes) > 0
node = free_nodes.pop()
index = len(client_queue)-1
client_to_process = client_queue.pop(index)
print_rank(f"Sending client {index} to worker {node}", loglevel=logging.DEBUG)
_send(command, node) # The command should indicate the worker which function to run on the client
if command == COMMAND_TESTVAL:
_send(mode,node) # Only for test/val has a value
_send(index, node) # Worker receives the index of the client to pop
elif command == COMMAND_TRAIN:
_send(client_to_process, node)
print_rank(f"Finished assigning worker {node}, free nodes {free_nodes}", loglevel=logging.DEBUG)
if dist.get_backend() == "nccl":
append_async_requests(node_request_map, node)
idle_nodes = None
else:
idle_nodes = sync_idle_nodes(client_queue, free_nodes)
# Waits until receive the output from all ranks
if not free_nodes:
print_rank(f"Waiting for a workers, free nodes {free_nodes}, reqs_lst {node_request_map}", loglevel=logging.DEBUG)
while len(free_nodes) == 0:
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
for output in results_list:
yield output
results_list = []
# Wait for all workers to finish
while (len(node_request_map)) != 0:
node_request_map, results_list, free_nodes = receive_workers_output(node_request_map, results_list, free_nodes, command, idle_nodes)
for output in results_list:
yield output
results_list = []
else:
# For a single-GPU execution, there is no P2P communication in the same GPU. Using threats to coordinate.
global GLOBAL_MESSAGE
GLOBAL_MESSAGE = server_data
if command == COMMAND_TESTVAL:
t1 = threading.Thread(target=single_worker.trigger_evaluate)
t1.start()
t1.join()
yield GLOBAL_MESSAGE
elif command == COMMAND_TRAIN:
total_clients = clients.copy()
for client_id in total_clients:
GLOBAL_MESSAGE = lr, model_params, nround, client_id
t1 = threading.Thread(target=single_worker.trigger_train)
t1.start()
t1.join()
result = GLOBAL_MESSAGE
yield result
if do_profiling:
profiler.disable()
print_profiler(profiler)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
@staticmethod
def process_clients(clients, server_data, single_worker):
"""Ask workers to perform training on Clients.
Args:
clients (list): list of clients indexes sampled by ´Server.py´
object per iteration.
server_data (dict): dictionary containing model.
Returns:
Generator of results.
"""
return Server.dispatch_clients(clients, server_data, COMMAND_TRAIN, single_worker=single_worker)
@staticmethod
def process_testvalidate(clients, server_data, mode, single_worker):
"""Ask workers to perform test/val on Clients.
Args:
clients (list): list of clients indexes for test/val rounds.
server_data (dict): dictionary containing model.
mode (str): test/val.
Returns:
Generator of results.
"""
mode = [-2] if mode == "test" else [2]
return Server.dispatch_clients(clients, server_data, COMMAND_TESTVAL, mode, single_worker=single_worker)
@staticmethod
def terminate_workers(terminate=True):
"""Terminate the execution of the workers."""
if terminate:
print_rank("Terminating worker processes")
for worker_rank in range(1, size()):
_send(COMMAND_TERMINATE, worker_rank)
class Worker:
"""Worker object responsible for instantiate Clients based on incoming data
from the Server and perform train/eval functions on it.
Each worker lives on a different NCCL/Gloo thread and is assigned to a different
GPU. Via the :code:`dispatch_clients` function, the Server passes to the
Worker specific instructions to process clients' data, typically in order
to generate a new model or to compute metrics.
Attributes:
model (torch.nn.Module): model being trained.
data_path (str): path where all clients' data is located.
do_profiling (bool): if True, analyzes execution in depth.
val_clients (list): clients list for validation rounds.
test_clients (list): clients list for testing rounds.
config (dict): clients configuration.
val_dataset (torch.utils.data.Dataset): validation dataset.
test_dataset (torch.utils.data.Dataset): testing dataset.
"""
def __init__(self, model=None, data_path=None, do_profiling=False, val_clients= None, \
test_clients=None, config=None, val_dataset = None, test_dataset = None):
self.model = model
self.data_path = data_path
self.do_profiling = do_profiling
self.config = config
self.val_clients = val_clients
self.test_clients = test_clients
self.val_dataset = val_dataset
self.test_dataset = test_dataset
def run(self):
"""Main loop executed by worker nodes.
This method handles the NCCL/Gloo communication between the worker and
the server. It keeps listening for commands from the Server,
and performs different actions on the Client assigned depending on
the command received.
"""
# Single GPU flag
single_gpu = True if size()==1 else False
print_rank(f"Single GPU flag Client: {single_gpu}", loglevel=logging.DEBUG)
if not single_gpu:
while True: # keeps listening for incoming server calls
# Initialize tensors -- required by torch.distributed
command, client_idx, mode = 0, 0, 0 # int
lr, nround = torch.zeros(1), torch.zeros(1) # float
# Read command
command = _recv(command)
print_rank(f"Command received {command} on worker {rank()}", loglevel=logging.DEBUG)
# Receive server data -- lr, model_params
if command == COMMAND_UPDATE:
print_rank(f"COMMMAND_UPDATE received {rank()}", loglevel=logging.DEBUG)
lr = _recv(lr, 0)
model_params = _recv_gradients(0)
nround = _recv(nround, 0)
server_data = (lr, model_params, int(nround))
print_rank(f"Received lr: {lr} and n_params: {len(model_params)} - round {nround}", loglevel=logging.DEBUG)
elif command == COMMAND_TRAIN:
print_rank(f"COMMMAND_TRAIN received {rank()}", loglevel=logging.DEBUG)
# Init profiler in training worker
profiler = None
if self.do_profiling:
profiler = cProfile.Profile()
profiler.enable()
# Receive client id from Server
client_idx = _recv(client_idx)
print_rank(f"Cliend idx received from Server: {client_idx}", loglevel=logging.DEBUG)
# Instantiate client
client_to_process = Client(
[client_idx],
self.config,
self.config['client_config']['type'] == 'optimization')
# Execute Client.get_data()
client_data = client_to_process.get_client_data()
# Execute Client.process_round()
output = client_to_process.process_round(client_data, server_data, self.model, self.data_path)
# Send output back to Server
if dist.get_backend() == "nccl":
# ASYNC mode -- enabled only for nccl backend
ack = to_device(torch.tensor(1))
dist.isend(tensor=ack, dst=0)
_send_train_output(output)
else:
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
gather_objects = [output for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
if self.do_profiling:
profiler.disable()
print_profiler(profiler)
elif command == COMMAND_TESTVAL:
print_rank(f"COMMMAND_TESTVAL received {rank()}", loglevel=logging.DEBUG)
# Init profiler in validation worker
profiler = None
if self.do_profiling:
profiler = cProfile.Profile()
profiler.enable()
# Receive mode and client id from Server
mode = _recv(mode)
mode = "test" if mode == -2 else "val"
client_idx = _recv(client_idx)
print_rank(f"Client idx received from Server: {client_idx}, {mode}", loglevel=logging.DEBUG)
# Get client and dataset
clients = self.val_clients if mode == "val" else self.test_clients
dataset = self.val_dataset if mode == "val" else self.test_dataset
clients_queue = clients.copy()
assert 0 <= client_idx < len(clients_queue)
client_to_process = clients_queue.pop(client_idx)
# Execute Client.get_data()
client_data = client_to_process.get_client_data(dataset)
# Execute Client.run_testvalidate()
output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)
# Send output back to Server
if dist.get_backend() == "nccl":
# ASYNC mode -- enabled only for nccl backend
_, metrics, num_instances = output
metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
output = metrics
print_rank(f"Worker {rank()} output {output}", loglevel=logging.DEBUG)
ack = to_device(torch.tensor(1))
dist.isend(tensor=ack, dst=0)
_send_metrics(output)
else:
# SYNC mode -- gloo backend does not have a non-blocking way to check if the operation is completed
gather_objects = [output for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f"Worker {rank()} sent output back", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
if self.do_profiling:
profiler.disable()
print_profiler(profiler)
elif command == COMMAND_TERMINATE:
print_rank(f"COMMMAND_TERMINATE received {rank()}", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
return
elif command == COMMAND_SYNC_NODES: # Only for sync calls
print_rank(f"COMMMAND_SYNC_NODES received {rank()}", loglevel=logging.DEBUG)
gather_objects = [None for i in range(size())]
output = [None for _ in gather_objects]
dist.all_gather_object(output, gather_objects[rank()])
print_rank(f"Worker IDLE {rank()} sent dummy output back", loglevel=logging.DEBUG)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
else:
assert False, "unknown command"
def trigger_evaluate(self):
global GLOBAL_MESSAGE
lr, model_params, nround = GLOBAL_MESSAGE
server_data = (lr, model_params, int(nround))
mode = "val"
# Get client and dataset
clients = self.val_clients if mode == "val" else self.test_clients
dataset = self.val_dataset if mode == "val" else self.test_dataset
clients_queue = clients.copy()
client_to_process = clients_queue.pop()
# Execute Client.get_data()
client_data = client_to_process.get_client_data(dataset)
# Execute Client.run_testvalidate()
output = client_to_process.run_testvalidate(client_data, server_data, mode, self.model)
_, metrics, num_instances = output
metrics['num']= {'value': float(num_instances), 'higher_is_better': False}
GLOBAL_MESSAGE = (_, metrics, num_instances)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None
def trigger_train(self):
global GLOBAL_MESSAGE
lr, model_params, nround, client_idx = GLOBAL_MESSAGE
server_data = (lr, model_params, int(nround))
# Instantiate client
client_to_process = Client([client_idx], self.config, self.config['client_config']['type'] == 'optimization')
# Execute Client.get_data()
client_data = client_to_process.get_client_data()
# Execute Client.process_round()
GLOBAL_MESSAGE = client_to_process.process_round(client_data, server_data, self.model, self.data_path)
# Some cleanup
torch.cuda.empty_cache()
torch.cuda.synchronize() if torch.cuda.is_available() else None