Merged PR 1083: Move config validation to FLUTEConfig.validate(). Fix some bugs.

what it says on the tin
This commit is contained in:
Robert Sim 2022-03-28 18:03:00 +00:00 коммит произвёл Andre Manoel
Родитель f97162fd0a
Коммит a8a20c2c7d
8 изменённых файлов: 246 добавлений и 224 удалений

Просмотреть файл

@ -1,116 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
A collection of functions for checking the format of configuration values
"""
import os
def check_server_config(config, default_server_conf):
assert "server_config" in config, "server config setting is missing"
# Checking parameters for server-side training
if "train" in config["server_config"]["data_config"]:
if "train_data_server" in config["server_config"]["data_config"]["train"]:
assert "server_replay_config" in config["server_config"], "Training dataset is defined on the server but training parameters are not set"
assert "optimizer_config" in config["server_config"]["server_replay_config"], "Missing \"optimizer_config\" in server_replay server training config"
assert "server_iterations" in config["server_config"]["server_replay_config"], "Missing \"server_iterations\" in server_replay server training config"
# Setting the default values if missing
for key in default_server_conf.keys():
if not key in config["server_config"]:
config["server_config"][key] = default_server_conf[key]
server_type = config["server_config"]["type"]
if not (server_type == "model_averaging" or \
server_type == "optimization" or \
server_type == "model_optimization" or \
server_type == "cluster_finetuning" or \
server_type == "cluster_parallel") :
raise ValueError("Invalid server type {} in federated learning config".format(server_type))
assert "best_model_criterion" in config["server_config"], "Missing \"best_model_criterion\" in server config"
if server_type == "model_optimization" or server_type == "cluster_finetuning" or server_type == "cluster_parallel":
assert "initial_lr_client" in config["server_config"], "Missing \"initial_lr_client\" in server config"
assert "lr_decay_factor" in config["server_config"], "Missing \"lr_decay_factor\" in server config"
assert "aggregate_median" in config["server_config"], "Missing \"aggregate_median\" in server config"
if "nbest_task_scheduler" in config["server_config"]:
assert "num_tasks" in config["server_config"]["nbest_task_scheduler"], "Define \"num_tasks\" in [\"nbest_task_scheduler\"]"
assert "iteration_per_task" in config["server_config"]["nbest_task_scheduler"], "Define \"iteration_per_task\" in [\"nbest_task_scheduler\"]"
assert len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]) == len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]), \
"Length mismatched: {}!={}".format(len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]), len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]))
data_path = config['data_path']
if 'vocab_dict' in config["server_config"]["data_config"]["val"]:
config["server_config"]["data_config"]["val"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["val"]["vocab_dict"])
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
config["server_config"]["data_config"]["test"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["test"]["vocab_dict"])
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
config["server_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["train"]["vocab_dict"])
# BERT specific parameters
if 'model_config' in config and 'BERT' in config['model_config']:
if 'model_name_or_path' in config['model_config']['BERT']['model']:
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name_or_path']
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
else:
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name']
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
if 'process_line_by_line' in config['model_config']['BERT']['model']:
config['server_config']['data_config']['val']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
config['server_config']['data_config']['test']['process_line_by_line']=config['model_config']['BERT']['model']['process_line_by_line']
if "initial_val" in config['server_config']:
config['server_config']['initial_val'] = config['server_config']['initial_val']
else:
config['server_config']['initial_val'] = False
if "initial_rec" in config['server_config']:
config['server_config']['initial_rec'] = config['server_config']['initial_rec']
else:
config['server_config']['initial_rec'] = False
return config
def check_client_config(config, default_client_conf):
assert "client_config" in config, "client config setting is missing"
# Setting the default values if missing
for key in default_client_conf.keys():
if not key in config["client_config"]:
config["client_config"][key] = default_client_conf[key]
client_type = config["client_config"]["type"]
if not (client_type == "gradient_computation" or client_type == "optimization"):
raise ValueError("Invalid client option {} in federated learning config".format(client_type))
if not "ss_config" in config["client_config"]:
config["client_config"]["ss_config"] = None
if "list_of_train_data" in config["client_config"]["data_config"]["train"] and "train_data" in config["client_config"]["data_config"]["train"]:
raise ValueError("\"list_of_train_data\" and \"train_data\" cannot be defined at the same time")
assert "list_of_train_data" in config["client_config"]["data_config"]["train"] or "train_data" in config["client_config"]["data_config"]["train"], "Define either \"list_of_train_data\" and \"train_data\""
# Adjust path to vocab_dict
data_path = config['data_path']
if 'vocab_dict' in config["client_config"]["data_config"]["train"]:
config["client_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["client_config"]["data_config"]["train"]["vocab_dict"])
# BERT specific parameters
if 'model_config' in config and 'train' in config['client_config']['data_config'] and 'BERT' in config['model_config']:
if 'model_name_or_path' in config['model_config']['BERT']['model']:
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
else:
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
if 'process_line_by_line' in config['model_config']['BERT']['model']:
config['client_config']['data_config']['train']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
return config

Просмотреть файл

@ -390,11 +390,14 @@ class Client:
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
# Estimate gradient magnitude mean/var
trainer.sufficient_stats['mean'] = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
trainer.sufficient_stats['mag'] = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
trainer.sufficient_stats['var'] = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - \
trainer.sufficient_stats['mag'] ** 2
trainer.sufficient_stats['norm'] = np.sqrt(trainer.sufficient_stats['sq_sum'])
# Now computed when the sufficient stats are updated.
assert 'sum' in trainer.sufficient_stats
assert 'mean' in trainer.sufficient_stats
# trainer.sufficient_stats['mean'] = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
# trainer.sufficient_stats['mag'] = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
# trainer.sufficient_stats['var'] = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - \
# trainer.sufficient_stats['mag'] ** 2
# trainer.sufficient_stats['norm'] = np.sqrt(trainer.sufficient_stats['sq_sum'])
trainer.train_loss = train_loss
trainer.num_samples = num_samples

Просмотреть файл

@ -3,6 +3,7 @@
from __future__ import annotations
from dataclasses import dataclass
from collections.abc import MutableMapping
import os
# TODO everywhere: choose reasonable defaults.
@ -613,8 +614,8 @@ class ServerConfig(Config):
annealing_config: AnnealingConfig = None
val_freq: int | None = None
rec_freq: int | None = None
initial_val: bool = False
initial_rec: bool = False
initial_val: bool = True
initial_rec: bool = True
wantRL: bool = False
RL: RLConfig = None
data_config: DataConfig = None
@ -739,6 +740,176 @@ class FLUTEConfig(Config):
server_config: ServerConfig = None
client_config: ClientConfig = None
# TODO: clean up all this validation code.
def _check_server_config(config, default_server_conf):
assert "server_config" in config, "server config setting is missing"
# Checking parameters for server-side training
if "train" in config["server_config"]["data_config"]:
if "train_data_server" in config["server_config"]["data_config"]["train"]:
assert "server_replay_config" in config["server_config"], "Training dataset is defined on the server but training parameters are not set"
assert "optimizer_config" in config["server_config"]["server_replay_config"], "Missing \"optimizer_config\" in server_replay server training config"
assert "server_iterations" in config["server_config"]["server_replay_config"], "Missing \"server_iterations\" in server_replay server training config"
# Setting the default values if missing
for key in default_server_conf.keys():
if key not in config["server_config"]:
config["server_config"][key] = default_server_conf[key]
server_type = config["server_config"]["type"]
if not (server_type == "model_averaging" or \
server_type == "optimization" or \
server_type == "model_optimization" or \
server_type == "cluster_finetuning" or \
server_type == "cluster_parallel"):
raise ValueError("Invalid server type {} in federated learning config".format(server_type))
assert "best_model_criterion" in config["server_config"], "Missing \"best_model_criterion\" in server config"
if server_type == "model_optimization" or server_type == "cluster_finetuning" or server_type == "cluster_parallel":
assert "initial_lr_client" in config["server_config"], "Missing \"initial_lr_client\" in server config"
assert "lr_decay_factor" in config["server_config"], "Missing \"lr_decay_factor\" in server config"
assert "aggregate_median" in config["server_config"], "Missing \"aggregate_median\" in server config"
if "nbest_task_scheduler" in config["server_config"]:
assert "num_tasks" in config["server_config"]["nbest_task_scheduler"], "Define \"num_tasks\" in [\"nbest_task_scheduler\"]"
assert "iteration_per_task" in config["server_config"]["nbest_task_scheduler"], "Define \"iteration_per_task\" in [\"nbest_task_scheduler\"]"
assert len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]) == len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]), \
"Length mismatched: {}!={}".format(len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]), len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]))
data_path = config['data_path']
if 'vocab_dict' in config["server_config"]["data_config"]["val"]:
config["server_config"]["data_config"]["val"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["val"]["vocab_dict"])
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
config["server_config"]["data_config"]["test"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["test"]["vocab_dict"])
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
config["server_config"]["data_config"]["train"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["train"]["vocab_dict"])
# BERT specific parameters
if 'model_config' in config and 'BERT' in config['model_config']:
if 'model_name_or_path' in config['model_config']['BERT']['model']:
config['server_config']['data_config']['val']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
config['server_config']['data_config']['test']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
else:
config['server_config']['data_config']['val']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
config['server_config']['data_config']['test']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
if 'process_line_by_line' in config['model_config']['BERT']['model']:
config['server_config']['data_config']['val']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
config['server_config']['data_config']['test']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
if "initial_val" in config['server_config']:
config['server_config']['initial_val'] = config['server_config']['initial_val']
else:
config['server_config']['initial_val'] = False
if "initial_rec" in config['server_config']:
config['server_config']['initial_rec'] = config['server_config']['initial_rec']
else:
config['server_config']['initial_rec'] = False
return config
def _check_client_config(config, default_client_conf):
assert "client_config" in config, "client config setting is missing"
# Setting the default values if missing
for key in default_client_conf.keys():
if key not in config["client_config"]:
config["client_config"][key] = default_client_conf[key]
client_type = config["client_config"]["type"]
if not (client_type == "gradient_computation" or client_type == "optimization"):
raise ValueError("Invalid client option {} in federated learning config".format(client_type))
if "ss_config" not in config["client_config"]:
config["client_config"]["ss_config"] = None
if "list_of_train_data" in config["client_config"]["data_config"]["train"] and "train_data" in config["client_config"]["data_config"]["train"]:
raise ValueError("\"list_of_train_data\" and \"train_data\" cannot be defined at the same time")
assert "list_of_train_data" in config["client_config"]["data_config"]["train"] or "train_data" in config["client_config"]["data_config"]["train"], "Define either \"list_of_train_data\" and \"train_data\""
# Adjust path to vocab_dict
data_path = config['data_path']
if 'vocab_dict' in config["client_config"]["data_config"]["train"]:
config["client_config"]["data_config"]["train"]["vocab_dict"] = os.path.join(data_path, config["client_config"]["data_config"]["train"]["vocab_dict"])
# BERT specific parameters
if 'model_config' in config and 'train' in config['client_config']['data_config'] and 'BERT' in config['model_config']:
if 'model_name_or_path' in config['model_config']['BERT']['model']:
config['client_config']['data_config']['train']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
else:
config['client_config']['data_config']['train']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
if 'process_line_by_line' in config['model_config']['BERT']['model']:
config['client_config']['data_config']['train']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
return config
def validate(config):
# Create dictionaries w/ parameters
default_data_conf = {
"input_dim": 300,
"batch_size": 40,
"loader_type": "text",
"prepend_datapath": False,
"pin_memory": True,
"num_frames": 0,
"desired_max_samples": 300,
"max_grad_norm": 5.0, # max_grad_norm for gradient clipping
"num_workers": 1,
"max_batch_size": 0, # maximum number of batch size; if 0, no limitation is applied
"unsorted_batch": False # do not sort when making batch; this is inefficient in terms of batch, but could be efficient in terms of accuracy
}
default_server_conf = {
"val_freq": 1,
"rec_freq": 8,
"max_iteration": 100000000,
"type": "optimization",
"data_config": default_data_conf,
"aggregate_median": None,
"best_model_criterion": "loss",
"fall_back_to_best_model": False,
"num_clients_per_iteration": -1
}
default_client_conf = {
"copying_train_jsonls": True,
"type": "gradient_computation",
"data_config": default_data_conf,
}
assert "data_path" in config, "data_path is missing from config"
assert "experiment_name" in config, "experiment_name is missing from config"
assert "output_path" in config, "output_path is missing from config"
assert "task" in config["server_config"], "task is missing from server_config"
assert "task" in config["client_config"], "task is missing from client_config"
assert "num_clients" not in config["server_config"]["data_config"], "Remove \"num_clients\" from server data_config since this is a reserved key"
assert "num_clients" not in config["client_config"]["data_config"], "Remove \"num_clients\" from client data_config since this is a reserved key"
# Make sure the pretrained model is found in the correct place
if "pretrained_model_path" in config["model_config"]["model_type"]:
config["model_config"]["model_type"]["pretrained_model_path"] = os.path.join(config["data_path"], config["model_config"]["model_type"]["pretrained_model_path"])
if "pretrained_model_path" in config["model_config"]:
config["model_config"]["pretrained_model_path"] = os.path.join(config["data_path"], config["model_config"]["pretrained_model_path"])
config._check_server_config(default_server_conf)
config._check_client_config(default_client_conf)
# RL-related options
if config["server_config"].get("wantRL", False):
if config["server_config"]["RL"].get("RL_path_global", True):
config["server_config"]["RL"]["RL_path"] = os.path.join(config["output_path"],
config["server_config"]["RL"]["RL_path"])
else:
config["server_config"]["RL"]["RL_path"] = os.path.join(config["output_path"], config["experiment_name"],
config["server_config"]["RL"]["RL_path"])
@staticmethod
def from_dict(config) -> FLUTEConfig:
dp_config = \

Просмотреть файл

@ -92,7 +92,14 @@ class OptimizationServer(federated.Server):
self.req_freq = server_config['rec_freq']
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, val_dataloader, test_dataloader)
self.metrics = dict()
# TODO: does this need to be adjusted for custom metrics?
self.metrics = {
'best_val_loss': float('inf'),
'best_val_acc': 0.0,
'best_test_loss': float('inf'),
'best_test_acc': 0.0
}
self.model_backup_freq = server_config.get('model_backup_freq', 100)
self.worker_trainer_config = server_config.get('trainer_config', {})
@ -198,17 +205,17 @@ class OptimizationServer(federated.Server):
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
elems = json.load(logfp)
self.cur_iter_no = elems.get('i', 0)
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
self.metrics['best_val_acc'] = elems.get('best_val_acc', float('inf'))
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
self.metrics['best_test_acc'] = elems.get('best_test_acc', float('inf'))
self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
self.lr_weight = elems.get('weight', 1.0)
self.no_label_updates = elems.get('num_label_updates', 0)
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
def run(self):
'''Trigger training.
This is a simple wrapper to the `train` method.
'''
print_rank('server started')
@ -237,7 +244,7 @@ class OptimizationServer(federated.Server):
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
eval_list = []
if self.cur_iter_no == 0:
if self.config['server_config']['initial_rec']:
eval_list.append('test')
if self.config['server_config']['initial_val']:
@ -246,7 +253,7 @@ class OptimizationServer(federated.Server):
print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
eval_list=[] # some cleanup
eval_list = [] # some cleanup
# Dump all the information in aggregate_metric
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
@ -257,7 +264,7 @@ class OptimizationServer(federated.Server):
config=self.config['server_config']
)
# Training loop
# Training loop
self.worker_trainer.model.train()
for i in range(self.cur_iter_no, self.max_iteration):
begin = time.time()
@ -330,7 +337,7 @@ class OptimizationServer(federated.Server):
adaptive_leakage = apply_privacy_metrics and \
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
if apply_privacy_metrics:
privacy_metrics_stats = defaultdict(list)
privacy_metrics_stats = defaultdict(list)
# Initialize profiler
profiler = None
@ -343,20 +350,20 @@ class OptimizationServer(federated.Server):
for client_output in self.process_clients(sampled_clients, server_data, self.clients_in_parallel):
# Process client output
client_timestamp = client_output['ts']
client_timestamp = client_output['ts']
client_stats = client_output['cs']
client_loss = client_output['tl']
client_mag_grad = client_output['mg']
client_mag_grad = client_output['mg']
client_mean_grad = client_output['ng']
client_var_grad = client_output['vg']
client_norm_grad = client_output['rg']
client_payload = client_output['pl']
if apply_privacy_metrics:
privacy_stats = client_output['ps']
for metric, value in privacy_stats.items():
privacy_metrics_stats[metric].append(value)
self.run_stats['mpiCosts'][-1].append(time.time() - client_timestamp)
# Get actual pseudo-gradients for aggregation
@ -395,7 +402,7 @@ class OptimizationServer(federated.Server):
client_norm_grads = np.array(client_norm_grads)
client_stats = (client_mag_grads, client_mean_grads, client_var_grads)
dump_norm_stats = self.config.get('dump_norm_stats', False)
if dump_norm_stats:
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
@ -463,11 +470,11 @@ class OptimizationServer(federated.Server):
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
self.losses = self.evaluation.losses
eval_list = []
# Create a schedule for the initial_lr (for the worker)
if 'val' in eval_list:
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
if not (self.losses[0] < self.metrics['best_val_loss']):
if not (self.losses[0] < self.metrics['best_val_loss']):
self.lr_weight *= self.lr_decay_factor
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
@ -482,10 +489,10 @@ class OptimizationServer(federated.Server):
self.log_path,
{
'i': i + 1,
'best_val_loss': float(self.metrics['best_val_loss']),
'best_val_acc': float(self.metrics['best_val_acc']),
'best_test_loss': float(self.metrics['best_test_loss']),
'best_test_acc': float(self.metrics['best_test_acc']),
'best_val_loss': float(self.metrics['best_val_loss']),
'best_val_acc': float(self.metrics['best_val_acc']),
'best_test_loss': float(self.metrics['best_test_loss']),
'best_test_acc': float(self.metrics['best_test_acc']),
'weight': float(self.lr_weight),
'num_label_updates': int(self.no_label_updates)
},
@ -531,7 +538,7 @@ class OptimizationServer(federated.Server):
def backup_models(self, i):
'''Save the current best models.
Save CER model, the best loss model and the best WER model. This occurs
at a specified period.
@ -586,10 +593,10 @@ def select_server(server_type, config):
Right now this just returns `OptimizationServer`, but this
function could be useful when there are multiple choices of
server.
Args:
server_type (str): indicates server choice.
config (dict): config parsed from YAML, passed so that
parameters can be used to select a given server.
'''
return OptimizationServer
return OptimizationServer

Просмотреть файл

@ -24,7 +24,7 @@ from utils import \
class TrainerBase:
"""Abstract class defining Trainer objects' common interface.
Args:
model (torch.nn.Module): model to be trained.
train_dataloader (torch.utils.data.DataLoader): dataloader that
@ -195,7 +195,7 @@ class ModelUpdater(TrainerBase):
class Trainer(TrainerBase):
"""Perform training step for any given client.
The main method to be called for triggering a training step is
The main method to be called for triggering a training step is
:code:`train_desired_samples`, which on its turn relies on
:code:`run_train_epoch`.
@ -293,7 +293,21 @@ class Trainer(TrainerBase):
"""Compute statistics about the gradients."""
sum_mean_grad, sum_mean_grad2, n = self.accumulate_gradient_power()
self.sufficient_stats = {"n": n, "sum": sum_mean_grad, "sq_sum": sum_mean_grad2}
mean_grad = sum_mean_grad / n
mag_grad = np.sqrt(sum_mean_grad2 / n)
var_grad = sum_mean_grad2 / n - mag_grad**2
norm_grad = np.sqrt(sum_mean_grad2)
self.sufficient_stats = {
"n": n,
"sum": sum_mean_grad,
"sq_sum": sum_mean_grad2,
"var": var_grad,
"mean": mean_grad,
"mag": mag_grad,
"norm": norm_grad
}
def train_desired_samples(self, desired_max_samples=None, apply_privacy_metrics=False):
"""Triggers training step.
@ -451,7 +465,7 @@ def run_validation_generic(model, val_dataloader):
f"len_sampler: {len(val_loader._index_sampler)}",
loglevel=logging.DEBUG
)
try:
from core.globals import task
loader = SourceFileLoader("CustomMetrics", str("./experiments/"+task+"/custom_metrics.py")).load_module()
@ -460,9 +474,9 @@ def run_validation_generic(model, val_dataloader):
except:
metrics_cl = Metrics()
print_rank("Loading default metrics")
return metrics_cl.compute_metrics(dataloader=val_loader, model=model)
def set_component_wise_lr(model, optimizer_config, updatable_names):
"""Set zero learning rate for layers in order to freeze the update.
@ -495,9 +509,9 @@ def save_model(model_path, config, model, optimizer, lr_scheduler, ss_scheduler,
"""Save a model as well as training information."""
save_state = {
"model_state_dict" : model.state_dict(),
"optimizer_state_dict" : optimizer.state_dict() if optimizer is not None else None,
"lr_scheduler_state_dict" : lr_scheduler.state_dict() if lr_scheduler is not None else None
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
"lr_scheduler_state_dict": lr_scheduler.state_dict() if lr_scheduler is not None else None
}
if ss_scheduler is not None:
save_state["ss_scheduler_state_dict"] = ss_scheduler.state_dict()

Просмотреть файл

@ -35,10 +35,6 @@ from utils.dataloaders_utils import (
make_val_dataloader,
make_test_dataloader,
)
from config_file_parser import (
check_server_config,
check_client_config
)
assert TRAINING_FRAMEWORK_TYPE == "mpi", "Unsupported platform {}".format(TRAINING_FRAMEWORK_TYPE)
@ -189,40 +185,7 @@ if __name__ == "__main__":
task = args.task
local_rank = args.local_rank
# Create dictionaries w/ parameters
default_data_conf = {
"input_dim": 300,
"batch_size": 40,
"loader_type": "text",
"prepend_datapath": False,
"pin_memory": True,
"num_frames": 0,
"desired_max_samples": 300,
"max_grad_norm": 5.0, # max_grad_norm for gradient clipping
"num_workers": 1,
"max_batch_size": 0, # maximum number of batch size; if 0, no limitation is applied
"unsorted_batch": False # do not sort when making batch; this is inefficient in terms of batch, but could be efficient in terms of accuracy
}
default_server_conf = {
"val_freq": 1,
"rec_freq": 8,
"max_iteration": 100000000,
"type": "optimization",
"data_config": default_data_conf,
"aggregate_median": None,
"best_model_criterion": "loss",
"fall_back_to_best_model": False,
"num_clients_per_iteration": -1
}
default_client_conf = {
"copying_train_jsonls": True,
"type": "gradient_computation",
"data_config": default_data_conf,
}
# The mount point can also be retrieved from input_datasets of the run context
# The mount point can also be retrieved from input_datasets of the run context
if data_path is None:
data_path = Run.get_context().input_datasets["input"]
print("The data can be found here: ", data_path)
@ -241,41 +204,20 @@ if __name__ == "__main__":
cfg_out = os.path.join(experiment_root, "FLUTE_config.yaml")
if local_rank <= 0:
shutil.copyfile(args.config, cfg_out)
print("Copy created")
# Initialize logging
init_logging(log_path, loglevel=logging_level)
with open(args.config) as f:
cfg_dict = yaml.safe_load(f)
config = FLUTEConfig.from_dict(cfg_dict)
assert "num_clients" not in config["server_config"]["data_config"], "Remove \"num_clients\" from server data_config since this is a reserved key"
assert "num_clients" not in config["client_config"]["data_config"], "Remove \"num_clients\" from client data_config since this is a reserved key"
# Make sure the pretrained model is found in the correct place
if "pretrained_model_path" in config["model_config"]["model_type"]:
config["model_config"]["model_type"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["model_type"]["pretrained_model_path"])
if "pretrained_model_path" in config["model_config"]:
config["model_config"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["pretrained_model_path"])
config["data_path"] = data_path
config = check_server_config(config, default_server_conf)
config = check_client_config(config, default_client_conf)
# Add task specification to client configuration
config["output_path"] = args.outputPath
config["experiment_name"] = experiment_name
config["client_config"]["task"] = task
config["server_config"]["task"] = task
# RL-related options
if config["server_config"].get("wantRL", False):
if config["server_config"]["RL"].get("RL_path_global", True):
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath,
config["server_config"]["RL"]["RL_path"])
else:
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath, experiment_name,
config["server_config"]["RL"]["RL_path"])
config.validate()
# Instantiate either Server or Worker on the thread
run_worker(model_path, config, task, data_path, local_rank)

Просмотреть файл

@ -23,12 +23,13 @@ def extract_indices_from_embeddings(gradients, batch, embed_size, vocab_size):
def compute_perplexity(encoded_batch, model):
outputs = model.inference(encoded_batch)
(batch_size, seq_len, vocab_size) = outputs[0].shape
perplex = T.nn.functional.log_softmax(outputs[0], dim=-1)
outputs = model.inference(encoded_batch)
(batch_size, seq_len, vocab_size) = outputs['output'].shape
perplex = T.nn.functional.log_softmax(outputs['output'], dim=-1)
return perplex.reshape(-1, vocab_size)[np.arange(batch_size * seq_len),
encoded_batch.reshape(-1)].reshape(batch_size, seq_len)
def practical_epsilon_leakage(original_params, model, encoded_batches, is_weighted_leakage=True,
max_ratio=1e9, optimizer_config=None):
# Copy the gradients and save the model.

Просмотреть файл

@ -143,7 +143,7 @@ server_config:
unsorted_batch: true
type: model_optimization
aggregate_median: softmax # the FL aggregation method
weight_train_loss: grad_mean_loss #train_loss #grad_mean_loss #or train_loss - how each client's weight is determined.
weight_train_loss: mag_mean_loss #or train_loss or mag_var_loss - how each client's weight is determined.
softmax_beta: 1000
initial_lr_client: 0.1
lr_decay_factor: 1.0