зеркало из https://github.com/microsoft/msrflute.git
Merged PR 1083: Move config validation to FLUTEConfig.validate(). Fix some bugs.
what it says on the tin
This commit is contained in:
Родитель
f97162fd0a
Коммит
a8a20c2c7d
|
@ -1,116 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
A collection of functions for checking the format of configuration values
|
||||
"""
|
||||
import os
|
||||
|
||||
def check_server_config(config, default_server_conf):
|
||||
|
||||
assert "server_config" in config, "server config setting is missing"
|
||||
|
||||
# Checking parameters for server-side training
|
||||
if "train" in config["server_config"]["data_config"]:
|
||||
if "train_data_server" in config["server_config"]["data_config"]["train"]:
|
||||
assert "server_replay_config" in config["server_config"], "Training dataset is defined on the server but training parameters are not set"
|
||||
assert "optimizer_config" in config["server_config"]["server_replay_config"], "Missing \"optimizer_config\" in server_replay server training config"
|
||||
assert "server_iterations" in config["server_config"]["server_replay_config"], "Missing \"server_iterations\" in server_replay server training config"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_server_conf.keys():
|
||||
if not key in config["server_config"]:
|
||||
config["server_config"][key] = default_server_conf[key]
|
||||
|
||||
server_type = config["server_config"]["type"]
|
||||
if not (server_type == "model_averaging" or \
|
||||
server_type == "optimization" or \
|
||||
server_type == "model_optimization" or \
|
||||
server_type == "cluster_finetuning" or \
|
||||
server_type == "cluster_parallel") :
|
||||
raise ValueError("Invalid server type {} in federated learning config".format(server_type))
|
||||
|
||||
assert "best_model_criterion" in config["server_config"], "Missing \"best_model_criterion\" in server config"
|
||||
|
||||
if server_type == "model_optimization" or server_type == "cluster_finetuning" or server_type == "cluster_parallel":
|
||||
assert "initial_lr_client" in config["server_config"], "Missing \"initial_lr_client\" in server config"
|
||||
assert "lr_decay_factor" in config["server_config"], "Missing \"lr_decay_factor\" in server config"
|
||||
assert "aggregate_median" in config["server_config"], "Missing \"aggregate_median\" in server config"
|
||||
|
||||
if "nbest_task_scheduler" in config["server_config"]:
|
||||
assert "num_tasks" in config["server_config"]["nbest_task_scheduler"], "Define \"num_tasks\" in [\"nbest_task_scheduler\"]"
|
||||
assert "iteration_per_task" in config["server_config"]["nbest_task_scheduler"], "Define \"iteration_per_task\" in [\"nbest_task_scheduler\"]"
|
||||
assert len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]) == len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]), \
|
||||
"Length mismatched: {}!={}".format(len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]), len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]))
|
||||
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["val"]:
|
||||
config["server_config"]["data_config"]["val"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["val"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["test"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["test"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name_or_path']
|
||||
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name']
|
||||
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
|
||||
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
|
||||
config['server_config']['data_config']['test']['process_line_by_line']=config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
if "initial_val" in config['server_config']:
|
||||
config['server_config']['initial_val'] = config['server_config']['initial_val']
|
||||
else:
|
||||
config['server_config']['initial_val'] = False
|
||||
|
||||
if "initial_rec" in config['server_config']:
|
||||
config['server_config']['initial_rec'] = config['server_config']['initial_rec']
|
||||
else:
|
||||
config['server_config']['initial_rec'] = False
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def check_client_config(config, default_client_conf):
|
||||
|
||||
assert "client_config" in config, "client config setting is missing"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_client_conf.keys():
|
||||
if not key in config["client_config"]:
|
||||
config["client_config"][key] = default_client_conf[key]
|
||||
|
||||
client_type = config["client_config"]["type"]
|
||||
if not (client_type == "gradient_computation" or client_type == "optimization"):
|
||||
raise ValueError("Invalid client option {} in federated learning config".format(client_type))
|
||||
|
||||
if not "ss_config" in config["client_config"]:
|
||||
config["client_config"]["ss_config"] = None
|
||||
|
||||
if "list_of_train_data" in config["client_config"]["data_config"]["train"] and "train_data" in config["client_config"]["data_config"]["train"]:
|
||||
raise ValueError("\"list_of_train_data\" and \"train_data\" cannot be defined at the same time")
|
||||
|
||||
assert "list_of_train_data" in config["client_config"]["data_config"]["train"] or "train_data" in config["client_config"]["data_config"]["train"], "Define either \"list_of_train_data\" and \"train_data\""
|
||||
|
||||
# Adjust path to vocab_dict
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["client_config"]["data_config"]["train"]:
|
||||
config["client_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["client_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'train' in config['client_config']['data_config'] and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
return config
|
|
@ -390,11 +390,14 @@ class Client:
|
|||
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
|
||||
|
||||
# Estimate gradient magnitude mean/var
|
||||
trainer.sufficient_stats['mean'] = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
|
||||
trainer.sufficient_stats['mag'] = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
|
||||
trainer.sufficient_stats['var'] = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - \
|
||||
trainer.sufficient_stats['mag'] ** 2
|
||||
trainer.sufficient_stats['norm'] = np.sqrt(trainer.sufficient_stats['sq_sum'])
|
||||
# Now computed when the sufficient stats are updated.
|
||||
assert 'sum' in trainer.sufficient_stats
|
||||
assert 'mean' in trainer.sufficient_stats
|
||||
# trainer.sufficient_stats['mean'] = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
|
||||
# trainer.sufficient_stats['mag'] = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
|
||||
# trainer.sufficient_stats['var'] = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - \
|
||||
# trainer.sufficient_stats['mag'] ** 2
|
||||
# trainer.sufficient_stats['norm'] = np.sqrt(trainer.sufficient_stats['sq_sum'])
|
||||
|
||||
trainer.train_loss = train_loss
|
||||
trainer.num_samples = num_samples
|
||||
|
|
175
core/config.py
175
core/config.py
|
@ -3,6 +3,7 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from collections.abc import MutableMapping
|
||||
import os
|
||||
|
||||
|
||||
# TODO everywhere: choose reasonable defaults.
|
||||
|
@ -613,8 +614,8 @@ class ServerConfig(Config):
|
|||
annealing_config: AnnealingConfig = None
|
||||
val_freq: int | None = None
|
||||
rec_freq: int | None = None
|
||||
initial_val: bool = False
|
||||
initial_rec: bool = False
|
||||
initial_val: bool = True
|
||||
initial_rec: bool = True
|
||||
wantRL: bool = False
|
||||
RL: RLConfig = None
|
||||
data_config: DataConfig = None
|
||||
|
@ -739,6 +740,176 @@ class FLUTEConfig(Config):
|
|||
server_config: ServerConfig = None
|
||||
client_config: ClientConfig = None
|
||||
|
||||
# TODO: clean up all this validation code.
|
||||
def _check_server_config(config, default_server_conf):
|
||||
|
||||
assert "server_config" in config, "server config setting is missing"
|
||||
|
||||
# Checking parameters for server-side training
|
||||
if "train" in config["server_config"]["data_config"]:
|
||||
if "train_data_server" in config["server_config"]["data_config"]["train"]:
|
||||
assert "server_replay_config" in config["server_config"], "Training dataset is defined on the server but training parameters are not set"
|
||||
assert "optimizer_config" in config["server_config"]["server_replay_config"], "Missing \"optimizer_config\" in server_replay server training config"
|
||||
assert "server_iterations" in config["server_config"]["server_replay_config"], "Missing \"server_iterations\" in server_replay server training config"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_server_conf.keys():
|
||||
if key not in config["server_config"]:
|
||||
config["server_config"][key] = default_server_conf[key]
|
||||
|
||||
server_type = config["server_config"]["type"]
|
||||
if not (server_type == "model_averaging" or \
|
||||
server_type == "optimization" or \
|
||||
server_type == "model_optimization" or \
|
||||
server_type == "cluster_finetuning" or \
|
||||
server_type == "cluster_parallel"):
|
||||
raise ValueError("Invalid server type {} in federated learning config".format(server_type))
|
||||
|
||||
assert "best_model_criterion" in config["server_config"], "Missing \"best_model_criterion\" in server config"
|
||||
|
||||
if server_type == "model_optimization" or server_type == "cluster_finetuning" or server_type == "cluster_parallel":
|
||||
assert "initial_lr_client" in config["server_config"], "Missing \"initial_lr_client\" in server config"
|
||||
assert "lr_decay_factor" in config["server_config"], "Missing \"lr_decay_factor\" in server config"
|
||||
assert "aggregate_median" in config["server_config"], "Missing \"aggregate_median\" in server config"
|
||||
|
||||
if "nbest_task_scheduler" in config["server_config"]:
|
||||
assert "num_tasks" in config["server_config"]["nbest_task_scheduler"], "Define \"num_tasks\" in [\"nbest_task_scheduler\"]"
|
||||
assert "iteration_per_task" in config["server_config"]["nbest_task_scheduler"], "Define \"iteration_per_task\" in [\"nbest_task_scheduler\"]"
|
||||
assert len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]) == len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]), \
|
||||
"Length mismatched: {}!={}".format(len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]), len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]))
|
||||
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["val"]:
|
||||
config["server_config"]["data_config"]["val"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["val"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["test"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["test"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["train"]["vocab_dict"] = os.path.join(data_path, config["server_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
|
||||
config['server_config']['data_config']['test']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
|
||||
config['server_config']['data_config']['test']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
|
||||
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
|
||||
config['server_config']['data_config']['test']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
if "initial_val" in config['server_config']:
|
||||
config['server_config']['initial_val'] = config['server_config']['initial_val']
|
||||
else:
|
||||
config['server_config']['initial_val'] = False
|
||||
|
||||
if "initial_rec" in config['server_config']:
|
||||
config['server_config']['initial_rec'] = config['server_config']['initial_rec']
|
||||
else:
|
||||
config['server_config']['initial_rec'] = False
|
||||
|
||||
return config
|
||||
|
||||
def _check_client_config(config, default_client_conf):
|
||||
|
||||
assert "client_config" in config, "client config setting is missing"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_client_conf.keys():
|
||||
if key not in config["client_config"]:
|
||||
config["client_config"][key] = default_client_conf[key]
|
||||
|
||||
client_type = config["client_config"]["type"]
|
||||
if not (client_type == "gradient_computation" or client_type == "optimization"):
|
||||
raise ValueError("Invalid client option {} in federated learning config".format(client_type))
|
||||
|
||||
if "ss_config" not in config["client_config"]:
|
||||
config["client_config"]["ss_config"] = None
|
||||
|
||||
if "list_of_train_data" in config["client_config"]["data_config"]["train"] and "train_data" in config["client_config"]["data_config"]["train"]:
|
||||
raise ValueError("\"list_of_train_data\" and \"train_data\" cannot be defined at the same time")
|
||||
|
||||
assert "list_of_train_data" in config["client_config"]["data_config"]["train"] or "train_data" in config["client_config"]["data_config"]["train"], "Define either \"list_of_train_data\" and \"train_data\""
|
||||
|
||||
# Adjust path to vocab_dict
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["client_config"]["data_config"]["train"]:
|
||||
config["client_config"]["data_config"]["train"]["vocab_dict"] = os.path.join(data_path, config["client_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'train' in config['client_config']['data_config'] and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['client_config']['data_config']['train']['model_name_or_path'] = config['model_config']['BERT']['model']['model_name']
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['process_line_by_line'] = config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
return config
|
||||
|
||||
def validate(config):
|
||||
|
||||
# Create dictionaries w/ parameters
|
||||
default_data_conf = {
|
||||
"input_dim": 300,
|
||||
"batch_size": 40,
|
||||
"loader_type": "text",
|
||||
"prepend_datapath": False,
|
||||
"pin_memory": True,
|
||||
"num_frames": 0,
|
||||
"desired_max_samples": 300,
|
||||
"max_grad_norm": 5.0, # max_grad_norm for gradient clipping
|
||||
"num_workers": 1,
|
||||
"max_batch_size": 0, # maximum number of batch size; if 0, no limitation is applied
|
||||
"unsorted_batch": False # do not sort when making batch; this is inefficient in terms of batch, but could be efficient in terms of accuracy
|
||||
}
|
||||
|
||||
default_server_conf = {
|
||||
"val_freq": 1,
|
||||
"rec_freq": 8,
|
||||
"max_iteration": 100000000,
|
||||
"type": "optimization",
|
||||
"data_config": default_data_conf,
|
||||
"aggregate_median": None,
|
||||
"best_model_criterion": "loss",
|
||||
"fall_back_to_best_model": False,
|
||||
"num_clients_per_iteration": -1
|
||||
}
|
||||
|
||||
default_client_conf = {
|
||||
"copying_train_jsonls": True,
|
||||
"type": "gradient_computation",
|
||||
"data_config": default_data_conf,
|
||||
}
|
||||
|
||||
assert "data_path" in config, "data_path is missing from config"
|
||||
assert "experiment_name" in config, "experiment_name is missing from config"
|
||||
assert "output_path" in config, "output_path is missing from config"
|
||||
assert "task" in config["server_config"], "task is missing from server_config"
|
||||
assert "task" in config["client_config"], "task is missing from client_config"
|
||||
|
||||
assert "num_clients" not in config["server_config"]["data_config"], "Remove \"num_clients\" from server data_config since this is a reserved key"
|
||||
assert "num_clients" not in config["client_config"]["data_config"], "Remove \"num_clients\" from client data_config since this is a reserved key"
|
||||
|
||||
# Make sure the pretrained model is found in the correct place
|
||||
if "pretrained_model_path" in config["model_config"]["model_type"]:
|
||||
config["model_config"]["model_type"]["pretrained_model_path"] = os.path.join(config["data_path"], config["model_config"]["model_type"]["pretrained_model_path"])
|
||||
if "pretrained_model_path" in config["model_config"]:
|
||||
config["model_config"]["pretrained_model_path"] = os.path.join(config["data_path"], config["model_config"]["pretrained_model_path"])
|
||||
|
||||
config._check_server_config(default_server_conf)
|
||||
config._check_client_config(default_client_conf)
|
||||
|
||||
# RL-related options
|
||||
if config["server_config"].get("wantRL", False):
|
||||
if config["server_config"]["RL"].get("RL_path_global", True):
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(config["output_path"],
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
else:
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(config["output_path"], config["experiment_name"],
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
|
||||
@staticmethod
|
||||
def from_dict(config) -> FLUTEConfig:
|
||||
dp_config = \
|
||||
|
|
|
@ -92,7 +92,14 @@ class OptimizationServer(federated.Server):
|
|||
self.req_freq = server_config['rec_freq']
|
||||
|
||||
self.evaluation = Evaluation(config, model_path, self.process_testvalidate, val_dataloader, test_dataloader)
|
||||
self.metrics = dict()
|
||||
|
||||
# TODO: does this need to be adjusted for custom metrics?
|
||||
self.metrics = {
|
||||
'best_val_loss': float('inf'),
|
||||
'best_val_acc': 0.0,
|
||||
'best_test_loss': float('inf'),
|
||||
'best_test_acc': 0.0
|
||||
}
|
||||
|
||||
self.model_backup_freq = server_config.get('model_backup_freq', 100)
|
||||
self.worker_trainer_config = server_config.get('trainer_config', {})
|
||||
|
@ -198,17 +205,17 @@ class OptimizationServer(federated.Server):
|
|||
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
|
||||
elems = json.load(logfp)
|
||||
self.cur_iter_no = elems.get('i', 0)
|
||||
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
|
||||
self.metrics['best_val_acc'] = elems.get('best_val_acc', float('inf'))
|
||||
self.metrics['best_val_loss'] = elems.get('best_val_loss', float('inf'))
|
||||
self.metrics['best_val_acc'] = elems.get('best_val_acc', 0)
|
||||
self.metrics['best_test_loss'] = elems.get('best_test_loss', float('inf'))
|
||||
self.metrics['best_test_acc'] = elems.get('best_test_acc', float('inf'))
|
||||
self.metrics['best_test_acc'] = elems.get('best_test_acc', 0)
|
||||
self.lr_weight = elems.get('weight', 1.0)
|
||||
self.no_label_updates = elems.get('num_label_updates', 0)
|
||||
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
|
||||
|
||||
|
||||
def run(self):
|
||||
'''Trigger training.
|
||||
|
||||
|
||||
This is a simple wrapper to the `train` method.
|
||||
'''
|
||||
print_rank('server started')
|
||||
|
@ -237,7 +244,7 @@ class OptimizationServer(federated.Server):
|
|||
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
|
||||
eval_list = []
|
||||
if self.cur_iter_no == 0:
|
||||
|
||||
|
||||
if self.config['server_config']['initial_rec']:
|
||||
eval_list.append('test')
|
||||
if self.config['server_config']['initial_val']:
|
||||
|
@ -246,7 +253,7 @@ class OptimizationServer(federated.Server):
|
|||
|
||||
print_rank("Running {} at itr={}".format(eval_list, self.cur_iter_no))
|
||||
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
||||
eval_list=[] # some cleanup
|
||||
eval_list = [] # some cleanup
|
||||
|
||||
# Dump all the information in aggregate_metric
|
||||
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
|
||||
|
@ -257,7 +264,7 @@ class OptimizationServer(federated.Server):
|
|||
config=self.config['server_config']
|
||||
)
|
||||
|
||||
# Training loop
|
||||
# Training loop
|
||||
self.worker_trainer.model.train()
|
||||
for i in range(self.cur_iter_no, self.max_iteration):
|
||||
begin = time.time()
|
||||
|
@ -330,7 +337,7 @@ class OptimizationServer(federated.Server):
|
|||
adaptive_leakage = apply_privacy_metrics and \
|
||||
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
|
||||
if apply_privacy_metrics:
|
||||
privacy_metrics_stats = defaultdict(list)
|
||||
privacy_metrics_stats = defaultdict(list)
|
||||
|
||||
# Initialize profiler
|
||||
profiler = None
|
||||
|
@ -343,20 +350,20 @@ class OptimizationServer(federated.Server):
|
|||
|
||||
for client_output in self.process_clients(sampled_clients, server_data, self.clients_in_parallel):
|
||||
# Process client output
|
||||
client_timestamp = client_output['ts']
|
||||
client_timestamp = client_output['ts']
|
||||
client_stats = client_output['cs']
|
||||
client_loss = client_output['tl']
|
||||
client_mag_grad = client_output['mg']
|
||||
client_mag_grad = client_output['mg']
|
||||
client_mean_grad = client_output['ng']
|
||||
client_var_grad = client_output['vg']
|
||||
client_norm_grad = client_output['rg']
|
||||
client_payload = client_output['pl']
|
||||
|
||||
|
||||
if apply_privacy_metrics:
|
||||
privacy_stats = client_output['ps']
|
||||
for metric, value in privacy_stats.items():
|
||||
privacy_metrics_stats[metric].append(value)
|
||||
|
||||
|
||||
self.run_stats['mpiCosts'][-1].append(time.time() - client_timestamp)
|
||||
|
||||
# Get actual pseudo-gradients for aggregation
|
||||
|
@ -395,7 +402,7 @@ class OptimizationServer(federated.Server):
|
|||
client_norm_grads = np.array(client_norm_grads)
|
||||
|
||||
client_stats = (client_mag_grads, client_mean_grads, client_var_grads)
|
||||
|
||||
|
||||
dump_norm_stats = self.config.get('dump_norm_stats', False)
|
||||
if dump_norm_stats:
|
||||
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
|
||||
|
@ -463,11 +470,11 @@ class OptimizationServer(federated.Server):
|
|||
self.metrics = self.evaluation.run(eval_list, self.metrics, metric_logger=run.log)
|
||||
self.losses = self.evaluation.losses
|
||||
eval_list = []
|
||||
|
||||
|
||||
# Create a schedule for the initial_lr (for the worker)
|
||||
if 'val' in eval_list:
|
||||
run.log('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
||||
if not (self.losses[0] < self.metrics['best_val_loss']):
|
||||
if not (self.losses[0] < self.metrics['best_val_loss']):
|
||||
self.lr_weight *= self.lr_decay_factor
|
||||
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
|
||||
|
||||
|
@ -482,10 +489,10 @@ class OptimizationServer(federated.Server):
|
|||
self.log_path,
|
||||
{
|
||||
'i': i + 1,
|
||||
'best_val_loss': float(self.metrics['best_val_loss']),
|
||||
'best_val_acc': float(self.metrics['best_val_acc']),
|
||||
'best_test_loss': float(self.metrics['best_test_loss']),
|
||||
'best_test_acc': float(self.metrics['best_test_acc']),
|
||||
'best_val_loss': float(self.metrics['best_val_loss']),
|
||||
'best_val_acc': float(self.metrics['best_val_acc']),
|
||||
'best_test_loss': float(self.metrics['best_test_loss']),
|
||||
'best_test_acc': float(self.metrics['best_test_acc']),
|
||||
'weight': float(self.lr_weight),
|
||||
'num_label_updates': int(self.no_label_updates)
|
||||
},
|
||||
|
@ -531,7 +538,7 @@ class OptimizationServer(federated.Server):
|
|||
|
||||
def backup_models(self, i):
|
||||
'''Save the current best models.
|
||||
|
||||
|
||||
Save CER model, the best loss model and the best WER model. This occurs
|
||||
at a specified period.
|
||||
|
||||
|
@ -586,10 +593,10 @@ def select_server(server_type, config):
|
|||
Right now this just returns `OptimizationServer`, but this
|
||||
function could be useful when there are multiple choices of
|
||||
server.
|
||||
|
||||
|
||||
Args:
|
||||
server_type (str): indicates server choice.
|
||||
config (dict): config parsed from YAML, passed so that
|
||||
parameters can be used to select a given server.
|
||||
'''
|
||||
return OptimizationServer
|
||||
return OptimizationServer
|
||||
|
|
|
@ -24,7 +24,7 @@ from utils import \
|
|||
|
||||
class TrainerBase:
|
||||
"""Abstract class defining Trainer objects' common interface.
|
||||
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to be trained.
|
||||
train_dataloader (torch.utils.data.DataLoader): dataloader that
|
||||
|
@ -195,7 +195,7 @@ class ModelUpdater(TrainerBase):
|
|||
class Trainer(TrainerBase):
|
||||
"""Perform training step for any given client.
|
||||
|
||||
The main method to be called for triggering a training step is
|
||||
The main method to be called for triggering a training step is
|
||||
:code:`train_desired_samples`, which on its turn relies on
|
||||
:code:`run_train_epoch`.
|
||||
|
||||
|
@ -293,7 +293,21 @@ class Trainer(TrainerBase):
|
|||
"""Compute statistics about the gradients."""
|
||||
|
||||
sum_mean_grad, sum_mean_grad2, n = self.accumulate_gradient_power()
|
||||
self.sufficient_stats = {"n": n, "sum": sum_mean_grad, "sq_sum": sum_mean_grad2}
|
||||
|
||||
mean_grad = sum_mean_grad / n
|
||||
mag_grad = np.sqrt(sum_mean_grad2 / n)
|
||||
var_grad = sum_mean_grad2 / n - mag_grad**2
|
||||
norm_grad = np.sqrt(sum_mean_grad2)
|
||||
|
||||
self.sufficient_stats = {
|
||||
"n": n,
|
||||
"sum": sum_mean_grad,
|
||||
"sq_sum": sum_mean_grad2,
|
||||
"var": var_grad,
|
||||
"mean": mean_grad,
|
||||
"mag": mag_grad,
|
||||
"norm": norm_grad
|
||||
}
|
||||
|
||||
def train_desired_samples(self, desired_max_samples=None, apply_privacy_metrics=False):
|
||||
"""Triggers training step.
|
||||
|
@ -451,7 +465,7 @@ def run_validation_generic(model, val_dataloader):
|
|||
f"len_sampler: {len(val_loader._index_sampler)}",
|
||||
loglevel=logging.DEBUG
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from core.globals import task
|
||||
loader = SourceFileLoader("CustomMetrics", str("./experiments/"+task+"/custom_metrics.py")).load_module()
|
||||
|
@ -460,9 +474,9 @@ def run_validation_generic(model, val_dataloader):
|
|||
except:
|
||||
metrics_cl = Metrics()
|
||||
print_rank("Loading default metrics")
|
||||
|
||||
|
||||
return metrics_cl.compute_metrics(dataloader=val_loader, model=model)
|
||||
|
||||
|
||||
def set_component_wise_lr(model, optimizer_config, updatable_names):
|
||||
"""Set zero learning rate for layers in order to freeze the update.
|
||||
|
||||
|
@ -495,9 +509,9 @@ def save_model(model_path, config, model, optimizer, lr_scheduler, ss_scheduler,
|
|||
"""Save a model as well as training information."""
|
||||
|
||||
save_state = {
|
||||
"model_state_dict" : model.state_dict(),
|
||||
"optimizer_state_dict" : optimizer.state_dict() if optimizer is not None else None,
|
||||
"lr_scheduler_state_dict" : lr_scheduler.state_dict() if lr_scheduler is not None else None
|
||||
"model_state_dict": model.state_dict(),
|
||||
"optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
|
||||
"lr_scheduler_state_dict": lr_scheduler.state_dict() if lr_scheduler is not None else None
|
||||
}
|
||||
if ss_scheduler is not None:
|
||||
save_state["ss_scheduler_state_dict"] = ss_scheduler.state_dict()
|
||||
|
|
|
@ -35,10 +35,6 @@ from utils.dataloaders_utils import (
|
|||
make_val_dataloader,
|
||||
make_test_dataloader,
|
||||
)
|
||||
from config_file_parser import (
|
||||
check_server_config,
|
||||
check_client_config
|
||||
)
|
||||
|
||||
assert TRAINING_FRAMEWORK_TYPE == "mpi", "Unsupported platform {}".format(TRAINING_FRAMEWORK_TYPE)
|
||||
|
||||
|
@ -189,40 +185,7 @@ if __name__ == "__main__":
|
|||
task = args.task
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Create dictionaries w/ parameters
|
||||
default_data_conf = {
|
||||
"input_dim": 300,
|
||||
"batch_size": 40,
|
||||
"loader_type": "text",
|
||||
"prepend_datapath": False,
|
||||
"pin_memory": True,
|
||||
"num_frames": 0,
|
||||
"desired_max_samples": 300,
|
||||
"max_grad_norm": 5.0, # max_grad_norm for gradient clipping
|
||||
"num_workers": 1,
|
||||
"max_batch_size": 0, # maximum number of batch size; if 0, no limitation is applied
|
||||
"unsorted_batch": False # do not sort when making batch; this is inefficient in terms of batch, but could be efficient in terms of accuracy
|
||||
}
|
||||
|
||||
default_server_conf = {
|
||||
"val_freq": 1,
|
||||
"rec_freq": 8,
|
||||
"max_iteration": 100000000,
|
||||
"type": "optimization",
|
||||
"data_config": default_data_conf,
|
||||
"aggregate_median": None,
|
||||
"best_model_criterion": "loss",
|
||||
"fall_back_to_best_model": False,
|
||||
"num_clients_per_iteration": -1
|
||||
}
|
||||
|
||||
default_client_conf = {
|
||||
"copying_train_jsonls": True,
|
||||
"type": "gradient_computation",
|
||||
"data_config": default_data_conf,
|
||||
}
|
||||
|
||||
# The mount point can also be retrieved from input_datasets of the run context
|
||||
# The mount point can also be retrieved from input_datasets of the run context
|
||||
if data_path is None:
|
||||
data_path = Run.get_context().input_datasets["input"]
|
||||
print("The data can be found here: ", data_path)
|
||||
|
@ -241,41 +204,20 @@ if __name__ == "__main__":
|
|||
cfg_out = os.path.join(experiment_root, "FLUTE_config.yaml")
|
||||
if local_rank <= 0:
|
||||
shutil.copyfile(args.config, cfg_out)
|
||||
print("Copy created")
|
||||
|
||||
|
||||
# Initialize logging
|
||||
init_logging(log_path, loglevel=logging_level)
|
||||
|
||||
with open(args.config) as f:
|
||||
cfg_dict = yaml.safe_load(f)
|
||||
config = FLUTEConfig.from_dict(cfg_dict)
|
||||
|
||||
assert "num_clients" not in config["server_config"]["data_config"], "Remove \"num_clients\" from server data_config since this is a reserved key"
|
||||
assert "num_clients" not in config["client_config"]["data_config"], "Remove \"num_clients\" from client data_config since this is a reserved key"
|
||||
|
||||
# Make sure the pretrained model is found in the correct place
|
||||
if "pretrained_model_path" in config["model_config"]["model_type"]:
|
||||
config["model_config"]["model_type"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["model_type"]["pretrained_model_path"])
|
||||
if "pretrained_model_path" in config["model_config"]:
|
||||
config["model_config"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["pretrained_model_path"])
|
||||
|
||||
|
||||
config["data_path"] = data_path
|
||||
|
||||
config = check_server_config(config, default_server_conf)
|
||||
config = check_client_config(config, default_client_conf)
|
||||
|
||||
# Add task specification to client configuration
|
||||
config["output_path"] = args.outputPath
|
||||
config["experiment_name"] = experiment_name
|
||||
config["client_config"]["task"] = task
|
||||
config["server_config"]["task"] = task
|
||||
|
||||
# RL-related options
|
||||
if config["server_config"].get("wantRL", False):
|
||||
if config["server_config"]["RL"].get("RL_path_global", True):
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath,
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
else:
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath, experiment_name,
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
config.validate()
|
||||
|
||||
# Instantiate either Server or Worker on the thread
|
||||
run_worker(model_path, config, task, data_path, local_rank)
|
||||
|
|
|
@ -23,12 +23,13 @@ def extract_indices_from_embeddings(gradients, batch, embed_size, vocab_size):
|
|||
|
||||
|
||||
def compute_perplexity(encoded_batch, model):
|
||||
outputs = model.inference(encoded_batch)
|
||||
(batch_size, seq_len, vocab_size) = outputs[0].shape
|
||||
perplex = T.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||
outputs = model.inference(encoded_batch)
|
||||
(batch_size, seq_len, vocab_size) = outputs['output'].shape
|
||||
perplex = T.nn.functional.log_softmax(outputs['output'], dim=-1)
|
||||
return perplex.reshape(-1, vocab_size)[np.arange(batch_size * seq_len),
|
||||
encoded_batch.reshape(-1)].reshape(batch_size, seq_len)
|
||||
|
||||
|
||||
def practical_epsilon_leakage(original_params, model, encoded_batches, is_weighted_leakage=True,
|
||||
max_ratio=1e9, optimizer_config=None):
|
||||
# Copy the gradients and save the model.
|
||||
|
|
|
@ -143,7 +143,7 @@ server_config:
|
|||
unsorted_batch: true
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # the FL aggregation method
|
||||
weight_train_loss: grad_mean_loss #train_loss #grad_mean_loss #or train_loss - how each client's weight is determined.
|
||||
weight_train_loss: mag_mean_loss #or train_loss or mag_var_loss - how each client's weight is determined.
|
||||
softmax_beta: 1000
|
||||
initial_lr_client: 0.1
|
||||
lr_decay_factor: 1.0
|
||||
|
|
Загрузка…
Ссылка в новой задаче