msrflute/core/server.py

598 строки
27 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
'''
In this file, we define the classes that live inside 'worker 0', the worker
responsible for orchestration and aggregation. The main class is the
OptimizationServer, which sends clients to the other workers to process and
combines the resulting models.
'''
import json
import logging
import os
import random
import shutil
import time
from collections import defaultdict
import numpy as np
import torch
# Internal imports
import core.federated as federated
from core.evaluation import Evaluation
from core.client import Client
from .strategies import select_strategy
from .trainer import (
ModelUpdater,
Trainer,
set_component_wise_lr,
)
from utils import (
get_lr,
print_rank,
update_json_log,
to_device,
)
# For profiling
import cProfile
import pstats
# AzureML-related libs
from azureml.core import Run
run = Run.get_context()
class OptimizationServer(federated.Server):
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
config, idx_val_clients, idx_test_clients, single_worker):
'''Implement Server's orchestration and aggregation.
This is the main Server class, that actually implements orchestration
and aggregation, inheriting from `federated.Server`, which deals with
communication only.
The `train` method is central in FLUTE, as it defines good part of what
happens during training.
Args:
num_clients (int): total available clients.
model (torch.nn.Module): neural network model.
optimizer (torch.optim.Optimizer): optimizer.
ss_scheduler: scheduled sampling scheduler.
data_path (str): points to where data is.
model_path (str): points to where pretrained model is.
server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
config (dict): JSON style configuration parameters
idx_val_clients (list): validation client ids
idx_test_clients (list): testing clients ids
'''
super().__init__()
# Initialize all attributes from arguments
self.client_idx_list = list(range(num_clients))
self.config = config
server_config = config['server_config']
decoder_config = config.get('decoder_config', None)
self.max_iteration = server_config['max_iteration']
self.do_clustering = server_config.get('clustering', False)
self.send_dicts = server_config.get('send_dicts', False)
self.num_clients_per_iteration = [int(x) for x in server_config['num_clients_per_iteration'].split(',')] \
if isinstance(server_config['num_clients_per_iteration'], str) \
else [server_config['num_clients_per_iteration']]
self.val_freq = server_config['val_freq']
self.req_freq = server_config['rec_freq']
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients, single_worker)
# TODO: does this need to be adjusted for custom metrics?
self.metrics = dict()
self.model_backup_freq = server_config.get('model_backup_freq', 100)
self.worker_trainer_config = server_config.get('trainer_config', {})
self.aggregate_median = server_config['aggregate_median']
self.initial_lr_client = server_config.get('initial_lr_client', -1.0)
self.lr_decay_factor = server_config.get('lr_decay_factor', 1.0)
self.model_type = config['model_config']['model_type']
self.quant_thresh = config['client_config'].get('quant_thresh', None)
self.quant_bits = config['client_config'].get('quant_bits', 10)
self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data']
self.data_path = data_path
self.single_worker = single_worker
# Get max grad norm from data config
if 'train' in server_config['data_config']:
max_grad_norm = server_config['data_config']['train'].get('max_grad_norm', None)
else:
max_grad_norm = None
# Creating an instance to update the model with stats aggregated from workers
self.worker_trainer = ModelUpdater(
model=model,
optimizer=optimizer,
ss_scheduler=ss_scheduler,
train_dataloader=server_train_dataloader,
val_dataloader=None,
max_grad_norm=max_grad_norm,
anneal_config=server_config['annealing_config'],
model_type=self.model_type,
decoder_config=decoder_config
)
self.metrics['worker_trainer'] = self.worker_trainer
# Creating an instance for the server-side trainer (runs mini-batch SGD)
self.server_replay_iterations = None
self.server_trainer = None
if server_train_dataloader is not None:
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
assert 'optimizer_config' in server_config[
'server_replay_config'], 'server-side replay training optimizer is not set'
self.server_optimizer_config = server_config['server_replay_config']['optimizer_config']
self.server_trainer_config = server_config['server_replay_config'].get('trainer_config', {})
self.server_replay_iterations = server_config['server_replay_config']['server_iterations']
self.server_trainer = Trainer(
model=model,
optimizer=None,
ss_scheduler=ss_scheduler,
train_dataloader=server_train_dataloader,
server_replay_config=server_config['server_replay_config'],
max_grad_norm=server_config['server_replay_config']\
.get('max_grad_norm',server_config['data_config']['train']\
.get('max_grad_norm',None)),
anneal_config=server_config['server_replay_config'].get('annealing_config', None),
ignore_subtask = server_config['server_replay_config'].get('ignore_subtask', False)
)
self.skip_model_update = False # will not update the model if True
self.train_loss = 0.0
self.model_path = model_path
self.best_model_criterion = server_config['best_model_criterion']
self.fall_back_to_best_model = server_config['fall_back_to_best_model']
self.last_model_path = os.path.join(self.model_path, 'latest_model.tar')
self.best_model_path = os.path.join(self.model_path,
'best_val_{}_model.tar'.format(self.best_model_criterion))
self.log_path = os.path.join(self.model_path, 'status_log.json')
self.cur_iter_no = 0 # keep the iteration number for Tensor board plotting
self.lr_weight = 1.0
self.losses = []
self.no_label_updates = 0 # no. label updates
# Update the parameters above if the log file
if server_config.get('resume_from_checkpoint', False):
self.load_saved_status()
# Decoding config
self.decoder_config = decoder_config
self.spm_model = server_config['data_config']['test'].get('spm_model', None)
self.do_profiling = server_config.get('do_profiling', False)
StrategyClass = select_strategy(config['strategy'])
self.strategy = StrategyClass('server', self.config, self.model_path)
print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)
def load_saved_status(self):
'''Load checkpoint from disk'''
# Check if model is on disk, if so loads it onto trainer
if os.path.exists(self.last_model_path):
print_rank('Resuming from checkpoint model {}'.format(self.last_model_path))
self.worker_trainer.load(self.last_model_path, update_lr_scheduler=True, update_ss_scheduler=True)
if self.server_trainer is not None:
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
# Check if log is on disk, if so loads it onto current stats
if os.path.exists(self.log_path):
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
elems = json.load(logfp)
self.cur_iter_no = elems.get('i', 0)
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
self.lr_weight = elems.get('weight', 1.0)
self.no_label_updates = elems.get('num_label_updates', 0)
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
def run(self):
'''Trigger training.
This is a simple wrapper to the `train` method.
'''
print_rank('server started')
self.train()
print_rank('server terminated')
def train(self):
'''Main method for training.'''
self.run_stats = {
'secsPerClientRound': [],
'secsPerClient': [],
'secsPerClientTraining': [],
'secsPerClientSetup': [],
'secsPerClientFull': [],
'secsPerRoundHousekeeping': [],
'secsPerRoundTotal': [],
'communicationCosts': []
}
run.log('Max iterations', self.max_iteration)
try:
self.worker_trainer.model = to_device(self.worker_trainer.model)
# Do an initial validation round to understand the pretrained model's validation accuracy
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
eval_list = []
if self.cur_iter_no == 0:
if self.config['server_config']['initial_rec']:
eval_list.append('test')
if self.config['server_config']['initial_val']:
eval_list.append('val')
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
eval_list = [] # some cleanup
# Dump all the information in aggregate_metric
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
for token in ['best_val_loss', 'best_val_acc', 'best_test_acc', 'latest']:
self.worker_trainer.save(
model_path=self.model_path,
token=token,
config=self.config['server_config']
)
# Training loop
self.worker_trainer.model.train()
for i in range(self.cur_iter_no, self.max_iteration):
begin = time.time()
metrics_payload = {}
def log_metric(k, v):
metrics_payload[k] = v
print_rank('==== iteration {}'.format(i))
log_metric('Current iteration', i)
# Initial value for the learning rate of the worker
initial_lr = self.initial_lr_client * self.lr_weight
print_rank('Client learning rate {}'.format(initial_lr))
# Run training on clients
self.worker_trainer.model.zero_grad()
self.train_loss = []
if self.send_dicts: # Send state dictionaries
glob_payload = [self.worker_trainer.model.state_dict()[param_key].to(torch.device('cpu')) for param_key in self.worker_trainer.model.state_dict()]
else: # Send parameters
glob_payload = [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]
server_data = (initial_lr, glob_payload, i)
# Random number of clients per iteration
if len(self.num_clients_per_iteration) > 1:
num_clients_curr_iter = random.randint(
self.num_clients_per_iteration[0],
self.num_clients_per_iteration[1]
)
else:
num_clients_curr_iter = self.num_clients_per_iteration[0]
log_metric('Clients for round', num_clients_curr_iter)
# Perform annealing in quantization threshold
if self.quant_thresh is not None:
self.config['client_config']['quant_thresh'] *= self.config['client_config'].get('quant_anneal', 1.0)
self.quant_thresh = self.config['client_config']['quant_thresh']
log_metric('Quantization Thresh.', self.config['client_config']['quant_thresh'])
# Create the pool of clients -- sample from this pool to assign to workers
sampled_idx_clients = random.sample(self.client_idx_list,
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
# Initialize stats
clients_begin = time.time()
client_losses = []
client_mag_grads = []
client_mean_grads = []
client_var_grads = []
client_norm_grads = []
self.run_stats['secsPerClient'].append([])
self.run_stats['secsPerClientFull'].append([])
self.run_stats['secsPerClientTraining'].append([])
self.run_stats['secsPerClientSetup'].append([])
self.run_stats['communicationCosts'].append([])
# Check if we want privacy metrics
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
self.config['privacy_metrics_config']['apply_metrics']
adaptive_leakage = apply_privacy_metrics and \
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
if apply_privacy_metrics:
privacy_metrics_stats = defaultdict(list)
# Initialize profiler
profiler = None
if self.do_profiling:
profiler = cProfile.Profile()
profiler.enable()
# Reset gradient for the model before assigning the new gradients
self.worker_trainer.model.zero_grad()
print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
for client_output in self.process_clients(sampled_idx_clients, server_data, self.single_worker):
# Process client output
client_timestamp = client_output['ts']
client_stats = client_output['cs']
client_loss = client_output['tl']
client_mag_grad = client_output['mg']
client_mean_grad = client_output['ng']
client_var_grad = client_output['vg']
client_norm_grad = client_output['rg']
client_payload = client_output['pl']
if apply_privacy_metrics:
privacy_stats = client_output['ps']
for metric, value in privacy_stats.items():
privacy_metrics_stats[metric].append(value)
self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)
# Get actual pseudo-gradients for aggregation
payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
if not payload_processed:
print_rank('Dropping client', loglevel=logging.DEBUG)
num_clients_curr_iter -= 1
continue
# Aggregate stats
self.train_loss.append(client_loss)
client_losses.append(client_loss)
client_mag_grads.append(client_mag_grad.item())
client_mean_grads.append(client_mean_grad.item())
client_var_grads.append(client_var_grad.item())
client_norm_grads.append(client_norm_grad.item())
# Mark the end of client processing
client_end = time.time()
self.run_stats['secsPerClientFull'][-1].append(client_stats['full cost'])
self.run_stats['secsPerClientTraining'][-1].append(client_stats['training'])
self.run_stats['secsPerClientSetup'][-1].append(client_stats['setup'])
self.run_stats['secsPerClient'][-1].append(client_end - clients_begin)
# Tear down profiler
if self.do_profiling:
profiler.disable()
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative').print_stats()
# Prepare output
client_mag_grads = np.array(client_mag_grads)
client_mean_grads = np.array(client_mean_grads)
client_var_grads = np.array(client_var_grads)
client_norm_grads = np.array(client_norm_grads)
client_stats = (client_mag_grads, client_mean_grads, client_var_grads)
dump_norm_stats = self.config.get('dump_norm_stats', False)
if dump_norm_stats:
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
outF.write('{}\n'.format(json.dumps(list(client_norm_grads))))
# Print the privacy metrics
if apply_privacy_metrics:
for metric, values in privacy_metrics_stats.items():
if metric == 'Dropped clients':
log_metric(metric, sum(values))
else:
log_metric(metric, max(values))
if type(adaptive_leakage) is float:
values = privacy_metrics_stats['Practical epsilon (Max leakage)']
new_threshold = list(sorted(values))[int(adaptive_leakage*len(values))]
print_rank('Updating leakage threshold to {}'.format(new_threshold))
self.config['privacy_metrics_config']['max_allowed_leakage'] = new_threshold
# Mark that all clients have been processed
end = time.time()
self.run_stats['secsPerClientRound'].append(end - begin)
begin = end
# Log the training loss to tensorboard/AML
log_metric('Training loss', sum(self.train_loss))
# Combine payloads
self.losses = self.strategy.combine_payloads(
worker_trainer=self.worker_trainer,
curr_iter=i,
num_clients_curr_iter=num_clients_curr_iter,
total_clients = len(self.client_idx_list),
client_stats=client_stats,
logger=log_metric,
)
# Run a couple of iterations of training data on the server
if self.server_trainer is not None:
print_rank('Running replay iterations on server')
if 'updatable_names' in self.server_trainer_config:
set_component_wise_lr(
self.worker_trainer.model,
self.server_optimizer_config,
self.server_trainer_config['updatable_names']
)
self.server_trainer.prepare_iteration(self.worker_trainer.model)
self.server_trainer.train_desired_samples(self.server_replay_iterations)
self.worker_trainer.model.load_state_dict(self.server_trainer.model.state_dict())
torch.cuda.empty_cache()
# Update a sampling scheduler
print_rank('Run ss scheduler')
self.worker_trainer.run_ss_scheduler()
# Run inference and score on val/test depending on the iter. number
if ((i+1) % self.val_freq) == 0:
eval_list.append("val")
if ((i+1) % self.req_freq) == 0 :
eval_list.append("test")
if len(eval_list)> 0:
print_rank('Running {} at itr={}'.format(eval_list,i+1))
self.metrics['worker_trainer'] = self.worker_trainer
if hasattr(self.strategy,'tmp_unsup'):
self.metrics['tmp_sup'] = self.strategy.tmp_sup
self.metrics['tmp_unsup'] = self.strategy.tmp_unsup
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
self.losses = self.evaluation.losses
eval_list = []
# Create a schedule for the initial_lr (for the worker)
if 'val' in eval_list:
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
if not (self.losses[0] < self.metrics['best_val_loss']):
self.lr_weight *= self.lr_decay_factor
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
# Backup the current best models
self.backup_models(i)
# Fall back to the best model if the option is enabled
self.fall_back_to_prev_best_status()
# Logging the latest best values only after the 1st val/test round has been executed
if len(self.metrics) > 1:
update_json_log(
self.log_path,
{
'i': i + 1,
'best_val_loss': float(self.metrics['best_val_loss']),
'best_val_acc': float(self.metrics['best_val_acc']),
'best_test_loss': float(self.metrics['best_test_loss']),
'best_test_acc': float(self.metrics['best_test_acc']),
'weight': float(self.lr_weight),
'num_label_updates': int(self.no_label_updates)
},
)
end = time.time()
# Aggregate stats
self.run_stats['secsPerRoundHousekeeping'].append(end - begin)
self.run_stats['secsPerRoundTotal'].append(self.run_stats['secsPerClientRound'][-1] + \
self.run_stats['secsPerRoundHousekeeping'][-1])
log_metric('secsPerRoundTotal', self.run_stats['secsPerRoundTotal'][-1])
if self.do_profiling:
log_metric('secsPerClientRound', self.run_stats['secsPerClientRound'][-1])
log_metric('secsPerRoundHousekeeping', self.run_stats['secsPerRoundHousekeeping'][-1])
metrics_for_stats = [
'secsPerClient',
'secsPerClientTraining',
'secsPerClientFull',
'secsPerClientSetup',
'communicationCosts',
]
for metric in metrics_for_stats:
log_metric(f'{metric}Mean', np.mean(self.run_stats[metric][-1]))
log_metric(f'{metric}Median', np.median(self.run_stats[metric][-1]))
log_metric(f'{metric}Max', max(self.run_stats[metric][-1]))
for k in self.run_stats:
if k in metrics_for_stats:
print_rank('{}: {}'.format(k, max(self.run_stats[k][-1])), loglevel=logging.DEBUG)
else:
print_rank('{}: {}'.format(k, self.run_stats[k][-1]), loglevel=logging.DEBUG)
# Log all the metrics
for k in metrics_payload:
run.log(k, metrics_payload[k])
finally: # perform cleanup even if error was raised above
self.terminate_workers(terminate=(not self.do_clustering))
def backup_models(self, i):
'''Save the current best models.
Save CER model, the best loss model and the best WER model. This occurs
at a specified period.
Args:
i: no. of iterations.
'''
# Always save the latest model
self.worker_trainer.save(
model_path=self.model_path,
token='latest',
config=self.config['server_config'],
)
if (i % self.model_backup_freq) == 0: # save the current best models
self.worker_trainer.save(
model_path=self.model_path,
token='epoch{}'.format(i),
config=self.config['server_config']
)
for bodyname in ['best_val_acc', 'best_val_loss', 'best_test_acc']:
src_model_path = os.path.join(self.model_path, '{}_model.tar'.format(bodyname))
if os.path.exists(src_model_path):
dst_model_path = os.path.join(self.model_path, 'epoch{}_{}_model.tar'.format(i, bodyname))
shutil.copyfile(src_model_path, dst_model_path)
print_rank('Saved {}'.format(dst_model_path))
def fall_back_to_prev_best_status(self):
'''Go back to the past best status and switch to the recent best model.'''
if self.fall_back_to_best_model:
print_rank('falling back to model {}'.format(self.best_model_path))
# Save current learning rate
tmp_lr = get_lr(self.worker_trainer.optimizer)
# Load previous best model
self.worker_trainer.load(self.best_model_path, update_lr_scheduler=False, update_ss_scheduler=False)
# Update previous learning rate on optimizer
for g in self.worker_trainer.optimizer.param_groups:
g['lr'] = tmp_lr
if self.server_trainer is not None:
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
def select_server(server_type):
'''Select a server type using different possible strings.
Right now this just returns `OptimizationServer`, but this
function could be useful when there are multiple choices of
server.
Args:
server_type (str): indicates server choice.
config (dict): config parsed from YAML, passed so that
parameters can be used to select a given server.
'''
if server_type == "personalization":
from experiments.cv.server import PersonalizationServer
return PersonalizationServer
else:
return OptimizationServer