зеркало из https://github.com/microsoft/msrflute.git
475 строки
22 KiB
Python
475 строки
22 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
'''
|
|
The Client object is short-lived, instantiated inside of worker 0 and moved to
|
|
workers 1 to N for processing a given client's data. It's main method is the
|
|
`process_round` function, used to update the model given a client's data.
|
|
'''
|
|
|
|
import copy
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from easydict import EasyDict as edict
|
|
|
|
import h5py
|
|
import numpy as np
|
|
import torch
|
|
|
|
# Internal imports
|
|
from core.globals import TRAINING_FRAMEWORK_TYPE
|
|
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
|
import core.federated as federated
|
|
else:
|
|
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
|
|
|
from .strategies import select_strategy
|
|
from .trainer import (
|
|
Trainer,
|
|
run_validation_generic,
|
|
set_component_wise_lr,
|
|
)
|
|
from utils import (
|
|
ScheduledSamplingScheduler,
|
|
make_optimizer,
|
|
print_rank,
|
|
scrub_empty_clients,
|
|
)
|
|
from utils.dataloaders_utils import (
|
|
make_train_dataloader,
|
|
make_val_dataloader,
|
|
make_test_dataloader,
|
|
)
|
|
|
|
import extensions.privacy
|
|
from extensions.privacy import metrics as privacy_metrics
|
|
from experiments import make_model
|
|
|
|
|
|
# A per-process cache of the training data, so clients don't have to repeatedly re-load
|
|
# TODO: deprecate this in favor of passing dataloader around
|
|
_data_dict = None
|
|
_file_ext = None
|
|
|
|
|
|
class Client:
|
|
# It's unclear why, but sphinx refuses to generate method docs
|
|
# if there is no docstring for this class.
|
|
"""Client class for specifying individual client training tasks"""
|
|
|
|
def __init__(self, client_id, config, send_gradients, dataloader):
|
|
'''
|
|
Client side processing: computing gradients, update the model and send them back to the server
|
|
|
|
Args:
|
|
client_id (int): identifier for grabbing that client's data.
|
|
config (dict): dictionary with parameters loaded from config file.
|
|
send_gradients (bool): if True, model gradients are sent back;
|
|
otherwise, model weights are sent back.
|
|
dataloader (torch.utils.data.DataLoader): dataloader that generates
|
|
training data for the client.
|
|
'''
|
|
super().__init__()
|
|
|
|
self.client_id = client_id
|
|
self.client_data = self.get_data(client_id,dataloader)
|
|
self.config = copy.deepcopy(config)
|
|
self.send_gradients = send_gradients
|
|
|
|
def get_client_data(self):
|
|
'''"Getter" method that returns all object's attributes at once.'''
|
|
return self.client_id, self.client_data, self.config, self.send_gradients
|
|
|
|
@staticmethod
|
|
def get_num_users(filename):
|
|
'''Count users given a JSON or HDF5 file.
|
|
|
|
This function will fill the global data dict. Ideally we want data
|
|
handling not to happen here and only at the dataloader, that will be the
|
|
behavior in future releases.
|
|
|
|
Args:
|
|
filename (str): path to file containing data.
|
|
'''
|
|
|
|
global _data_dict
|
|
global _file_ext
|
|
_file_ext = filename.split('.')[-1]
|
|
|
|
try:
|
|
if _file_ext == 'json' or _file_ext == 'txt':
|
|
if _data_dict is None:
|
|
print_rank('Reading training data dictionary from JSON')
|
|
with open(filename,'r') as fid:
|
|
_data_dict = json.load(fid) # pre-cache the training data
|
|
_data_dict = scrub_empty_clients(_data_dict) # empty clients MUST be scrubbed here to match num_clients in the entry script
|
|
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
|
|
|
elif _file_ext == 'hdf5':
|
|
print_rank('Reading training data dictionary from HDF5')
|
|
_data_dict = h5py.File(filename, 'r')
|
|
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
|
|
|
except:
|
|
raise ValueError('Error reading training file. Please make sure the format is allowed')
|
|
|
|
num_users = len(_data_dict['users'])
|
|
return num_users
|
|
|
|
@staticmethod
|
|
def get_data(client_id, dataloader):
|
|
'''Load data from the dataloader given the client's id.
|
|
|
|
This function will load the global data dict. Ideally we want data
|
|
handling not to happen here and only at the dataloader, that will be the
|
|
behavior in future releases.
|
|
|
|
Args:
|
|
client_id (int or list): identifier(s) for grabbing client's data.
|
|
dataloader (torch.utils.data.DataLoader): dataloader that
|
|
provides the trianing
|
|
'''
|
|
|
|
# Auxiliary function for decoding only when necessary
|
|
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
|
|
|
# During training, client_id will be always an integer
|
|
if isinstance(client_id, int):
|
|
user_name = decode_if_str(_data_dict['users'][client_id])
|
|
num_samples = _data_dict['num_samples'][client_id]
|
|
|
|
if _file_ext == 'hdf5':
|
|
arr_data = [decode_if_str(e) for e in _data_dict['user_data'][user_name]['x'][()]]
|
|
user_data = {'x': arr_data}
|
|
elif _file_ext == 'json' or _file_ext == 'txt':
|
|
user_data = _data_dict['user_data'][user_name]
|
|
|
|
if 'user_data_label' in _data_dict: # supervised problem
|
|
labels = _data_dict['user_data_label'][user_name]
|
|
if _file_ext == 'hdf5': # transforms HDF5 Dataset into Numpy array
|
|
labels = labels[()]
|
|
|
|
return edict({'users': [user_name],
|
|
'user_data': {user_name: user_data},
|
|
'num_samples': [num_samples],
|
|
'user_data_label': {user_name: labels}})
|
|
else:
|
|
print_rank('no labels present, unsupervised problem', loglevel=logging.DEBUG)
|
|
return edict({'users': [user_name],
|
|
'user_data': {user_name: user_data},
|
|
'num_samples': [num_samples]})
|
|
|
|
# During validation and test, client_id might be a list of integers
|
|
elif isinstance(client_id, list):
|
|
if 'user_data_label' in _data_dict:
|
|
users_dict = {'users': [], 'num_samples': [], 'user_data': {}, 'user_data_label': {}}
|
|
else:
|
|
users_dict = {'users': [], 'num_samples': [], 'user_data': {}}
|
|
|
|
for client in client_id:
|
|
user_name = decode_if_str(dataloader.dataset.user_list[client])
|
|
users_dict['users'].append(user_name)
|
|
users_dict['num_samples'].append(dataloader.dataset.num_samples[client])
|
|
|
|
if _file_ext == 'hdf5':
|
|
arr_data = dataloader.dataset.user_data[user_name]['x']
|
|
arr_decoded = [decode_if_str(e) for e in arr_data]
|
|
users_dict['user_data'][user_name] = {'x': arr_decoded}
|
|
elif _file_ext == 'json':
|
|
users_dict['user_data'][user_name] = {'x': dataloader.dataset.user_data[user_name]['x']}
|
|
elif _file_ext == 'txt': # using a different line for .txt since our files have a different structure
|
|
users_dict['user_data'][user_name] = dataloader.dataset.user_data[user_name]
|
|
|
|
if 'user_data_label' in _data_dict:
|
|
labels = dataloader.dataset.user_data_label[user_name]
|
|
if _file_ext == 'hdf5':
|
|
labels = labels[()]
|
|
users_dict['user_data_label'][user_name] = labels
|
|
|
|
return users_dict
|
|
|
|
@staticmethod
|
|
def run_testvalidate(client_data, server_data, mode, model):
|
|
'''Called by worker to run test/validation sample on a client.
|
|
|
|
This functions assumes set_model_for_round has already been called to
|
|
push the model to the client (see federated.py).
|
|
|
|
Args:
|
|
client_data (tuple): client data and config. It is a tuple with 4
|
|
components; importantly, the second component is a dict
|
|
containing the data, and the third component is a dict with the
|
|
config parsed from the YAML file.
|
|
server_data (tuple): server data (model parameters mostly). It is
|
|
a tuple with 3 components; importantly, the third component
|
|
consists of the current model parameters.
|
|
mode (str): whether to `test` or `validate`.
|
|
model (torch.nn.Module): actual model without parameters.
|
|
'''
|
|
|
|
# Process inputs and initialize variables
|
|
_, data_strct, config, _ = client_data
|
|
_, _, model_parameters = server_data
|
|
config = copy.deepcopy(config)
|
|
|
|
begin = time.time()
|
|
|
|
# Use the server's data config since we're distributing test/validate from the server
|
|
data_config = config['server_config']['data_config'][mode]
|
|
want_logits = data_config.get('wantLogits', False)
|
|
|
|
# Create dataloader
|
|
dataloader = None
|
|
print_rank('making dataloader with task {}'.format(config['server_config']['task']), loglevel=logging.DEBUG)
|
|
if mode == 'test':
|
|
dataloader = make_test_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
|
|
elif mode == 'val':
|
|
dataloader = make_val_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
|
|
|
|
# Set model parameters
|
|
n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
|
|
print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
|
|
model.cuda() if torch.cuda.is_available() else model
|
|
for p, data in zip(model.parameters(), model_parameters):
|
|
p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()
|
|
print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)
|
|
|
|
|
|
num_instances = sum(data_strct['num_samples'])
|
|
print_rank(f'Validating {num_instances}', loglevel=logging.DEBUG)
|
|
|
|
# Compute output and metrics on the test or validation data
|
|
output, metrics = run_validation_generic(model, dataloader)
|
|
output = None if not want_logits else output
|
|
|
|
return output, metrics, num_instances
|
|
|
|
@staticmethod
|
|
def process_round(client_data, server_data, model, data_path, eps=1e-7):
|
|
'''Compute gradients given client's data and update model.
|
|
|
|
Args:
|
|
client_data (tuple): client data and config. It is a tuple
|
|
consisting of 4 components: an int indicating the client's id, a
|
|
dict containing that client's data, a dict with the config
|
|
parsed from the YAML file, and a bool indicating whether or not
|
|
gradients should be sent.
|
|
server_data (tuple): server data (model parameters mostly). It is
|
|
a tuple consisting of 3 components; importantly, the first is
|
|
a float giving the client's learning rate, and the third a list
|
|
of torch.Tensor's with current model parameters. The second one
|
|
is not used, right now.
|
|
model (torch.nn.Module): actual model without parameters.
|
|
data_path (str): where to get data from.
|
|
eps (float): lower bound for aggregation weights.
|
|
'''
|
|
|
|
# Ensure the client is assigned to the correct GPU
|
|
if torch.cuda.is_available() and torch.cuda.device_count() == federated.size():
|
|
torch.cuda.set_device(federated.local_rank())
|
|
|
|
# Process inputs and initialize variables
|
|
client_id, data_strct, config, send_gradients = client_data
|
|
initial_lr, _, model_parameters = server_data
|
|
config = copy.deepcopy(config)
|
|
|
|
model_config = config['model_config']
|
|
client_config = config['client_config']
|
|
data_config = client_config['data_config']['train']
|
|
task = client_config.get('task', {})
|
|
trainer_config = client_config.get('trainer_config', {})
|
|
privacy_metrics_config = config.get('privacy_metrics_config', None)
|
|
|
|
StrategyClass = select_strategy(config['strategy'])
|
|
strategy = StrategyClass('client', config)
|
|
print_rank(f'Client successfully instantiated strategy {strategy}', loglevel=logging.DEBUG)
|
|
|
|
begin = time.time()
|
|
client_stats = {}
|
|
|
|
# Update the location of the training file
|
|
data_config['list_of_train_data'] = os.path.join(data_path, data_config['list_of_train_data'])
|
|
|
|
user = data_strct['users'][0]
|
|
if 'user_data_label' in data_strct.keys(): # supervised case
|
|
input_strct = edict({
|
|
'users': [user],
|
|
'user_data': {user: data_strct['user_data'][user]},
|
|
'num_samples': [data_strct['num_samples'][0]],
|
|
'user_data_label': {user: data_strct['user_data_label'][user]}
|
|
})
|
|
else:
|
|
input_strct = edict({
|
|
'users': [user],
|
|
'user_data': {user: data_strct['user_data'][user]},
|
|
'num_samples': [data_strct['num_samples'][0]]
|
|
})
|
|
|
|
print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
|
|
client_id, user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
|
|
|
# Get dataloaders
|
|
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=input_strct)
|
|
val_dataloader = make_val_dataloader(data_config, data_path)
|
|
|
|
# Instantiate the model object
|
|
if model is None:
|
|
model = make_model(
|
|
model_config,
|
|
dataloader_type=train_dataloader.__class__.__name__,
|
|
input_dim=data_config['input_dim'],
|
|
vocab_size=train_dataloader.vocab_size,
|
|
)
|
|
|
|
# Set model parameters
|
|
n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
|
|
print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
|
|
model.cuda() if torch.cuda.is_available() else model
|
|
for p, data in zip(model.parameters(), model_parameters):
|
|
p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()
|
|
print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)
|
|
|
|
# Fix parameters of layers
|
|
if 'updatable_names' in trainer_config:
|
|
set_component_wise_lr(model, client_config['optimizer_config'], trainer_config['updatable_names'])
|
|
|
|
# Create the optimizer on the workers
|
|
# NOTE: the server dictates the learning rate for the clients
|
|
client_config['optimizer_config']['lr'] = initial_lr
|
|
optimizer = make_optimizer(client_config['optimizer_config'], model)
|
|
|
|
# Make the scheduled sampling scheduler
|
|
ss_scheduler = None
|
|
if 'ss_config' in client_config and client_config['ss_config'] is not None:
|
|
ss_scheduler = ScheduledSamplingScheduler(model=model, **client_config['ss_config'])
|
|
|
|
# Make the trainer
|
|
trainer = Trainer(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
ss_scheduler=ss_scheduler,
|
|
train_dataloader=train_dataloader,
|
|
val_dataloader=val_dataloader,
|
|
server_replay_config =client_config,
|
|
max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
|
|
anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
|
|
num_skips_threshold=client_config['num_skips_threshold'] if 'num_skips_threshold' in client_config else -1,
|
|
ignore_subtask=client_config['ignore_subtask']
|
|
)
|
|
|
|
if trainer.optimizer is not None:
|
|
initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())
|
|
|
|
annealing_config = client_config['annealing_config'] if 'annealing_config' in client_config else None
|
|
|
|
assert 'desired_max_samples' in client_config['data_config']['train'], 'Missing \'desired_max_samples\' entry in data config parameter'
|
|
desired_max_samples = client_config['data_config']['train']['desired_max_samples']
|
|
|
|
if trainer.optimizer is not None: # reset the optimizer state
|
|
if initial_lr > 0:
|
|
trainer.optimizer.param_groups[0].update({'lr': initial_lr})
|
|
initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())
|
|
trainer.reset_optimizer(initial_optimizer_state, annealing_config)
|
|
|
|
# Mark the end of setup
|
|
end = time.time()
|
|
client_stats['setup'] = end - begin
|
|
print_rank(f'Client setup cost {client_stats["setup"]}s', loglevel=logging.DEBUG)
|
|
begin_training = end
|
|
|
|
# Training begins here
|
|
trainer.model.train()
|
|
trainer.model.zero_grad()
|
|
|
|
# Save the client batches if we want to evaluate the privacy metrics
|
|
apply_privacy_metrics = (False if privacy_metrics_config is None else privacy_metrics_config['apply_metrics'])
|
|
|
|
# This is where training actually happens
|
|
train_loss, num_samples = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics)
|
|
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
|
|
|
|
# Estimate gradient magnitude mean/var
|
|
trainer.sufficient_stats['mean'] = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
|
|
trainer.sufficient_stats['mag'] = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
|
|
trainer.sufficient_stats['var'] = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - \
|
|
trainer.sufficient_stats['mag'] ** 2
|
|
trainer.sufficient_stats['norm'] = np.sqrt(trainer.sufficient_stats['sq_sum'])
|
|
|
|
trainer.train_loss = train_loss
|
|
trainer.num_samples = num_samples
|
|
|
|
# Compute pseudo-gradient
|
|
for p, data in zip(trainer.model.parameters(), model_parameters):
|
|
data = data.cuda() if torch.cuda.is_available() else data
|
|
p.grad = data - p.data
|
|
|
|
payload = strategy.generate_client_payload(trainer) if send_gradients else None
|
|
|
|
# Mark that training (including post-processing) is finished
|
|
end = time.time()
|
|
client_stats['training'] = end - begin_training
|
|
client_stats['full cost'] = end - begin
|
|
print_rank(f'Client training cost {end - begin_training}s', loglevel=logging.DEBUG)
|
|
print_rank(f'Client full cost {end - begin}s', loglevel=logging.DEBUG)
|
|
|
|
# Create dictionary that is sent back to server
|
|
client_output = {
|
|
'cs': client_stats,
|
|
'tl': train_loss,
|
|
'mg': trainer.sufficient_stats['mag'],
|
|
'vg': trainer.sufficient_stats['var'],
|
|
'ng': trainer.sufficient_stats['mean'],
|
|
'rg': trainer.sufficient_stats['norm'],
|
|
'ns': num_samples,
|
|
'pl': payload,
|
|
}
|
|
|
|
# Apply privacy metrics
|
|
if privacy_metrics_config and privacy_metrics_config['apply_metrics']:
|
|
print_rank('Applying privacy metrics', loglevel=logging.DEBUG)
|
|
|
|
privacy_stats = {'Dropped clients': 0}
|
|
batches = trainer.cached_batches
|
|
trainer.cached_batches = []
|
|
gradients = extensions.privacy.unroll_network(model.named_parameters(), select_grad=True)[0]
|
|
|
|
if privacy_metrics_config['apply_indices_extraction']:
|
|
allowed_word_rank = privacy_metrics_config.get('allowed_word_rank', 9000)
|
|
embed_dim, vocab_size = model_config['embed_dim'], model_config['vocab_size']
|
|
overlap, indices = privacy_metrics.extract_indices_from_embeddings(gradients, batches, embed_dim, vocab_size)
|
|
|
|
max_overlap = privacy_metrics_config.get('max_allowed_overlap', None)
|
|
if max_overlap is not None and overlap > max_overlap:
|
|
print_rank('Removing this client because we extracted {}% words and the maximum allowed is {}%'.format(overlap * 100, max_overlap * 100))
|
|
client_output['wt'] = 0.0
|
|
privacy_stats['Dropped clients'] = 1
|
|
|
|
privacy_stats['Extracted indices percentage'] = overlap
|
|
privacy_stats['Words percentage above ' + str(allowed_word_rank) + ' word rank'] = (indices > allowed_word_rank).mean() if len(indices)>0 else 0
|
|
|
|
if privacy_metrics_config['apply_leakage_metric']:
|
|
print_rank('Applying leakage metric', loglevel=logging.DEBUG)
|
|
|
|
orig_params = {n: p for (n, _), p in zip(trainer.model.named_parameters(), model_parameters)}
|
|
max_ratio = np.exp(privacy_metrics_config['max_leakage'])
|
|
optim_config = privacy_metrics_config['attacker_optimizer_config']
|
|
is_leakage_weighted = privacy_metrics_config['is_leakage_weighted']
|
|
|
|
leakage = privacy_metrics.practical_epsilon_leakage(orig_params,
|
|
trainer.model, batches, is_leakage_weighted, max_ratio, optim_config)
|
|
print_rank('privacy leakage: {}'.format(leakage), loglevel=logging.DEBUG)
|
|
|
|
max_leakage = privacy_metrics_config.get('max_allowed_leakage', None)
|
|
if max_leakage is not None and leakage > max_leakage:
|
|
print_rank('Removing this client because the information leakage/practical epsilon is {} and the maximum allowed is {}'.format(leakage, max_leakage))
|
|
client_output['wt'] = 0.0
|
|
privacy_stats['Dropped clients'] = 1
|
|
|
|
privacy_stats['Practical epsilon (Max leakage)'] = leakage
|
|
|
|
client_output['ps'] = privacy_stats
|
|
|
|
client_output['ts'] = time.time()
|
|
return client_output
|