зеркало из https://github.com/microsoft/msrflute.git
598 строки
27 KiB
Python
598 строки
27 KiB
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT license.
|
|
'''
|
|
In this file, we define the classes that live inside 'worker 0', the worker
|
|
responsible for orchestration and aggregation. The main class is the
|
|
OptimizationServer, which sends clients to the other workers to process and
|
|
combines the resulting models.
|
|
'''
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import random
|
|
import shutil
|
|
import time
|
|
from collections import defaultdict
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
# Internal imports
|
|
import core.federated as federated
|
|
from core.evaluation import Evaluation
|
|
from core.client import Client
|
|
from .strategies import select_strategy
|
|
from .trainer import (
|
|
ModelUpdater,
|
|
Trainer,
|
|
set_component_wise_lr,
|
|
)
|
|
from utils import (
|
|
get_lr,
|
|
print_rank,
|
|
update_json_log,
|
|
to_device,
|
|
)
|
|
|
|
# For profiling
|
|
import cProfile
|
|
import pstats
|
|
|
|
# AzureML-related libs
|
|
from azureml.core import Run
|
|
run = Run.get_context()
|
|
|
|
|
|
class OptimizationServer(federated.Server):
|
|
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, server_train_dataloader,
|
|
config, idx_val_clients, idx_test_clients, single_worker):
|
|
'''Implement Server's orchestration and aggregation.
|
|
|
|
This is the main Server class, that actually implements orchestration
|
|
and aggregation, inheriting from `federated.Server`, which deals with
|
|
communication only.
|
|
|
|
The `train` method is central in FLUTE, as it defines good part of what
|
|
happens during training.
|
|
|
|
Args:
|
|
num_clients (int): total available clients.
|
|
model (torch.nn.Module): neural network model.
|
|
optimizer (torch.optim.Optimizer): optimizer.
|
|
ss_scheduler: scheduled sampling scheduler.
|
|
data_path (str): points to where data is.
|
|
model_path (str): points to where pretrained model is.
|
|
server_train_dataloader (torch.utils.data.DataLoader): dataloader for training
|
|
config (dict): JSON style configuration parameters
|
|
idx_val_clients (list): validation client ids
|
|
idx_test_clients (list): testing clients ids
|
|
'''
|
|
|
|
super().__init__()
|
|
|
|
# Initialize all attributes from arguments
|
|
self.client_idx_list = list(range(num_clients))
|
|
self.config = config
|
|
server_config = config['server_config']
|
|
decoder_config = config.get('decoder_config', None)
|
|
|
|
self.max_iteration = server_config['max_iteration']
|
|
self.do_clustering = server_config.get('clustering', False)
|
|
self.send_dicts = server_config.get('send_dicts', False)
|
|
|
|
self.num_clients_per_iteration = [int(x) for x in server_config['num_clients_per_iteration'].split(',')] \
|
|
if isinstance(server_config['num_clients_per_iteration'], str) \
|
|
else [server_config['num_clients_per_iteration']]
|
|
|
|
self.val_freq = server_config['val_freq']
|
|
self.req_freq = server_config['rec_freq']
|
|
|
|
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, idx_val_clients, idx_test_clients, single_worker)
|
|
|
|
# TODO: does this need to be adjusted for custom metrics?
|
|
self.metrics = dict()
|
|
|
|
self.model_backup_freq = server_config.get('model_backup_freq', 100)
|
|
self.worker_trainer_config = server_config.get('trainer_config', {})
|
|
|
|
self.aggregate_median = server_config['aggregate_median']
|
|
self.initial_lr_client = server_config.get('initial_lr_client', -1.0)
|
|
self.lr_decay_factor = server_config.get('lr_decay_factor', 1.0)
|
|
|
|
self.model_type = config['model_config']['model_type']
|
|
self.quant_thresh = config['client_config'].get('quant_thresh', None)
|
|
self.quant_bits = config['client_config'].get('quant_bits', 10)
|
|
|
|
self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data']
|
|
self.data_path = data_path
|
|
self.single_worker = single_worker
|
|
|
|
# Get max grad norm from data config
|
|
if 'train' in server_config['data_config']:
|
|
max_grad_norm = server_config['data_config']['train'].get('max_grad_norm', None)
|
|
else:
|
|
max_grad_norm = None
|
|
|
|
# Creating an instance to update the model with stats aggregated from workers
|
|
self.worker_trainer = ModelUpdater(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
ss_scheduler=ss_scheduler,
|
|
train_dataloader=server_train_dataloader,
|
|
val_dataloader=None,
|
|
max_grad_norm=max_grad_norm,
|
|
anneal_config=server_config['annealing_config'],
|
|
model_type=self.model_type,
|
|
decoder_config=decoder_config
|
|
)
|
|
self.metrics['worker_trainer'] = self.worker_trainer
|
|
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
|
self.server_replay_iterations = None
|
|
self.server_trainer = None
|
|
if server_train_dataloader is not None:
|
|
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
|
assert 'optimizer_config' in server_config[
|
|
'server_replay_config'], 'server-side replay training optimizer is not set'
|
|
self.server_optimizer_config = server_config['server_replay_config']['optimizer_config']
|
|
self.server_trainer_config = server_config['server_replay_config'].get('trainer_config', {})
|
|
self.server_replay_iterations = server_config['server_replay_config']['server_iterations']
|
|
self.server_trainer = Trainer(
|
|
model=model,
|
|
optimizer=None,
|
|
ss_scheduler=ss_scheduler,
|
|
train_dataloader=server_train_dataloader,
|
|
server_replay_config=server_config['server_replay_config'],
|
|
max_grad_norm=server_config['server_replay_config']\
|
|
.get('max_grad_norm',server_config['data_config']['train']\
|
|
.get('max_grad_norm',None)),
|
|
anneal_config=server_config['server_replay_config'].get('annealing_config', None),
|
|
ignore_subtask = server_config['server_replay_config'].get('ignore_subtask', False)
|
|
)
|
|
|
|
self.skip_model_update = False # will not update the model if True
|
|
|
|
self.train_loss = 0.0
|
|
self.model_path = model_path
|
|
self.best_model_criterion = server_config['best_model_criterion']
|
|
self.fall_back_to_best_model = server_config['fall_back_to_best_model']
|
|
self.last_model_path = os.path.join(self.model_path, 'latest_model.tar')
|
|
self.best_model_path = os.path.join(self.model_path,
|
|
'best_val_{}_model.tar'.format(self.best_model_criterion))
|
|
self.log_path = os.path.join(self.model_path, 'status_log.json')
|
|
self.cur_iter_no = 0 # keep the iteration number for Tensor board plotting
|
|
self.lr_weight = 1.0
|
|
|
|
self.losses = []
|
|
self.no_label_updates = 0 # no. label updates
|
|
|
|
# Update the parameters above if the log file
|
|
if server_config.get('resume_from_checkpoint', False):
|
|
self.load_saved_status()
|
|
|
|
# Decoding config
|
|
self.decoder_config = decoder_config
|
|
self.spm_model = server_config['data_config']['test'].get('spm_model', None)
|
|
|
|
self.do_profiling = server_config.get('do_profiling', False)
|
|
|
|
StrategyClass = select_strategy(config['strategy'])
|
|
self.strategy = StrategyClass('server', self.config, self.model_path)
|
|
print_rank(f'Server successfully instantiated strategy {self.strategy}', loglevel=logging.DEBUG)
|
|
|
|
def load_saved_status(self):
|
|
'''Load checkpoint from disk'''
|
|
|
|
# Check if model is on disk, if so loads it onto trainer
|
|
if os.path.exists(self.last_model_path):
|
|
print_rank('Resuming from checkpoint model {}'.format(self.last_model_path))
|
|
self.worker_trainer.load(self.last_model_path, update_lr_scheduler=True, update_ss_scheduler=True)
|
|
if self.server_trainer is not None:
|
|
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
|
|
|
# Check if log is on disk, if so loads it onto current stats
|
|
if os.path.exists(self.log_path):
|
|
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
|
|
elems = json.load(logfp)
|
|
self.cur_iter_no = elems.get('i', 0)
|
|
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
|
|
self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
|
|
self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
|
|
self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
|
|
self.lr_weight = elems.get('weight', 1.0)
|
|
self.no_label_updates = elems.get('num_label_updates', 0)
|
|
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
|
|
|
|
def run(self):
|
|
'''Trigger training.
|
|
|
|
This is a simple wrapper to the `train` method.
|
|
'''
|
|
print_rank('server started')
|
|
self.train()
|
|
print_rank('server terminated')
|
|
|
|
def train(self):
|
|
'''Main method for training.'''
|
|
|
|
self.run_stats = {
|
|
'secsPerClientRound': [],
|
|
'secsPerClient': [],
|
|
'secsPerClientTraining': [],
|
|
'secsPerClientSetup': [],
|
|
'secsPerClientFull': [],
|
|
'secsPerRoundHousekeeping': [],
|
|
'secsPerRoundTotal': [],
|
|
'communicationCosts': []
|
|
}
|
|
|
|
run.log('Max iterations', self.max_iteration)
|
|
try:
|
|
self.worker_trainer.model = to_device(self.worker_trainer.model)
|
|
|
|
# Do an initial validation round to understand the pretrained model's validation accuracy
|
|
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
|
|
eval_list = []
|
|
if self.cur_iter_no == 0:
|
|
|
|
if self.config['server_config']['initial_rec']:
|
|
eval_list.append('test')
|
|
if self.config['server_config']['initial_val']:
|
|
eval_list.append('val')
|
|
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
|
|
|
print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
|
|
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
|
eval_list = [] # some cleanup
|
|
|
|
# Dump all the information in aggregate_metric
|
|
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
|
|
for token in ['best_val_loss', 'best_val_acc', 'best_test_acc', 'latest']:
|
|
self.worker_trainer.save(
|
|
model_path=self.model_path,
|
|
token=token,
|
|
config=self.config['server_config']
|
|
)
|
|
|
|
# Training loop
|
|
self.worker_trainer.model.train()
|
|
for i in range(self.cur_iter_no, self.max_iteration):
|
|
begin = time.time()
|
|
metrics_payload = {}
|
|
|
|
def log_metric(k, v):
|
|
metrics_payload[k] = v
|
|
|
|
print_rank('==== iteration {}'.format(i))
|
|
log_metric('Current iteration', i)
|
|
|
|
# Initial value for the learning rate of the worker
|
|
initial_lr = self.initial_lr_client * self.lr_weight
|
|
print_rank('Client learning rate {}'.format(initial_lr))
|
|
|
|
# Run training on clients
|
|
self.worker_trainer.model.zero_grad()
|
|
self.train_loss = []
|
|
|
|
if self.send_dicts: # Send state dictionaries
|
|
glob_payload = [self.worker_trainer.model.state_dict()[param_key].to(torch.device('cpu')) for param_key in self.worker_trainer.model.state_dict()]
|
|
else: # Send parameters
|
|
glob_payload = [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]
|
|
|
|
server_data = (initial_lr, glob_payload, i)
|
|
|
|
# Random number of clients per iteration
|
|
if len(self.num_clients_per_iteration) > 1:
|
|
num_clients_curr_iter = random.randint(
|
|
self.num_clients_per_iteration[0],
|
|
self.num_clients_per_iteration[1]
|
|
)
|
|
else:
|
|
num_clients_curr_iter = self.num_clients_per_iteration[0]
|
|
log_metric('Clients for round', num_clients_curr_iter)
|
|
|
|
# Perform annealing in quantization threshold
|
|
if self.quant_thresh is not None:
|
|
self.config['client_config']['quant_thresh'] *= self.config['client_config'].get('quant_anneal', 1.0)
|
|
self.quant_thresh = self.config['client_config']['quant_thresh']
|
|
log_metric('Quantization Thresh.', self.config['client_config']['quant_thresh'])
|
|
|
|
# Create the pool of clients -- sample from this pool to assign to workers
|
|
sampled_idx_clients = random.sample(self.client_idx_list,
|
|
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
|
|
|
# Initialize stats
|
|
clients_begin = time.time()
|
|
|
|
client_losses = []
|
|
client_mag_grads = []
|
|
client_mean_grads = []
|
|
client_var_grads = []
|
|
client_norm_grads = []
|
|
|
|
self.run_stats['secsPerClient'].append([])
|
|
self.run_stats['secsPerClientFull'].append([])
|
|
self.run_stats['secsPerClientTraining'].append([])
|
|
self.run_stats['secsPerClientSetup'].append([])
|
|
self.run_stats['communicationCosts'].append([])
|
|
|
|
# Check if we want privacy metrics
|
|
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
|
|
self.config['privacy_metrics_config']['apply_metrics']
|
|
adaptive_leakage = apply_privacy_metrics and \
|
|
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
|
|
if apply_privacy_metrics:
|
|
privacy_metrics_stats = defaultdict(list)
|
|
|
|
# Initialize profiler
|
|
profiler = None
|
|
if self.do_profiling:
|
|
profiler = cProfile.Profile()
|
|
profiler.enable()
|
|
|
|
# Reset gradient for the model before assigning the new gradients
|
|
self.worker_trainer.model.zero_grad()
|
|
|
|
print_rank(f"Clients sampled from server {sampled_idx_clients}", loglevel=logging.DEBUG)
|
|
for client_output in self.process_clients(sampled_idx_clients, server_data, self.single_worker):
|
|
# Process client output
|
|
client_timestamp = client_output['ts']
|
|
client_stats = client_output['cs']
|
|
client_loss = client_output['tl']
|
|
client_mag_grad = client_output['mg']
|
|
client_mean_grad = client_output['ng']
|
|
client_var_grad = client_output['vg']
|
|
client_norm_grad = client_output['rg']
|
|
client_payload = client_output['pl']
|
|
|
|
if apply_privacy_metrics:
|
|
privacy_stats = client_output['ps']
|
|
for metric, value in privacy_stats.items():
|
|
privacy_metrics_stats[metric].append(value)
|
|
|
|
self.run_stats['communicationCosts'][-1].append(time.time() - client_timestamp)
|
|
|
|
# Get actual pseudo-gradients for aggregation
|
|
payload_processed = self.strategy.process_individual_payload(self.worker_trainer, client_payload)
|
|
if not payload_processed:
|
|
print_rank('Dropping client', loglevel=logging.DEBUG)
|
|
num_clients_curr_iter -= 1
|
|
continue
|
|
|
|
# Aggregate stats
|
|
self.train_loss.append(client_loss)
|
|
client_losses.append(client_loss)
|
|
client_mag_grads.append(client_mag_grad.item())
|
|
client_mean_grads.append(client_mean_grad.item())
|
|
client_var_grads.append(client_var_grad.item())
|
|
client_norm_grads.append(client_norm_grad.item())
|
|
|
|
# Mark the end of client processing
|
|
client_end = time.time()
|
|
|
|
self.run_stats['secsPerClientFull'][-1].append(client_stats['full cost'])
|
|
self.run_stats['secsPerClientTraining'][-1].append(client_stats['training'])
|
|
self.run_stats['secsPerClientSetup'][-1].append(client_stats['setup'])
|
|
self.run_stats['secsPerClient'][-1].append(client_end - clients_begin)
|
|
|
|
# Tear down profiler
|
|
if self.do_profiling:
|
|
profiler.disable()
|
|
stats = pstats.Stats(profiler)
|
|
stats.sort_stats('cumulative').print_stats()
|
|
|
|
# Prepare output
|
|
client_mag_grads = np.array(client_mag_grads)
|
|
client_mean_grads = np.array(client_mean_grads)
|
|
client_var_grads = np.array(client_var_grads)
|
|
client_norm_grads = np.array(client_norm_grads)
|
|
|
|
client_stats = (client_mag_grads, client_mean_grads, client_var_grads)
|
|
|
|
dump_norm_stats = self.config.get('dump_norm_stats', False)
|
|
if dump_norm_stats:
|
|
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
|
|
outF.write('{}\n'.format(json.dumps(list(client_norm_grads))))
|
|
|
|
# Print the privacy metrics
|
|
if apply_privacy_metrics:
|
|
for metric, values in privacy_metrics_stats.items():
|
|
if metric == 'Dropped clients':
|
|
log_metric(metric, sum(values))
|
|
else:
|
|
log_metric(metric, max(values))
|
|
|
|
if type(adaptive_leakage) is float:
|
|
values = privacy_metrics_stats['Practical epsilon (Max leakage)']
|
|
new_threshold = list(sorted(values))[int(adaptive_leakage*len(values))]
|
|
print_rank('Updating leakage threshold to {}'.format(new_threshold))
|
|
self.config['privacy_metrics_config']['max_allowed_leakage'] = new_threshold
|
|
|
|
# Mark that all clients have been processed
|
|
end = time.time()
|
|
self.run_stats['secsPerClientRound'].append(end - begin)
|
|
begin = end
|
|
|
|
# Log the training loss to tensorboard/AML
|
|
log_metric('Training loss', sum(self.train_loss))
|
|
|
|
# Combine payloads
|
|
self.losses = self.strategy.combine_payloads(
|
|
worker_trainer=self.worker_trainer,
|
|
curr_iter=i,
|
|
num_clients_curr_iter=num_clients_curr_iter,
|
|
total_clients = len(self.client_idx_list),
|
|
client_stats=client_stats,
|
|
logger=log_metric,
|
|
)
|
|
|
|
# Run a couple of iterations of training data on the server
|
|
if self.server_trainer is not None:
|
|
print_rank('Running replay iterations on server')
|
|
|
|
if 'updatable_names' in self.server_trainer_config:
|
|
set_component_wise_lr(
|
|
self.worker_trainer.model,
|
|
self.server_optimizer_config,
|
|
self.server_trainer_config['updatable_names']
|
|
)
|
|
self.server_trainer.prepare_iteration(self.worker_trainer.model)
|
|
self.server_trainer.train_desired_samples(self.server_replay_iterations)
|
|
self.worker_trainer.model.load_state_dict(self.server_trainer.model.state_dict())
|
|
torch.cuda.empty_cache()
|
|
|
|
# Update a sampling scheduler
|
|
print_rank('Run ss scheduler')
|
|
self.worker_trainer.run_ss_scheduler()
|
|
|
|
# Run inference and score on val/test depending on the iter. number
|
|
if ((i+1) % self.val_freq) == 0:
|
|
eval_list.append("val")
|
|
if ((i+1) % self.req_freq) == 0 :
|
|
eval_list.append("test")
|
|
|
|
if len(eval_list)> 0:
|
|
print_rank('Running {} at itr={}'.format(eval_list,i+1))
|
|
self.metrics['worker_trainer'] = self.worker_trainer
|
|
if hasattr(self.strategy,'tmp_unsup'):
|
|
self.metrics['tmp_sup'] = self.strategy.tmp_sup
|
|
self.metrics['tmp_unsup'] = self.strategy.tmp_unsup
|
|
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
|
self.losses = self.evaluation.losses
|
|
eval_list = []
|
|
|
|
# Create a schedule for the initial_lr (for the worker)
|
|
if 'val' in eval_list:
|
|
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
|
if not (self.losses[0] < self.metrics['best_val_loss']):
|
|
self.lr_weight *= self.lr_decay_factor
|
|
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
|
|
|
|
# Backup the current best models
|
|
self.backup_models(i)
|
|
|
|
# Fall back to the best model if the option is enabled
|
|
self.fall_back_to_prev_best_status()
|
|
|
|
# Logging the latest best values only after the 1st val/test round has been executed
|
|
if len(self.metrics) > 1:
|
|
update_json_log(
|
|
self.log_path,
|
|
{
|
|
'i': i + 1,
|
|
'best_val_loss': float(self.metrics['best_val_loss']),
|
|
'best_val_acc': float(self.metrics['best_val_acc']),
|
|
'best_test_loss': float(self.metrics['best_test_loss']),
|
|
'best_test_acc': float(self.metrics['best_test_acc']),
|
|
'weight': float(self.lr_weight),
|
|
'num_label_updates': int(self.no_label_updates)
|
|
},
|
|
)
|
|
|
|
end = time.time()
|
|
|
|
# Aggregate stats
|
|
self.run_stats['secsPerRoundHousekeeping'].append(end - begin)
|
|
self.run_stats['secsPerRoundTotal'].append(self.run_stats['secsPerClientRound'][-1] + \
|
|
self.run_stats['secsPerRoundHousekeeping'][-1])
|
|
|
|
log_metric('secsPerRoundTotal', self.run_stats['secsPerRoundTotal'][-1])
|
|
if self.do_profiling:
|
|
log_metric('secsPerClientRound', self.run_stats['secsPerClientRound'][-1])
|
|
log_metric('secsPerRoundHousekeeping', self.run_stats['secsPerRoundHousekeeping'][-1])
|
|
|
|
metrics_for_stats = [
|
|
'secsPerClient',
|
|
'secsPerClientTraining',
|
|
'secsPerClientFull',
|
|
'secsPerClientSetup',
|
|
'communicationCosts',
|
|
]
|
|
|
|
for metric in metrics_for_stats:
|
|
log_metric(f'{metric}Mean', np.mean(self.run_stats[metric][-1]))
|
|
log_metric(f'{metric}Median', np.median(self.run_stats[metric][-1]))
|
|
log_metric(f'{metric}Max', max(self.run_stats[metric][-1]))
|
|
|
|
for k in self.run_stats:
|
|
if k in metrics_for_stats:
|
|
print_rank('{}: {}'.format(k, max(self.run_stats[k][-1])), loglevel=logging.DEBUG)
|
|
else:
|
|
print_rank('{}: {}'.format(k, self.run_stats[k][-1]), loglevel=logging.DEBUG)
|
|
|
|
# Log all the metrics
|
|
for k in metrics_payload:
|
|
run.log(k, metrics_payload[k])
|
|
|
|
finally: # perform cleanup even if error was raised above
|
|
self.terminate_workers(terminate=(not self.do_clustering))
|
|
|
|
def backup_models(self, i):
|
|
'''Save the current best models.
|
|
|
|
Save CER model, the best loss model and the best WER model. This occurs
|
|
at a specified period.
|
|
|
|
Args:
|
|
i: no. of iterations.
|
|
'''
|
|
|
|
# Always save the latest model
|
|
self.worker_trainer.save(
|
|
model_path=self.model_path,
|
|
token='latest',
|
|
config=self.config['server_config'],
|
|
)
|
|
|
|
if (i % self.model_backup_freq) == 0: # save the current best models
|
|
self.worker_trainer.save(
|
|
model_path=self.model_path,
|
|
token='epoch{}'.format(i),
|
|
config=self.config['server_config']
|
|
)
|
|
|
|
for bodyname in ['best_val_acc', 'best_val_loss', 'best_test_acc']:
|
|
src_model_path = os.path.join(self.model_path, '{}_model.tar'.format(bodyname))
|
|
if os.path.exists(src_model_path):
|
|
dst_model_path = os.path.join(self.model_path, 'epoch{}_{}_model.tar'.format(i, bodyname))
|
|
shutil.copyfile(src_model_path, dst_model_path)
|
|
print_rank('Saved {}'.format(dst_model_path))
|
|
|
|
def fall_back_to_prev_best_status(self):
|
|
'''Go back to the past best status and switch to the recent best model.'''
|
|
|
|
if self.fall_back_to_best_model:
|
|
print_rank('falling back to model {}'.format(self.best_model_path))
|
|
|
|
# Save current learning rate
|
|
tmp_lr = get_lr(self.worker_trainer.optimizer)
|
|
|
|
# Load previous best model
|
|
self.worker_trainer.load(self.best_model_path, update_lr_scheduler=False, update_ss_scheduler=False)
|
|
|
|
# Update previous learning rate on optimizer
|
|
for g in self.worker_trainer.optimizer.param_groups:
|
|
g['lr'] = tmp_lr
|
|
|
|
if self.server_trainer is not None:
|
|
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
|
|
|
|
|
def select_server(server_type):
|
|
'''Select a server type using different possible strings.
|
|
|
|
Right now this just returns `OptimizationServer`, but this
|
|
function could be useful when there are multiple choices of
|
|
server.
|
|
|
|
Args:
|
|
server_type (str): indicates server choice.
|
|
config (dict): config parsed from YAML, passed so that
|
|
parameters can be used to select a given server.
|
|
'''
|
|
if server_type == "personalization":
|
|
from experiments.cv.server import PersonalizationServer
|
|
return PersonalizationServer
|
|
else:
|
|
return OptimizationServer
|