зеркало из https://github.com/microsoft/msrflute.git
First commit
Co-authored-by: Mirian Hipolito Garcia <mirianh@microsoft.com>
This commit is contained in:
Коммит
cf9d22f748
|
@ -0,0 +1,2 @@
|
|||
[flake8]
|
||||
ignore = E501
|
|
@ -0,0 +1,6 @@
|
|||
__pycache__/
|
||||
.vscode/
|
||||
doc/sphinx/_build
|
||||
testing/logs.txt
|
||||
testing/outputs
|
||||
testing/mockup
|
|
@ -0,0 +1,3 @@
|
|||
[submodule "dp-accountant"]
|
||||
path = dp-accountant
|
||||
url = https://github.com/microsoft/prv_accountant
|
|
@ -0,0 +1,9 @@
|
|||
# Microsoft Open Source Code of Conduct
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
|
||||
Resources:
|
||||
|
||||
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
|
||||
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
|
|
@ -0,0 +1,19 @@
|
|||
# Contributing
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to
|
||||
agree to a Contributor License Agreement (CLA) declaring that you have the right to,
|
||||
and actually do, grant us the rights to use your contribution. For details, visit
|
||||
https://cla.microsoft.com.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
|
||||
or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
Submit pull requests to **branch contribution**. PR's in any other branch will not be accepted.
|
||||
|
||||
When you submit a pull request, a CLA-bot will automatically determine whether you need
|
||||
to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the
|
||||
instructions provided by the bot. You will only need to do this once across all repositories using our CLA.
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
Copyright (c) Microsoft Corporation.
|
||||
|
||||
MIT License
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,21 @@
|
|||
THIRD-PARTY SOFTWARE NOTICES AND INFORMATION
|
||||
Do Not Translate or Localize
|
||||
|
||||
This software incorporates components from the projects listed below. The original copyright notices
|
||||
and the licenses under which Microsoft received such components are set forth below and are provided for
|
||||
informational purposes only. Microsoft reserves all rights not expressly granted herein, whether by
|
||||
implication, estoppel or otherwise.
|
||||
|
||||
This software includes parts of the Huggingface/Transformers Library (https://github.com/huggingface/transformers).
|
||||
State-of-the-art of Natural Language Processing for Jax, PyTorch and TensorFlow. Huggingface/Transformers library is
|
||||
licensed under Apache License 2.0, you can find a copy of this license at https://github.com/huggingface/transformers/blob/master/LICENSE
|
||||
|
||||
This software includes parts of the Tensorflow/Privacy Library (https://github.com/tensorflow/privacy).
|
||||
A library that includes implementations of TensorFlow optimizers for training machine learning models with
|
||||
differential privacy. The Tensorflow/Privacy library is licensed under Apache License 2.0,
|
||||
you can find a copy of this license at https://github.com/tensorflow/privacy/blob/master/LICENSE
|
||||
|
||||
This software includes parts of LEAF Library (https://github.com/TalwalkarLab/leaf).
|
||||
A Benchmark for Federated Settings. LEAF library is licensed under BSD 2-Clause License, you can find a copy
|
||||
of this license at https://github.com/TalwalkarLab/leaf/blob/master/LICENSE.md
|
||||
|
|
@ -0,0 +1,122 @@
|
|||
# FLUTE
|
||||
|
||||
Welcome to FLUTE (Federated Learning Utilities for Testing and Experimentation), a platform for conducting high-performance federated learning simulations.
|
||||
|
||||
## Quick Start
|
||||
|
||||
Install the requirements stated inside of `requirements.txt`. Ideally this sould be done inside of a virtual environment, for instance, using Anaconda.
|
||||
|
||||
```
|
||||
conda create -n FLUTE python==3.8
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
You will also need some MPI runtime such as OpenMPI (on Linux) or MS-MPI (on Windows). There is no `setup.py` as FLUTE is not currently distributed as a package, but instead meant to run from the root of the repository.
|
||||
|
||||
After this initial setup, you can use the data created for the integration test inside of `testing` for a first local run. Note that this data needs to be download manually, for more instructions please look at [the README file inside `testing`](testing/README.md).
|
||||
|
||||
```
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath ./testing/mockup -outputPath scratch -config testing/configs/hello_world_local.yaml -task nlg_gru
|
||||
```
|
||||
|
||||
This config uses 1 MPI node with 3 workers (1 server, 2 clients). The config file `testing/configs/hello_world_local.yaml` has some comments explaining the major sections and some important details; essentially, it consists in a very short experiment where a couple of iterations are done for just a few clients. A `scratch` folder will be created containing detailed logs.
|
||||
|
||||
## Documentation
|
||||
|
||||
The documentation is inside the `doc/sphinx` folder. To build the docs on Linux:
|
||||
|
||||
```
|
||||
$ pip install sphinx
|
||||
$ cd doc/sphinx
|
||||
$ make html
|
||||
```
|
||||
|
||||
On Windows, you can use the `make.bat` script.
|
||||
|
||||
## Architecture
|
||||
|
||||
The core client/server training code is inside the `core` folder.
|
||||
|
||||
- Server-side federation and global DP application takes place in `server.py`, more specifically in the `OptimizationServer.train()` method.
|
||||
- Client-side training updates take place in the static method `Client.process_round()`, inside `client.py`.
|
||||
|
||||
General FL orchestration code is in `federated.py`, but for most hub and spoke federation scenarios you won't need to touch this (unless you want to invest in optimizing MPI, which would be great!). Note that FLUTE does not implement secure aggregation since this is primarily a security feature for production scenarios; contributors are invited to add it for experimentation purposes.
|
||||
|
||||
The primary entry point for an experiment is in the script `e2e_trainer.py`. Primary config scripts for experiments are in `configs`. For instance, a basic training scenario for a next-word prediction task is set up in `hello_world_nlg_gru_json.yaml`.
|
||||
|
||||
Privacy accounting is expensive so the main parameters are logged and the actual accounting can be done offline. RDP privacy accounting is in `extensions/privacy/analysis.py`. A better accounting method is in the `dp-accountant` submodule.
|
||||
|
||||
## Customization
|
||||
|
||||
See `experiments` folder for illustrations of how dataloaders and models are customized. In order to in include a new experiment, the new scenario must be added following the same folder structure as `nlg_gru` and `mlm_bert`, naming the folder with the task.
|
||||
|
||||
## Experiments
|
||||
|
||||
Experiments are defined by YAML files, examples are provided in the `configs` folder. These can be run either locally or on AzureML.
|
||||
|
||||
For running experiments on AzureML, the CLI can help. You should first [install the CLI](https://docs.microsoft.com/en-us/azure/machine-learning/reference-azure-machine-learning-cli) (make sure you have v2) and [create a resource group and workspace](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli?tabs=createnewresources%2Cvnetpleconfigurationsv1cli). You can then create a compute cluster, type `az ml compute create -h` for more info. Afterwards, you should write an YAML file with instructions for the job; we provide a simple example below
|
||||
|
||||
```yaml
|
||||
experiment_name: basic_example
|
||||
description: Basic example of AML config for submitting FLUTE jobs
|
||||
code:
|
||||
local_path: .
|
||||
compute: azureml:Test
|
||||
environment:
|
||||
image: pytorch/pytorch:1.9.0-cuda10.2-cudnn7-devel
|
||||
inputs:
|
||||
data:
|
||||
folder: azureml://datastores/data/paths/cifar
|
||||
mode: rw_mount
|
||||
command: >
|
||||
apt -y update &&
|
||||
apt -y install openmpi-bin libopenmpi-dev openssh-client &&
|
||||
python3 -m pip install --upgrade pip &&
|
||||
python3 -m pip install -r requirements.txt &&
|
||||
mpiexec --allow-run-as-root -n 4 python e2e_trainer.py
|
||||
-outputPath=./outputs
|
||||
-dataPath={inputs.data}
|
||||
-task=classif_cnn
|
||||
-config=./experiments/classif_cnn/config.yaml
|
||||
```
|
||||
|
||||
You should replace `compute` with the name of the one you created before, and adjust the path of the datastore containing the data -- in the example above, we created a datastore called `data` and added to it a folder called `cifar`, which contained the two HDF5 files. The command passed above will install dependencies and then launch an MPI job with 4 threads, for the experiment defined in `experiments/classif_cnn`. Details on how to run a job using the AzureML CLI are given [in its documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-cli), but typically it suffices to set up the environment and type `az ml job create -f <name-of-the-yaml-file>`. In the same page of the documentation, you can also find more info about how to set up the YAML file above, in case other changes are needed.
|
||||
|
||||
Note that the `local_path` above is relative to the location of the YAML file, so setting it to `.` assumes it is in the same folder as `e2e_trainer.py`. All files on this folder will be uploaded to Azure, including hidden folders such as `.git`, so make sure to temporarily get rid of large files and folders that are not needed.
|
||||
|
||||
After launching the experiment, you can follow it on AzureML Studio, which prints logs, plots metrics and makes the output easily available after the experiment is finished.
|
||||
|
||||
## Privacy Accounting
|
||||
|
||||
Accounting is expensive, so we log all the privacy parameters so that accounting can be run offline. Best run on a Linux box with a GPU.
|
||||
In particular, we use a DP accountant from another Microsoft repository, which is included in ours as a submodule. For using this accountant, just follow the instructions below:
|
||||
|
||||
```
|
||||
$ git submodule update --init --recursive
|
||||
$ cd dp-accountant
|
||||
$ python setup.py install
|
||||
$ ./bin/compute-dp-epsilon --help
|
||||
usage: compute-dp-epsilon [-h] -p SAMPLING_PROBABILITY -s NOISE_MULTIPLIER -i ITERATIONS -d DELTA
|
||||
```
|
||||
## Third Party Notice
|
||||
|
||||
This software includes the files listed below from the Huggingface/Transformers Library (https://github.com/huggingface/transformers) as part of task performance and preprocessing pretrained models.
|
||||
|
||||
experiments/mlm_bert
|
||||
└── utils
|
||||
├── trainer_pt_utils.py
|
||||
└── trainer_utils.py
|
||||
|
||||
This software includes the file extensions/privacy/analysis.py from the Tensorflow/Privacy Library (https://github.com/tensorflow/privacy) as part of Renyi Differential Privacy implementation.
|
||||
|
||||
This software includes the script testing/build_vocab.py from LEAF Library (https://github.com/TalwalkarLab/leaf) to create the vocabulary needed to run a testing job.
|
||||
|
||||
For more information about third-party OSS licence, please refer to [NOTICE.txt](NOTICE.txt).
|
||||
|
||||
## Support
|
||||
|
||||
You are welcome to open issues on this repository related to bug reports and feature requests.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcomed and encouraged. For details on how to contribute, please see [CONTRIBUTING.md](CONTRIBUTING.md).
|
|
@ -0,0 +1,108 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""
|
||||
A collection of functions for checking the format of configuration values
|
||||
"""
|
||||
import os
|
||||
|
||||
def check_server_config(config, default_server_conf):
|
||||
|
||||
assert "server_config" in config, "server config setting is missing"
|
||||
|
||||
# Checking parameters for server-side training
|
||||
if "train" in config["server_config"]["data_config"]:
|
||||
if "train_data_server" in config["server_config"]["data_config"]["train"]:
|
||||
assert "server_replay_config" in config["server_config"], "Training dataset is defined on the server but training parameters are not set"
|
||||
assert "optimizer_config" in config["server_config"]["server_replay_config"], "Missing \"optimizer_config\" in server_replay server training config"
|
||||
assert "server_iterations" in config["server_config"]["server_replay_config"], "Missing \"server_iterations\" in server_replay server training config"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_server_conf.keys():
|
||||
if not key in config["server_config"]:
|
||||
config["server_config"][key] = default_server_conf[key]
|
||||
|
||||
server_type = config["server_config"]["type"]
|
||||
if not (server_type == "model_averaging" or \
|
||||
server_type == "optimization" or \
|
||||
server_type == "model_optimization" or \
|
||||
server_type == "cluster_finetuning" or \
|
||||
server_type == "cluster_parallel") :
|
||||
raise ValueError("Invalid server type {} in federated learning config".format(server_type))
|
||||
|
||||
assert config["server_config"]["best_model_criterion"] == "loss" or \
|
||||
config["server_config"]["best_model_criterion"] == "cer", \
|
||||
"Invalid model criterion {}".format(config["server_config"]["best_model_criterion"])
|
||||
|
||||
if server_type == "model_optimization" or server_type == "cluster_finetuning" or server_type == "cluster_parallel":
|
||||
assert "initial_lr_client" in config["server_config"], "Missing \"initial_lr_client\" in server config"
|
||||
assert "lr_decay_factor" in config["server_config"], "Missing \"lr_decay_factor\" in server config"
|
||||
assert "aggregate_median" in config["server_config"], "Missing \"aggregate_median\" in server config"
|
||||
|
||||
if "nbest_task_scheduler" in config["server_config"]:
|
||||
assert "num_tasks" in config["server_config"]["nbest_task_scheduler"], "Define \"num_tasks\" in [\"nbest_task_scheduler\"]"
|
||||
assert "iteration_per_task" in config["server_config"]["nbest_task_scheduler"], "Define \"iteration_per_task\" in [\"nbest_task_scheduler\"]"
|
||||
assert len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]) == len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]), \
|
||||
"Length mismatched: {}!={}".format(len(config["server_config"]["nbest_task_scheduler"]["num_tasks"]), len(config["server_config"]["nbest_task_scheduler"]["iteration_per_task"]))
|
||||
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["val"]:
|
||||
config["server_config"]["data_config"]["val"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["val"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["test"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["test"]["vocab_dict"])
|
||||
if 'vocab_dict' in config["server_config"]["data_config"]["test"]:
|
||||
config["server_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["server_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name_or_path']
|
||||
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['server_config']['data_config']['val']['model_name_or_path'] =config['model_config']['BERT']['model']['model_name']
|
||||
config['server_config']['data_config']['test']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
|
||||
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['server_config']['data_config']['val']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
|
||||
config['server_config']['data_config']['test']['process_line_by_line']=config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def check_client_config(config, default_client_conf):
|
||||
|
||||
assert "client_config" in config, "client config setting is missing"
|
||||
|
||||
# Setting the default values if missing
|
||||
for key in default_client_conf.keys():
|
||||
if not key in config["client_config"]:
|
||||
config["client_config"][key] = default_client_conf[key]
|
||||
|
||||
client_type = config["client_config"]["type"]
|
||||
if not (client_type == "gradient_computation" or client_type == "optimization"):
|
||||
raise ValueError("Invalid client option {} in federated learning config".format(client_type))
|
||||
|
||||
if not "ss_config" in config["client_config"]:
|
||||
config["client_config"]["ss_config"] = None
|
||||
|
||||
if "list_of_train_data" in config["client_config"]["data_config"]["train"] and "train_data" in config["client_config"]["data_config"]["train"]:
|
||||
raise ValueError("\"list_of_train_data\" and \"train_data\" cannot be defined at the same time")
|
||||
|
||||
assert "list_of_train_data" in config["client_config"]["data_config"]["train"] or "train_data" in config["client_config"]["data_config"]["train"], "Define either \"list_of_train_data\" and \"train_data\""
|
||||
|
||||
# Adjust path to vocab_dict
|
||||
data_path = config['data_path']
|
||||
if 'vocab_dict' in config["client_config"]["data_config"]["train"]:
|
||||
config["client_config"]["data_config"]["train"]["vocab_dict"]=os.path.join(data_path, config["client_config"]["data_config"]["train"]["vocab_dict"])
|
||||
|
||||
# BERT specific parameters
|
||||
if 'model_config' in config and 'train' in config['client_config']['data_config'] and 'BERT' in config['model_config']:
|
||||
if 'model_name_or_path' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name_or_path']
|
||||
else:
|
||||
config['client_config']['data_config']['train']['model_name_or_path']=config['model_config']['BERT']['model']['model_name']
|
||||
if 'process_line_by_line' in config['model_config']['BERT']['model']:
|
||||
config['client_config']['data_config']['train']['process_line_by_line'] =config['model_config']['BERT']['model']['process_line_by_line']
|
||||
|
||||
return config
|
|
@ -0,0 +1,158 @@
|
|||
# Basic configuration file for running mlm_bert example using json files in Azure ML.
|
||||
model_config:
|
||||
model_type: BERT
|
||||
model_folder: experiments/mlm_bert/model.py
|
||||
BERT:
|
||||
loader_type: text
|
||||
model:
|
||||
model_name: roberta-large
|
||||
cache_dir: ./cache_dir
|
||||
use_fast_tokenizer: False
|
||||
mask_token: <mask>
|
||||
task: mlm
|
||||
past_index: -1
|
||||
prediction_loss_only: false
|
||||
process_line_by_line: false
|
||||
training:
|
||||
seed: 12345
|
||||
label_smoothing_factor: 0
|
||||
batch_size: 64
|
||||
max_seq_length: 256
|
||||
|
||||
dp_config:
|
||||
enable_local_dp: false
|
||||
enable_global_dp: false
|
||||
global_sigma: 0.35
|
||||
grad_dir_eps: -1
|
||||
grad_mag_eps: -1
|
||||
weight_eps: 100
|
||||
max_grad: 0.008
|
||||
min_grad: 0.000001
|
||||
max_weight: 0.5
|
||||
min_weight: 0.0000001
|
||||
weight_dp: scalarDP
|
||||
|
||||
server_config:
|
||||
resume_from_checkpoint: true
|
||||
do_profiling: false
|
||||
fast_aggregation: true
|
||||
wantRL: false
|
||||
RL:
|
||||
RL_path_global: false
|
||||
marginal_update_RL: true
|
||||
RL_path: ./RL_models
|
||||
model_descriptor_RL: marginalUpdate
|
||||
network_params: 300,128,128,128,64,100
|
||||
initial_epsilon: 0.5
|
||||
final_epsilon: 0.0001
|
||||
epsilon_gamma: 0.90
|
||||
max_replay_memory_size: 1000
|
||||
minibatch_size: 16
|
||||
gamma: 0.99
|
||||
optimizer_config:
|
||||
lr: 0.0003
|
||||
type: adam
|
||||
amsgrad: true
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
step_size: 1
|
||||
gamma: 0.95
|
||||
optimizer_config:
|
||||
lr: 0.00001
|
||||
weight_decay: 0.01
|
||||
type: adamW
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
gamma: 1.0
|
||||
step_size: 1000
|
||||
val_freq: 4
|
||||
rec_freq: 16
|
||||
max_iteration: 10000
|
||||
num_clients_per_iteration: 200
|
||||
data_config:
|
||||
val:
|
||||
loader_type: text
|
||||
val_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
tokenizer_type_fast: False
|
||||
batch_size: 128
|
||||
max_seq_length: 256
|
||||
min_words_per_utt: 5
|
||||
max_samples_per_user: 5000
|
||||
mask_token: <mask>
|
||||
num_workers: 0
|
||||
prepend_datapath: false
|
||||
cache_dir: ./cache_dir
|
||||
train:
|
||||
loader_type: text
|
||||
train_data: null
|
||||
train_data_server: null
|
||||
desired_max_samples: null
|
||||
test:
|
||||
loader_type: text
|
||||
test_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
tokenizer_type_fast: False
|
||||
batch_size: 128
|
||||
max_seq_length: 256
|
||||
max_samples_per_user: 5000
|
||||
mask_token: <mask>
|
||||
num_workers: 0
|
||||
prepend_datapath: false
|
||||
cache_dir: ./cache_dir
|
||||
type: model_optimization
|
||||
aggregate_median: softmax
|
||||
weight_train_loss: grad_mean_loss
|
||||
softmax_beta: 1.00
|
||||
initial_lr_client: 0.00001
|
||||
lr_decay_factor: 1.0
|
||||
best_model_criterion: loss
|
||||
fall_back_to_best_model: false
|
||||
server_replay_config:
|
||||
server_iterations: 50
|
||||
optimizer_config:
|
||||
lr: 0.00002
|
||||
amsgrad: true
|
||||
type: adam
|
||||
|
||||
client_config:
|
||||
meta_learning: basic
|
||||
stats_on_smooth_grad: true
|
||||
ignore_subtask: false
|
||||
copying_train_data: false
|
||||
do_profiling: false
|
||||
data_config:
|
||||
train:
|
||||
loader_type: text
|
||||
list_of_train_data: <add path to data here>
|
||||
task: mlm
|
||||
mlm_probability: 0.25
|
||||
tokenizer_type_fast: False
|
||||
batch_size: 24
|
||||
max_seq_length: 256
|
||||
min_words_per_utt: 5
|
||||
desired_max_samples: 5000
|
||||
mask_token: <mask>
|
||||
num_workers: 0
|
||||
num_frames: 0
|
||||
max_grad_norm: 15.0
|
||||
prepend_datapath: false
|
||||
cache_dir: ./cache_dir
|
||||
pin_memory: true
|
||||
type: optimization
|
||||
meta_optimizer_config:
|
||||
lr: 0.01
|
||||
type: adam
|
||||
optimizer_config:
|
||||
type: adamW
|
||||
weight_decay: 0.01
|
||||
amsgrad: true
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
step_size: 2
|
||||
gamma: 1.0
|
|
@ -0,0 +1,125 @@
|
|||
# Basic configuration file for running locally nlg_gru example using json files.
|
||||
model_config:
|
||||
model_type: GRU
|
||||
model_folder: experiments/nlg_gru/model.py
|
||||
pretrained_model_path: <add path to pretrained weights here>
|
||||
embed_dim: 160
|
||||
vocab_size: 10000
|
||||
hidden_dim: 512
|
||||
OOV_correct: false
|
||||
|
||||
dp_config:
|
||||
enable_local_dp: false
|
||||
|
||||
privacy_metrics_config:
|
||||
apply_metrics: false
|
||||
|
||||
server_config:
|
||||
wantRL: false
|
||||
resume_from_checkpoint: true
|
||||
do_profiling: false
|
||||
optimizer_config:
|
||||
type: lamb
|
||||
lr: 0.1
|
||||
weight_decay: 0.005
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
gamma: 1.0
|
||||
step_size: 100
|
||||
val_freq: 2
|
||||
rec_freq: 4
|
||||
max_iteration: 11
|
||||
num_clients_per_iteration: 10
|
||||
data_config:
|
||||
val:
|
||||
batch_size: 2048
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
val_data: <add path to data here>
|
||||
vocab_dict: <add path to vocab here>
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
num_frames: 2400
|
||||
max_batch_size: 2048
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
train:
|
||||
batch_size: 128
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
train_data_server: null
|
||||
vocab_dict: <add path to vocab here>
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
num_frames: 2400
|
||||
desired_max_samples: 500
|
||||
max_grad_norm: 10.0
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
test:
|
||||
batch_size: 2048
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
train_data_server: null
|
||||
test_data: <add path to data here>
|
||||
vocab_dict: <add path to vocab here>
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
max_batch_size: 2048
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
type: model_optimization
|
||||
aggregate_median: softmax
|
||||
weight_train_loss: train_loss
|
||||
softmax_beta: 20.0
|
||||
initial_lr_client: 1.0
|
||||
lr_decay_factor: 1.0
|
||||
best_model_criterion: loss
|
||||
fall_back_to_best_model: false
|
||||
server_replay_config:
|
||||
server_iterations: 50
|
||||
optimizer_config:
|
||||
type: adam
|
||||
lr: 0.00002
|
||||
amsgrad: true
|
||||
|
||||
client_config:
|
||||
meta_learning: basic
|
||||
stats_on_smooth_grad: true
|
||||
ignore_subtask: false
|
||||
num_skips_threshold: 10
|
||||
copying_train_data: false
|
||||
do_profiling: false
|
||||
data_config:
|
||||
train:
|
||||
batch_size: 64
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
list_of_train_data: <add path to data here>
|
||||
vocab_dict: <add path to vocab here>
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
desired_max_samples: 50000
|
||||
max_grad_norm: 20.0
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
type: optimization
|
||||
meta_optimizer_config:
|
||||
lr: 1.0
|
||||
type: sgd
|
||||
optimizer_config:
|
||||
type: sgd
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
step_size: 1
|
||||
gamma: 1.0
|
|
@ -0,0 +1,569 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
'''
|
||||
The Client object is short-lived, instantiated inside of worker 0 and moved to
|
||||
workers 1 to N for processing a given client's data. It's main method is the
|
||||
`process_round` function, used to update the model given a client's data.
|
||||
'''
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
import h5py
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Internal imports
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
import core.federated as federated
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
||||
|
||||
from .trainer import (
|
||||
Trainer,
|
||||
run_validation_generic,
|
||||
set_component_wise_lr,
|
||||
)
|
||||
from utils import (
|
||||
ScheduledSamplingScheduler,
|
||||
make_optimizer,
|
||||
print_rank,
|
||||
scrub_empty_clients,
|
||||
)
|
||||
from utils.dataloaders_utils import (
|
||||
make_train_dataloader,
|
||||
make_val_dataloader,
|
||||
make_test_dataloader,
|
||||
)
|
||||
|
||||
import extensions.privacy
|
||||
from extensions.privacy import metrics as privacy_metrics
|
||||
from extensions import quant_model
|
||||
from experiments import make_model
|
||||
|
||||
|
||||
# A per-process cache of the training data, so clients don't have to repeatedly re-load
|
||||
# TODO: deprecate this in favor of passing dataloader around
|
||||
_data_dict = None
|
||||
_file_ext = None
|
||||
|
||||
class Client:
|
||||
def __init__(self, client_id, config, send_gradients, dataloader):
|
||||
'''
|
||||
Client side processing: computing gradients, update the model and send them back to the server
|
||||
|
||||
Args:
|
||||
client_id (int): identifier for grabbing that client's data.
|
||||
config (dict): dictionary with parameters loaded from config file.
|
||||
send_gradients (bool): if True, model gradients are sent back;
|
||||
otherwise, model weights are sent back.
|
||||
dataloader (torch.utils.data.DataLoader): dataloader that generates
|
||||
training data for the client.
|
||||
'''
|
||||
super().__init__()
|
||||
|
||||
self.client_id = client_id
|
||||
self.client_data = self.get_data(client_id,dataloader)
|
||||
self.config = copy.deepcopy(config)
|
||||
self.send_gradients = send_gradients
|
||||
|
||||
def get_client_data(self):
|
||||
'''"Getter" method that returns all object's attributes at once.'''
|
||||
return self.client_id, self.client_data, self.config, self.send_gradients
|
||||
|
||||
@staticmethod
|
||||
def get_num_users(filename):
|
||||
'''Count users given a JSON or HDF5 file.
|
||||
|
||||
This function will fill the global data dict. Ideally we want data
|
||||
handling not to happen here and only at the dataloader, that will be the
|
||||
behavior in future releases.
|
||||
|
||||
Args:
|
||||
filename (str): path to file containing data.
|
||||
'''
|
||||
|
||||
global _data_dict
|
||||
global _file_ext
|
||||
_file_ext = filename.split('.')[-1]
|
||||
|
||||
try:
|
||||
if _file_ext == 'json' or _file_ext == 'txt':
|
||||
if _data_dict is None:
|
||||
print_rank('Reading training data dictionary from JSON')
|
||||
with open(filename,'r') as fid:
|
||||
_data_dict = json.load(fid) # pre-cache the training data
|
||||
_data_dict = scrub_empty_clients(_data_dict) # empty clients MUST be scrubbed here to match num_clients in the entry script
|
||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
||||
|
||||
elif _file_ext == 'hdf5':
|
||||
print_rank('Reading training data dictionary from HDF5')
|
||||
_data_dict = h5py.File(filename, 'r')
|
||||
print_rank('Read training data dictionary', loglevel=logging.DEBUG)
|
||||
|
||||
except:
|
||||
raise ValueError('Error reading training file. Please make sure the format is allowed')
|
||||
|
||||
num_users = len(_data_dict['users'])
|
||||
return num_users
|
||||
|
||||
@staticmethod
|
||||
def get_data(client_id, dataloader):
|
||||
'''Load data from the dataloader given the client's id.
|
||||
|
||||
This function will load the global data dict. Ideally we want data
|
||||
handling not to happen here and only at the dataloader, that will be the
|
||||
behavior in future releases.
|
||||
|
||||
Args:
|
||||
client_id (int or list): identifier(s) for grabbing client's data.
|
||||
dataloader (torch.utils.data.DataLoader): dataloader that
|
||||
provides the trianing
|
||||
'''
|
||||
|
||||
# Auxiliary function for decoding only when necessary
|
||||
decode_if_str = lambda x: x.decode() if isinstance(x, bytes) else x
|
||||
|
||||
# During training, client_id will be always an integer
|
||||
if isinstance(client_id, int):
|
||||
user_name = decode_if_str(_data_dict['users'][client_id])
|
||||
num_samples = _data_dict['num_samples'][client_id]
|
||||
|
||||
if _file_ext == 'hdf5':
|
||||
arr_data = [decode_if_str(e) for e in _data_dict['user_data'][user_name]['x'][()]]
|
||||
user_data = {'x': arr_data}
|
||||
elif _file_ext == 'json' or _file_ext == 'txt':
|
||||
user_data = _data_dict['user_data'][user_name]
|
||||
|
||||
if 'user_data_label' in _data_dict: # supervised problem
|
||||
labels = _data_dict['user_data_label'][user_name]
|
||||
if _file_ext == 'hdf5': # transforms HDF5 Dataset into Numpy array
|
||||
labels = labels[()]
|
||||
|
||||
return edict({'users': [user_name],
|
||||
'user_data': {user_name: user_data},
|
||||
'num_samples': [num_samples],
|
||||
'user_data_label': {user_name: labels}})
|
||||
else:
|
||||
print_rank('no labels present, unsupervised problem', loglevel=logging.DEBUG)
|
||||
return edict({'users': [user_name],
|
||||
'user_data': {user_name: user_data},
|
||||
'num_samples': [num_samples]})
|
||||
|
||||
# During validation and test, client_id might be a list of integers
|
||||
elif isinstance(client_id, list):
|
||||
if 'user_data_label' in _data_dict:
|
||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}, 'user_data_label': {}}
|
||||
else:
|
||||
users_dict = {'users': [], 'num_samples': [], 'user_data': {}}
|
||||
|
||||
for client in client_id:
|
||||
user_name = decode_if_str(dataloader.dataset.user_list[client])
|
||||
users_dict['users'].append(user_name)
|
||||
users_dict['num_samples'].append(dataloader.dataset.num_samples[client])
|
||||
|
||||
if _file_ext == 'hdf5':
|
||||
arr_data = dataloader.dataset.user_data[user_name]['x']
|
||||
arr_decoded = [decode_if_str(e) for e in arr_data]
|
||||
users_dict['user_data'][user_name] = {'x': arr_decoded}
|
||||
elif _file_ext == 'json':
|
||||
users_dict['user_data'][user_name] = {'x': dataloader.dataset.user_data[user_name]['x']}
|
||||
elif _file_ext == 'txt': # using a different line for .txt since our files have a different structure
|
||||
users_dict['user_data'][user_name] = dataloader.dataset.user_data[user_name]
|
||||
|
||||
if 'user_data_label' in _data_dict:
|
||||
labels = dataloader.dataset.user_data_label[user_name]
|
||||
if _file_ext == 'hdf5':
|
||||
labels = labels[()]
|
||||
users_dict['user_data_label'][user_name] = labels
|
||||
|
||||
return users_dict
|
||||
|
||||
@staticmethod
|
||||
def run_testvalidate(client_data, server_data, mode, model):
|
||||
'''Called by worker to run test/validation sample on a client.
|
||||
|
||||
This functions assumes set_model_for_round has already been called to
|
||||
push the model to the client (see federated.py).
|
||||
|
||||
Args:
|
||||
client_data (tuple): client data and config. It is a tuple with 4
|
||||
components; importantly, the second component is a dict
|
||||
containing the data, and the third component is a dict with the
|
||||
config parsed from the YAML file.
|
||||
server_data (tuple): server data (model parameters mostly). It is
|
||||
a tuple with 3 components; importantly, the third component
|
||||
consists of the current model parameters.
|
||||
mode (str): whether to `test` or `validate`.
|
||||
model (torch.nn.Module): actual model without parameters.
|
||||
'''
|
||||
|
||||
# Process inputs and initialize variables
|
||||
_, data_strct, config, _ = client_data
|
||||
_, _, model_parameters = server_data
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
begin = time.time()
|
||||
|
||||
# Use the server's data config since we're distributing test/validate from the server
|
||||
data_config = config['server_config']['data_config'][mode]
|
||||
want_logits = data_config.get('wantLogits', False)
|
||||
|
||||
# Create dataloader
|
||||
dataloader = None
|
||||
print_rank('making dataloader with task {}'.format(config['server_config']['task']), loglevel=logging.DEBUG)
|
||||
if mode == 'test':
|
||||
dataloader = make_test_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
|
||||
elif mode == 'val':
|
||||
dataloader = make_val_dataloader(data_config, data_path=None, task=config['server_config']['task'], data_strct=data_strct)
|
||||
|
||||
# Set model parameters
|
||||
n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
|
||||
print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
|
||||
model.cuda() if torch.cuda.is_available() else model
|
||||
for p, data in zip(model.parameters(), model_parameters):
|
||||
p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()
|
||||
print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)
|
||||
|
||||
# Compute output and metrics on the test or validation data
|
||||
num_instances = sum(data_strct['num_samples'])
|
||||
print_rank(f'Validating {num_instances}', loglevel=logging.DEBUG)
|
||||
output, loss, cer = run_validation_generic(model, dataloader)
|
||||
if not want_logits:
|
||||
output = None
|
||||
|
||||
return output, (loss, cer, num_instances)
|
||||
|
||||
@staticmethod
|
||||
def process_round(client_data, server_data, model, data_path, eps=1e-7):
|
||||
'''Compute gradients given client's data and update model.
|
||||
|
||||
Args:
|
||||
client_data (tuple): client data and config. It is a tuple
|
||||
consisting of 4 components: an int indicating the client's id, a
|
||||
dict containing that client's data, a dict with the config
|
||||
parsed from the YAML file, and a bool indicating whether or not
|
||||
gradients should be sent.
|
||||
server_data (tuple): server data (model parameters mostly). It is
|
||||
a tuple consisting of 3 components; importantly, the first is
|
||||
a float giving the client's learning rate, and the third a list
|
||||
of torch.Tensor's with current model parameters. The second one
|
||||
is not used, right now.
|
||||
model (torch.nn.Module): actual model without parameters.
|
||||
data_path (str): where to get data from.
|
||||
eps (float): lower bound for aggregation weights.
|
||||
'''
|
||||
|
||||
# Ensure the client is assigned to the correct GPU
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == federated.size():
|
||||
torch.cuda.set_device(federated.local_rank())
|
||||
|
||||
# Process inputs and initialize variables
|
||||
client_id, data_strct, config, send_gradients = client_data
|
||||
initial_lr, _, model_parameters = server_data
|
||||
config = copy.deepcopy(config)
|
||||
|
||||
model_config = config['model_config']
|
||||
client_config = config['client_config']
|
||||
dp_config = config.get('dp_config', None)
|
||||
data_config = client_config['data_config']['train']
|
||||
task = client_config.get('task', {})
|
||||
quant_threshold = client_config.get('quant_thresh', None)
|
||||
quant_bits = client_config.get('quant_bits', 10)
|
||||
trainer_config = client_config.get('trainer_config', {})
|
||||
privacy_metrics_config = config.get('privacy_metrics_config', None)
|
||||
|
||||
begin = time.time()
|
||||
client_stats = {}
|
||||
|
||||
# Update the location of the training file
|
||||
data_config['list_of_train_data'] = os.path.join(data_path, data_config['list_of_train_data'])
|
||||
|
||||
user = data_strct['users'][0]
|
||||
if 'user_data_label' in data_strct.keys(): # supervised case
|
||||
input_strct = edict({
|
||||
'users': [user],
|
||||
'user_data': {user: data_strct['user_data'][user]},
|
||||
'num_samples': [data_strct['num_samples'][0]],
|
||||
'user_data_label': {user: data_strct['user_data_label'][user]}
|
||||
})
|
||||
else:
|
||||
input_strct = edict({
|
||||
'users': [user],
|
||||
'user_data': {user: data_strct['user_data'][user]},
|
||||
'num_samples': [data_strct['num_samples'][0]]
|
||||
})
|
||||
|
||||
print_rank('Loading : {}-th client with name: {}, {} samples, {}s elapsed'.format(
|
||||
client_id, user, data_strct['num_samples'][0], time.time() - begin), loglevel=logging.INFO)
|
||||
|
||||
# Estimate stats on the smooth gradient
|
||||
stats_on_smooth_grad = client_config.get('stats_on_smooth_grad', False)
|
||||
|
||||
# Get dataloaders
|
||||
train_dataloader = make_train_dataloader(data_config, data_path, task=task, clientx=0, data_strct=input_strct)
|
||||
val_dataloader = make_val_dataloader(data_config, data_path)
|
||||
|
||||
# Instantiate the model object
|
||||
if model is None:
|
||||
model = make_model(model_config,
|
||||
dataloader_type=train_dataloader.__class__.__name__,
|
||||
input_dim=data_config['input_dim'],
|
||||
vocab_size=train_dataloader.vocab_size)
|
||||
|
||||
# Set model parameters
|
||||
n_layers, n_params = len([f for f in model.parameters()]), len(model_parameters)
|
||||
print_rank(f'Copying model parameters... {n_layers}/{n_params}', loglevel=logging.DEBUG)
|
||||
model.cuda() if torch.cuda.is_available() else model
|
||||
for p, data in zip(model.parameters(), model_parameters):
|
||||
p.data = data.detach().clone().cuda() if torch.cuda.is_available() else data.detach().clone()
|
||||
print_rank(f'Model setup complete. {time.time() - begin}s elapsed.', loglevel=logging.DEBUG)
|
||||
|
||||
# Fix parameters of layers
|
||||
if 'updatable_names' in trainer_config:
|
||||
set_component_wise_lr(model, client_config['optimizer_config'], trainer_config['updatable_names'])
|
||||
|
||||
# Create the optimizer on the workers
|
||||
# NOTE: the server dictates the learning rate for the clients
|
||||
client_config['optimizer_config']['lr'] = initial_lr
|
||||
optimizer = make_optimizer(client_config['optimizer_config'], model)
|
||||
|
||||
# Make the scheduled sampling scheduler
|
||||
ss_scheduler = None
|
||||
if 'ss_config' in client_config and client_config['ss_config'] is not None:
|
||||
ss_scheduler = ScheduledSamplingScheduler(model=model, **client_config['ss_config'])
|
||||
|
||||
# Make the trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
val_dataloader=val_dataloader,
|
||||
server_replay_config =client_config,
|
||||
max_grad_norm=client_config['data_config']['train'].get('max_grad_norm', None),
|
||||
anneal_config=client_config['annealing_config'] if 'annealing_config' in client_config else None,
|
||||
num_skips_threshold=client_config['num_skips_threshold'] if 'num_skips_threshold' in client_config else -1,
|
||||
ignore_subtask=client_config['ignore_subtask']
|
||||
)
|
||||
|
||||
if trainer.optimizer is not None:
|
||||
initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())
|
||||
|
||||
annealing_config = client_config['annealing_config'] if 'annealing_config' in client_config else None
|
||||
|
||||
assert 'desired_max_samples' in client_config['data_config']['train'], 'Missing \'desired_max_samples\' entry in data config parameter'
|
||||
desired_max_samples = client_config['data_config']['train']['desired_max_samples']
|
||||
|
||||
if trainer.optimizer is not None: # reset the optimizer state
|
||||
if initial_lr > 0:
|
||||
trainer.optimizer.param_groups[0].update({'lr': initial_lr})
|
||||
initial_optimizer_state = copy.deepcopy(trainer.optimizer.state_dict())
|
||||
trainer.reset_optimizer(initial_optimizer_state, annealing_config)
|
||||
|
||||
# Mark the end of setup
|
||||
end = time.time()
|
||||
client_stats['setup'] = end-begin
|
||||
print_rank(f'Client setup cost {client_stats["setup"]}s', loglevel=logging.DEBUG)
|
||||
begin_training = end
|
||||
|
||||
# Training begins here
|
||||
trainer.model.train()
|
||||
trainer.model.zero_grad()
|
||||
|
||||
# Save the client batches if we want to evaluate the privacy metrics
|
||||
apply_privacy_metrics = (False if privacy_metrics_config is None else privacy_metrics_config['apply_metrics'])
|
||||
|
||||
# This is where training actually happens
|
||||
train_loss, num_samples = trainer.train_desired_samples(desired_max_samples=desired_max_samples, apply_privacy_metrics=apply_privacy_metrics)
|
||||
print_rank('client={}: training loss={}'.format(client_id, train_loss), loglevel=logging.DEBUG)
|
||||
|
||||
# From now on we just post-process the results of the training:
|
||||
# get weights for aggregation, do quantization/DP, compute statistics...
|
||||
|
||||
# Estimate the pseudo-gradient
|
||||
for p, data in zip(trainer.model.parameters(), model_parameters):
|
||||
data = data.cuda() if torch.cuda.is_available() else data
|
||||
p.grad = data - p.data
|
||||
|
||||
# Get weights for aggregation, potentially using DGA
|
||||
weight = 1.0
|
||||
add_weight_noise = False
|
||||
|
||||
def filter_weight(weight):
|
||||
'''Handles aggregation weights if something messed them up'''
|
||||
print_rank('Client Weight BEFORE filtering: {}'.format(weight), loglevel=logging.DEBUG)
|
||||
if np.isnan(weight) or not np.isfinite(weight):
|
||||
weight = 0.0
|
||||
elif weight > 100:
|
||||
weight = 100
|
||||
print_rank('Client Weights AFTER filtering: {}'.format(weight), loglevel=logging.DEBUG)
|
||||
return weight
|
||||
|
||||
# Reset gradient stats and recalculate them on the smooth/pseudo gradient
|
||||
if stats_on_smooth_grad:
|
||||
trainer.reset_gradient_power()
|
||||
trainer.estimate_sufficient_stats()
|
||||
|
||||
# Estimate gradient magnitude mean/var
|
||||
mean_grad = trainer.sufficient_stats['sum'] / trainer.sufficient_stats['n']
|
||||
mag_grad = np.sqrt(trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'])
|
||||
var_grad = trainer.sufficient_stats['sq_sum'] / trainer.sufficient_stats['n'] - mag_grad ** 2
|
||||
norm_grad = np.sqrt(trainer.sufficient_stats['sq_sum'])
|
||||
|
||||
# If we are using softmax based on training loss, it needs DP noise
|
||||
if send_gradients and config['server_config']['aggregate_median'] == 'softmax':
|
||||
# This matters when DP is required
|
||||
add_weight_noise = True
|
||||
|
||||
if 'weight_train_loss' not in config['server_config'] or config['server_config']['weight_train_loss'] == 'train_loss':
|
||||
training_weight = train_loss / num_samples
|
||||
elif config['server_config']['weight_train_loss'] == 'mag_var_loss':
|
||||
training_weight = var_grad
|
||||
elif config['server_config']['weight_train_loss'] == 'mag_mean_loss':
|
||||
training_weight = mean_grad
|
||||
else:
|
||||
training_weight = mag_grad
|
||||
|
||||
try:
|
||||
weight = math.exp(-config['server_config']['softmax_beta']*training_weight)
|
||||
except:
|
||||
print_rank('There is an issue with the weight -- Reverting to {}'.format(eps), loglevel=logging.DEBUG)
|
||||
weight = eps # TODO: set to min_weight?
|
||||
weight = filter_weight(weight)
|
||||
|
||||
# Add local DP noise here. Note at this point the model parameters are the gradient (diff_data above)
|
||||
# When weight == 0, something went wrong. So we'll skip adding noise and return a zero gradient.
|
||||
if weight > 0.0 and dp_config is not None and dp_config.get('enable_local_dp', False):
|
||||
# Unroll the network grads as 1D vectors
|
||||
flat_grad, params_ids = extensions.privacy.unroll_network(trainer.model.named_parameters(), select_grad=True)
|
||||
grad_norm = flat_grad.norm().cpu().item()
|
||||
|
||||
if dp_config['eps'] < 0:
|
||||
# clip, but don't add noise
|
||||
if grad_norm > dp_config['max_grad']:
|
||||
flat_grad = flat_grad * (dp_config['max_grad'] / grad_norm)
|
||||
extensions.privacy.update_network(trainer.model.named_parameters(), params_ids, flat_grad, apply_to_grad=True)
|
||||
|
||||
else:
|
||||
# Get Gaussian LDP noise
|
||||
dp_eps = dp_config['eps']
|
||||
delta = dp_config.get('delta', 1e-7) # TODO pre-compute in config
|
||||
weight_ = weight
|
||||
|
||||
# Scaling the weight down so we don't impact the noise too much
|
||||
weight = dp_config.get('weight_scaler', 1) * weight
|
||||
weight = min(dp_config['max_weight'], weight)
|
||||
flat_noisy_grad = dp_config['max_grad'] * (flat_grad / flat_grad.norm())
|
||||
max_sensitivity = np.sqrt(dp_config['max_grad']**2 + (dp_config['max_weight']**2 if add_weight_noise else 0.0))
|
||||
flat_noisy_grad = torch.cat([flat_noisy_grad, torch.tensor([weight], device=flat_noisy_grad.device)], dim=0)
|
||||
flat_noisy_grad, sigma = extensions.privacy.add_gaussian_noise(flat_noisy_grad, dp_eps, max_sensitivity, delta)
|
||||
weight = min(max(flat_noisy_grad[-1].item(), dp_config['min_weight']), dp_config['max_weight'])
|
||||
|
||||
# Scaling the weight back up after noise addition (This is a DP-protect transformation)
|
||||
weight = weight / dp_config.get('weight_scaler', 1)
|
||||
if not add_weight_noise:
|
||||
weight = weight_
|
||||
flat_noisy_grad = flat_noisy_grad[:-1]
|
||||
|
||||
print_rank('Cosine error from noise {}'.format(torch.nn.functional.cosine_similarity(flat_grad, flat_noisy_grad, dim=0)), loglevel=logging.DEBUG)
|
||||
print_rank('Error from noise is {}'.format((flat_grad-flat_noisy_grad).norm()), loglevel=logging.DEBUG)
|
||||
print_rank('weight is {} and noisy weight is {}'.format(weight_, weight), loglevel=logging.DEBUG)
|
||||
|
||||
# Return back to the network
|
||||
extensions.privacy.update_network(trainer.model.named_parameters(), params_ids, flat_noisy_grad, apply_to_grad=True)
|
||||
|
||||
# In all other cases we can compute the weight after adding noise
|
||||
if send_gradients and not add_weight_noise:
|
||||
assert config['server_config']['aggregate_median'] == 'mean'
|
||||
assert weight == 1.0
|
||||
|
||||
if send_gradients:
|
||||
# Weight the gradient and remove gradients of the layers we want to freeze
|
||||
for n, p in trainer.model.named_parameters():
|
||||
p.grad = weight * p.grad
|
||||
if model_config.get('freeze_layer', None) and n == model_config['freeze_layer']:
|
||||
print_rank('Setting gradient to zero for layer: {}'.format(n), loglevel=logging.INFO)
|
||||
p.grad.mul_(0)
|
||||
|
||||
# Gradient quantization step -- if quant_threshold is None, the code returns without doing anything
|
||||
quant_model(trainer.model, quant_threshold=quant_threshold, quant_bits=quant_bits, global_stats=False)
|
||||
|
||||
# Mark that training (including post-processing) is finished
|
||||
end = time.time()
|
||||
client_stats['training'] = end - begin_training
|
||||
client_stats['full cost'] = end - begin
|
||||
print_rank(f'Client training cost {end - begin_training}s', loglevel=logging.DEBUG)
|
||||
print_rank(f'Client full cost {end - begin}s', loglevel=logging.DEBUG)
|
||||
|
||||
# Create dictionary that is sent back to server
|
||||
client_output = {
|
||||
'cs': client_stats,
|
||||
'tl': train_loss,
|
||||
'wt': weight,
|
||||
'mg': mag_grad,
|
||||
'vg': var_grad,
|
||||
'ng': mean_grad,
|
||||
'rg': norm_grad,
|
||||
'ns': num_samples
|
||||
}
|
||||
|
||||
# Apply privacy metrics
|
||||
if privacy_metrics_config and privacy_metrics_config['apply_metrics']:
|
||||
print_rank('Applying privacy metrics', loglevel=logging.DEBUG)
|
||||
|
||||
privacy_stats = {'Dropped clients': 0}
|
||||
batches = trainer.cached_batches
|
||||
trainer.cached_batches = []
|
||||
gradients = extensions.privacy.unroll_network(model.named_parameters(), select_grad=True)[0]
|
||||
|
||||
if privacy_metrics_config['apply_indices_extraction']:
|
||||
allowed_word_rank = privacy_metrics_config.get('allowed_word_rank', 9000)
|
||||
embed_dim, vocab_size = model_config['embed_dim'], model_config['vocab_size']
|
||||
overlap, indices = privacy_metrics.extract_indices_from_embeddings(gradients, batches, embed_dim, vocab_size)
|
||||
|
||||
max_overlap = privacy_metrics_config.get('max_allowed_overlap', None)
|
||||
if max_overlap is not None and overlap > max_overlap:
|
||||
print_rank('Removing this client because we extracted {}% words and the maximum allowed is {}%'.format(overlap * 100, max_overlap * 100))
|
||||
client_output['wt'] = 0.0
|
||||
privacy_stats['Dropped clients'] = 1
|
||||
|
||||
privacy_stats['Extracted indices percentage'] = overlap
|
||||
privacy_stats['Words percentage above ' + str(allowed_word_rank) + ' word rank'] = (indices > allowed_word_rank).mean() if len(indices)>0 else 0
|
||||
|
||||
if privacy_metrics_config['apply_leakage_metric']:
|
||||
print_rank('Applying leakage metric', loglevel=logging.DEBUG)
|
||||
|
||||
orig_params = {n: p for (n, _), p in zip(trainer.model.named_parameters(), model_parameters)}
|
||||
max_ratio = np.exp(privacy_metrics_config['max_leakage'])
|
||||
optim_config = privacy_metrics_config['attacker_optimizer_config']
|
||||
is_leakage_weighted = privacy_metrics_config['is_leakage_weighted']
|
||||
|
||||
leakage = privacy_metrics.practical_epsilon_leakage(orig_params,
|
||||
trainer.model, batches, is_leakage_weighted, max_ratio, optim_config)
|
||||
print_rank('privacy leakage: {}'.format(leakage), loglevel=logging.DEBUG)
|
||||
|
||||
max_leakage = privacy_metrics_config.get('max_allowed_leakage', None)
|
||||
if max_leakage is not None and leakage > max_leakage:
|
||||
print_rank('Removing this client because the information leakage/practical epsilon is {} and the maximum allowed is {}'.format(leakage, max_leakage))
|
||||
client_output['wt'] = 0.0
|
||||
privacy_stats['Dropped clients'] = 1
|
||||
|
||||
privacy_stats['Practical epsilon (Max leakage)'] = leakage
|
||||
|
||||
client_output['ps'] = privacy_stats
|
||||
|
||||
# Finally, we add the actual model gradients or parameters to the output dictionary
|
||||
if send_gradients:
|
||||
client_output['gr'] = [p.grad.to(torch.device('cpu')) for p in trainer.model.parameters()]
|
||||
else:
|
||||
client_output['pm'] = [p.data.to(torch.device('cpu')) for p in trainer.model.parameters()]
|
||||
|
||||
client_output['ts'] = time.time()
|
||||
return client_output
|
|
@ -0,0 +1,436 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import copy
|
||||
import cProfile
|
||||
import gc
|
||||
import os
|
||||
import pickle
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from mpi4py import MPI
|
||||
|
||||
from utils import (
|
||||
print_rank,
|
||||
print_profiler
|
||||
)
|
||||
from utils.queue import process_in_parallel
|
||||
|
||||
|
||||
SPLIT_SIZE = 512 * 1024 * 1024 # messages above this size (in bytes) are split
|
||||
|
||||
COMMAND_UPDATE = "update"
|
||||
COMMAND_TRAIN = "train"
|
||||
COMMAND_TERMINATE = "terminate"
|
||||
COMMAND_TESTVAL = "testvalidate"
|
||||
|
||||
|
||||
def rank():
|
||||
"""Return rank of node"""
|
||||
return MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
def local_rank():
|
||||
"""Return local rank of MPI node"""
|
||||
assert (
|
||||
"OMPI_COMM_WORLD_LOCAL_RANK" in os.environ
|
||||
), "local rank can only be determined when using OpenMPI"
|
||||
return int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
|
||||
|
||||
def size():
|
||||
"""Returns number of MPI nodes including server"""
|
||||
return MPI.COMM_WORLD.Get_size()
|
||||
|
||||
|
||||
class Server:
|
||||
"""Server object responsible for orchestration and aggregation.
|
||||
|
||||
The Server is one of the two objects that may exist inside of a thread, all
|
||||
throughout its execution (the other being the Worker). At every round, the
|
||||
Server samples clients and sends their data for an available Worker to process.
|
||||
The Workers then each produce a new model, and all models are sent to the Server
|
||||
for aggregation.
|
||||
|
||||
The methods defined here are related to orchestration only, the aggregation
|
||||
will be done by a different object which inherits from this one.
|
||||
|
||||
Notes:
|
||||
This class has no :code`__init__` method, and all its methods are static.
|
||||
It thus only serves the purpose of grouping the methods, but nothing
|
||||
is actually stored inside of the object.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=None):
|
||||
"""Perform execution of client code on the worker nodes.
|
||||
|
||||
This function does the following:
|
||||
1. It sends the server_data to all workers
|
||||
2. For each client:
|
||||
2a. It sends the function process_round of the client
|
||||
to a free worker.
|
||||
2b. It calls get_client_data on the client.
|
||||
2c. It triggers the execution of the payload_fn on the
|
||||
worker with parameters server_data and client_data.
|
||||
|
||||
Notes:
|
||||
This function yields the gradients of different clients
|
||||
as they are received. Therefore, the order of the results generally
|
||||
does not correspond to the order of the clients.
|
||||
|
||||
Args:
|
||||
clients (list): list of clients to be processed.
|
||||
server_data (dict): server data sent to the workers and passed to
|
||||
clients, typically includes the global model at that step.
|
||||
payload_fn (callback): instructions for worker to execute.
|
||||
clients_in_parallel (int or None): how many threads will be used for
|
||||
processing clients, defaults to None in which case all of them
|
||||
are processed on the same thread.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
"""
|
||||
# Send args to workers
|
||||
data_pickled = pickle.dumps(server_data) # pickle once
|
||||
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
|
||||
MPI.COMM_WORLD.send(COMMAND_UPDATE, worker_rank)
|
||||
_send(data_pickled, worker_rank, pickled=True)
|
||||
|
||||
# Perform payload_fn on clients
|
||||
client_queue = clients.copy()
|
||||
free_nodes = list(range(1, MPI.COMM_WORLD.Get_size()))
|
||||
node_request_map = []
|
||||
|
||||
# Initiate computation for all clients
|
||||
while client_queue:
|
||||
if clients_in_parallel is not None:
|
||||
clients_to_process = [client_queue.pop() for _ in range(clients_in_parallel) if len(client_queue) > 0]
|
||||
else:
|
||||
clients_to_process = client_queue.pop()
|
||||
|
||||
print_rank(f"Queueing {clients_to_process}, {len(client_queue)} remaining", loglevel=logging.DEBUG)
|
||||
|
||||
# Wait for free worker node
|
||||
if not free_nodes:
|
||||
print_rank(f"Waiting for a worker", loglevel=logging.DEBUG)
|
||||
assert(len(node_request_map) > 0)
|
||||
status = MPI.Status()
|
||||
ix, _ = MPI.Request.waitany(node_request_map, status=status)
|
||||
|
||||
# Collects worker output after processing has finished
|
||||
output = _recv(status.source)
|
||||
if isinstance(output, list):
|
||||
yield from output
|
||||
else:
|
||||
yield output
|
||||
|
||||
free_nodes.append(status.source)
|
||||
print_rank(f"Found free worker {ix}:{status.source}", loglevel=logging.DEBUG)
|
||||
node_request_map.pop(ix)
|
||||
|
||||
# Run client computation on free worker node
|
||||
assert len(free_nodes) > 0
|
||||
node = free_nodes.pop()
|
||||
print_rank(f"Sending to worker {node}", loglevel=logging.DEBUG)
|
||||
payload_fn(clients_to_process, node)
|
||||
print_rank(f"Payload sent. Queueing irecv on {node}", loglevel=logging.DEBUG)
|
||||
node_request_map.append(MPI.COMM_WORLD.irecv(source=node))
|
||||
print_rank(f"Queued irecv for {node}", loglevel=logging.DEBUG)
|
||||
|
||||
print_rank(f"Done queuing clients. Waiting on workers")
|
||||
|
||||
# Wait for all workers to finish
|
||||
for i, request in enumerate(node_request_map):
|
||||
status = MPI.Status()
|
||||
request.wait(status)
|
||||
print_rank(f"Result for item {i}: source: {status.source}", loglevel=logging.DEBUG)
|
||||
|
||||
print_rank(f"Calling _recv for {status.source}", loglevel=logging.DEBUG)
|
||||
output = _recv(status.source)
|
||||
if isinstance(output, list):
|
||||
yield from output
|
||||
else:
|
||||
yield output
|
||||
|
||||
@staticmethod
|
||||
def process_clients(clients, server_data, clients_in_parallel):
|
||||
"""Ask workers to process client data.
|
||||
|
||||
The payload function defined below will send a free worker instructions
|
||||
on how to process the data of one or more clients. This payload function
|
||||
is then passed to :code:`dispatch_clients`, which continuously looks for
|
||||
free workers and sends them more clients to process.
|
||||
|
||||
Args:
|
||||
clients (list): list of client.Client objects.
|
||||
server_data (dict): dictionary containing model.
|
||||
clients_in_parallel (None or int): how many threads to use for
|
||||
processing the clients on a given worker.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
"""
|
||||
|
||||
def payload_fn(clients, node):
|
||||
"""Payload function for a training round."""
|
||||
|
||||
# Send command for training and function to process round
|
||||
MPI.COMM_WORLD.send(COMMAND_TRAIN, node)
|
||||
|
||||
# Loop through clients and send their data
|
||||
if clients_in_parallel is None:
|
||||
MPI.COMM_WORLD.send(clients.process_round, node)
|
||||
MPI.COMM_WORLD.send(clients.get_client_data(), node)
|
||||
else:
|
||||
MPI.COMM_WORLD.send(clients[0].process_round, node) # clients is a list
|
||||
MPI.COMM_WORLD.send(len(clients), node)
|
||||
for client in clients:
|
||||
MPI.COMM_WORLD.send(client.get_client_data(), node)
|
||||
|
||||
return Server.dispatch_clients(clients, server_data, payload_fn, clients_in_parallel=clients_in_parallel)
|
||||
|
||||
@staticmethod
|
||||
def process_testvalidate(clients, server_data, mode):
|
||||
"""Ask workers to use clients data to compute metrics.
|
||||
|
||||
Similar to :code:`process_round` but asks workers to
|
||||
compute metrics instead, by using a different payload function.
|
||||
|
||||
Args:
|
||||
clients (list): list of client.Client objects.
|
||||
server_data (dict): dictionary containing model.
|
||||
mode(str): whether to :code:`test` or :code:`validate`.
|
||||
|
||||
Returns:
|
||||
Generator of results sent by server via MPI.
|
||||
"""
|
||||
|
||||
def payload_fn(client, node):
|
||||
"""Payload function for a test/validation round."""
|
||||
|
||||
MPI.COMM_WORLD.send(COMMAND_TESTVAL, node)
|
||||
MPI.COMM_WORLD.send(client.run_testvalidate, node)
|
||||
MPI.COMM_WORLD.send(client.get_client_data(), node)
|
||||
MPI.COMM_WORLD.send(mode, node)
|
||||
|
||||
return Server.dispatch_clients(clients, server_data, payload_fn)
|
||||
|
||||
@staticmethod
|
||||
def terminate_workers(terminate=True):
|
||||
"""Terminate the execution of the workers."""
|
||||
|
||||
if terminate:
|
||||
print_rank("Terminating worker processes")
|
||||
for worker_rank in range(1, MPI.COMM_WORLD.Get_size()):
|
||||
MPI.COMM_WORLD.send(COMMAND_TERMINATE, worker_rank)
|
||||
|
||||
|
||||
class Worker:
|
||||
"""Worker object responsible for processing clients' data.
|
||||
|
||||
Each worker lives on a different MPI thread and is assigned to a different
|
||||
GPU. Via the :code:`dispatch_clients` function, the Server passes the
|
||||
Worker specific instructions to process clients' data, typically in order
|
||||
to generate a new model or to compute metrics.
|
||||
|
||||
Attributes:
|
||||
model (torch.nn.Module): model being trained.
|
||||
data_path (str): path where all clients' data is located.
|
||||
do_profiling (bool): if True, analyzes execution in depth.
|
||||
clients_in_parallel (None or int): if not None, processes clients in
|
||||
threads during training round.
|
||||
server_data (dict): stores data received from Server when an update
|
||||
command is received.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, data_path=None, do_profiling=False, clients_in_parallel=None):
|
||||
"""
|
||||
Set the GPU workspace for the model to be exchanged between the server and clients
|
||||
This prevents a model instance from being created on the GPU worker many time
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module, optional): model being trained, defaults to None.
|
||||
data_path (str, optional): path where all clients' data is located,
|
||||
defaults to None.
|
||||
do_profiling (bool, optional): if True, analyzes execution in depth; defaults
|
||||
to False.
|
||||
clients_in_parallel (None or int, optional): if not None, processes clients in
|
||||
threads during training round. Defaults to None.
|
||||
"""
|
||||
self.model = model
|
||||
self.data_path = data_path
|
||||
self.do_profiling = do_profiling
|
||||
self.clients_in_parallel = clients_in_parallel
|
||||
|
||||
self.server_data = None
|
||||
|
||||
# For processing in different threads, we need copies of the model
|
||||
if clients_in_parallel is not None:
|
||||
device = f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu"
|
||||
self.model_copies = [copy.deepcopy(model).to(device) for _ in range(clients_in_parallel)]
|
||||
|
||||
def run(self):
|
||||
"""Main loop executed by worker nodes.
|
||||
|
||||
This method triggers the MPI communication between the worker and
|
||||
the server. It keeps listening for commands from the Server,
|
||||
and performs different actions depending on the command received.
|
||||
"""
|
||||
|
||||
while True: # keeps listening for commands on MPI
|
||||
command = MPI.COMM_WORLD.recv()
|
||||
assert isinstance(command, str)
|
||||
|
||||
if command == COMMAND_UPDATE:
|
||||
self.server_data = _recv(0)
|
||||
|
||||
elif command == COMMAND_TRAIN:
|
||||
profiler = None
|
||||
if self.do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
client_fn = MPI.COMM_WORLD.recv() # NOTE: assumes function is same for all clients
|
||||
|
||||
# Pick whether to do processing in batches or not
|
||||
if self.clients_in_parallel is None:
|
||||
client_data = MPI.COMM_WORLD.recv()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
output = client_fn(client_data, self.server_data, self.model, self.data_path)
|
||||
else:
|
||||
n_clients = MPI.COMM_WORLD.recv()
|
||||
client_data = [MPI.COMM_WORLD.recv() for _ in range(n_clients)]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
output = process_in_parallel(client_fn, client_data, self.server_data, self.model_copies, self.data_path)
|
||||
print_rank(f"Processed batch of size {len(client_data)}, got {len(output)} outputs", loglevel=logging.DEBUG)
|
||||
|
||||
# Wait for server to be available and send output(s)
|
||||
MPI.COMM_WORLD.isend(None, 0).wait()
|
||||
_send(output, 0)
|
||||
|
||||
# Make sure that memory is cleaned up
|
||||
if self.clients_in_parallel is not None:
|
||||
for args in client_data:
|
||||
del args
|
||||
del client_fn, client_data, output
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
if self.do_profiling:
|
||||
profiler.disable()
|
||||
print_profiler(profiler)
|
||||
|
||||
elif command == COMMAND_TESTVAL:
|
||||
profiler = None
|
||||
if self.do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
client_fn = MPI.COMM_WORLD.recv()
|
||||
client_data = MPI.COMM_WORLD.recv()
|
||||
client_mode = MPI.COMM_WORLD.recv()
|
||||
|
||||
# Clean up memory before client processing
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
try:
|
||||
output = client_fn(client_data, self.server_data, client_mode, self.model)
|
||||
except RuntimeError as e:
|
||||
_dump_tensors(gpu_only=True)
|
||||
raise RuntimeError("Federated Error: {}".format(str(e)))
|
||||
|
||||
MPI.COMM_WORLD.isend(None, 0).wait()
|
||||
_send(output, 0)
|
||||
|
||||
# Make sure that memory is cleaned up
|
||||
del client_fn, client_data, output
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() if torch.cuda.is_available() else None
|
||||
|
||||
if self.do_profiling:
|
||||
profiler.disable()
|
||||
print_profiler(profiler)
|
||||
|
||||
elif command == COMMAND_TERMINATE:
|
||||
return
|
||||
|
||||
else:
|
||||
assert False, "unknown command"
|
||||
|
||||
|
||||
def _send(data, rank, pickled=False, verbose=False):
|
||||
"""Send large object by chunking it into multiple MPI messages."""
|
||||
|
||||
# Pickle data
|
||||
data_pickled = data
|
||||
if not pickled:
|
||||
data_pickled = pickle.dumps(data_pickled)
|
||||
|
||||
# Compute in how many chunks data will be sent
|
||||
num_chunks = len(data_pickled) // SPLIT_SIZE + 1
|
||||
if verbose:
|
||||
print_rank(f"_send data_pickled size: {len(data_pickled)}, {num_chunks} chunks")
|
||||
|
||||
# Send data in chunks
|
||||
MPI.COMM_WORLD.send(num_chunks, rank)
|
||||
|
||||
ix = 0
|
||||
while len(data_pickled) - ix > SPLIT_SIZE:
|
||||
MPI.COMM_WORLD.send(data_pickled[ix:ix+SPLIT_SIZE], rank)
|
||||
ix += SPLIT_SIZE
|
||||
MPI.COMM_WORLD.send(data_pickled[ix:], rank)
|
||||
|
||||
def _recv(rank):
|
||||
"""Receive large object by chunking it into multiple MPI messages."""
|
||||
|
||||
num_chunks = MPI.COMM_WORLD.recv(source=rank)
|
||||
pickled_chunks = []
|
||||
for _ in range(num_chunks):
|
||||
pickled_chunks.append(MPI.COMM_WORLD.recv(source=rank))
|
||||
data_pickled = b"".join(pickled_chunks)
|
||||
return pickle.loads(data_pickled)
|
||||
|
||||
def _dump_tensors(gpu_only=True):
|
||||
"""Print a list of the Tensors being tracked by the garbage collector."""
|
||||
|
||||
def pretty_size(size):
|
||||
"""Pretty prints a torch.Size object."""
|
||||
assert(isinstance(size, torch.Size))
|
||||
return " × ".join(map(str, size))
|
||||
|
||||
print_rank("Dump memory allocated")
|
||||
print_rank(torch.cuda.memory_allocated())
|
||||
print_rank("Dump max memory allocated")
|
||||
print_rank(torch.cuda.max_memory_allocated())
|
||||
print_rank("Dump memory cached")
|
||||
print_rank(torch.cuda.memory_cached())
|
||||
print_rank("Dump max memory cached")
|
||||
print_rank(torch.cuda.max_memory_cached())
|
||||
|
||||
total_size = 0
|
||||
for obj in gc.get_objects():
|
||||
try:
|
||||
if torch.is_tensor(obj):
|
||||
if not gpu_only or obj.is_cuda:
|
||||
print("%s:%s%s %s" % (type(obj).__name__,
|
||||
" GPU" if obj.is_cuda else "",
|
||||
" pinned" if obj.is_pinned else "",
|
||||
pretty_size(obj.size())))
|
||||
total_size += obj.numel()
|
||||
elif hasattr(obj, "data") and torch.is_tensor(obj.data):
|
||||
if not gpu_only or obj.is_cuda:
|
||||
print("%s -> %s:%s%s%s%s %s" % (type(obj).__name__,
|
||||
type(obj.data).__name__,
|
||||
" GPU" if obj.is_cuda else "",
|
||||
" pinned" if obj.data.is_pinned else "",
|
||||
" grad" if obj.requires_grad else "",
|
||||
" volatile" if obj.volatile else "",
|
||||
pretty_size(obj.data.size())))
|
||||
total_size += obj.data.numel()
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
print_rank("Total size: {}".format(total_size))
|
|
@ -0,0 +1,18 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import os
|
||||
# Macro variable that sets which distributed trainig framework is used (e.g. mpi, syft, horovod)
|
||||
TRAINING_FRAMEWORK_TYPE = 'mpi'
|
||||
logging_level = logging.INFO # DEBUG | INFO
|
||||
file_type = None
|
||||
|
||||
|
||||
def define_file_type (data_path,config):
|
||||
global file_type
|
||||
|
||||
filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
||||
arr_filename = filename.split(".")
|
||||
file_type = arr_filename[-1]
|
||||
print(" File_type has ben assigned to: {}".format(file_type))
|
|
@ -0,0 +1,956 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
'''
|
||||
In this file, we define the classes that live inside 'worker 0', the worker
|
||||
responsible for orchestration and aggregation. The main class is the
|
||||
OptimizationServer, which sends clients to the other workers to process and
|
||||
combines the resulting models.
|
||||
'''
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
# Internal imports
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
import core.federated as federated
|
||||
else:
|
||||
raise NotImplementedError('{} is not supported'.format(TRAINING_FRAMEWORK_TYPE))
|
||||
|
||||
from core.client import Client
|
||||
from .trainer import (
|
||||
ModelUpdater,
|
||||
Trainer,
|
||||
set_component_wise_lr,
|
||||
)
|
||||
from utils import (
|
||||
compute_grad_cosines,
|
||||
get_lr,
|
||||
print_rank,
|
||||
update_json_log,
|
||||
)
|
||||
from utils.utils import _to_cuda
|
||||
|
||||
from extensions import (
|
||||
RL,
|
||||
privacy,
|
||||
)
|
||||
|
||||
# For profiling
|
||||
import cProfile
|
||||
import pstats
|
||||
|
||||
# AzureML-related libs
|
||||
from azureml.core import Run
|
||||
run = Run.get_context()
|
||||
|
||||
|
||||
class OptimizationServer(federated.Server):
|
||||
def __init__(self, num_clients, model, optimizer, ss_scheduler, data_path, model_path, train_dataloader,
|
||||
val_dataloader, test_dataloader, config, config_server):
|
||||
'''Implement Server's orchestration and aggregation.
|
||||
|
||||
This is the main Server class, that actually implements orchestration
|
||||
and aggregation, inheriting from `federated.Server`, which deals with
|
||||
communication only.
|
||||
|
||||
The `train` method is central in FLUTE, as it defines good part of what
|
||||
happens during training.
|
||||
|
||||
Args:
|
||||
num_clients (int): total available clients.
|
||||
model (torch.nn.Module): neural network model.
|
||||
optimizer (torch.optim.Optimizer): optimizer.
|
||||
ss_scheduler: scheduled sampling scheduler.
|
||||
data_path (str): points to where data is.
|
||||
model_path (str): points to where pretrained model is.
|
||||
train_dataloader (torch.utils.data.DataLoader): dataloader for training
|
||||
val_dataloader (torch.utils.data.DataLoader): dataloader for validation
|
||||
test_dataloader (torch.utils.data.DataLoader): dataloader for test, can be None
|
||||
config (dict): JSON style configuration parameters
|
||||
config_server: deprecated, kept for API compatibility only.
|
||||
'''
|
||||
|
||||
super().__init__()
|
||||
|
||||
# Initialize all attributes from arguments
|
||||
self.client_idx_list = list(range(num_clients))
|
||||
self.config = config
|
||||
server_config = config['server_config']
|
||||
decoder_config = config.get('decoder_config', None)
|
||||
|
||||
self.max_iteration = server_config['max_iteration']
|
||||
self.do_clustering = server_config.get('clustering', False)
|
||||
|
||||
self.num_clients_per_iteration = [int(x) for x in server_config['num_clients_per_iteration'].split(',')] \
|
||||
if isinstance(server_config['num_clients_per_iteration'], str) \
|
||||
else [server_config['num_clients_per_iteration']]
|
||||
|
||||
self.val_freq = server_config['val_freq']
|
||||
self.rec_freq = server_config['rec_freq']
|
||||
self.model_backup_freq = server_config.get('model_backup_freq', 100)
|
||||
self.worker_trainer_config = server_config.get('trainer_config', {})
|
||||
|
||||
self.aggregate_median = server_config['aggregate_median']
|
||||
self.initial_lr_client = server_config.get('initial_lr_client', -1.0)
|
||||
self.lr_decay_factor = server_config.get('lr_decay_factor', 1.0)
|
||||
|
||||
self.model_type = config['model_config']['model_type']
|
||||
self.quant_thresh = config['client_config'].get('quant_thresh', None)
|
||||
self.quant_bits = config['client_config'].get('quant_bits', 10)
|
||||
|
||||
self.list_of_train_data = config['client_config']['data_config']['train']['list_of_train_data']
|
||||
self.data_path = data_path
|
||||
|
||||
# Get max grad norm from data config
|
||||
if 'train' in server_config['data_config']:
|
||||
max_grad_norm = server_config['data_config']['train'].get('max_grad_norm', None)
|
||||
else:
|
||||
max_grad_norm = None
|
||||
|
||||
# Creating an instance to update the model with stats aggregated from workers
|
||||
self.worker_trainer = ModelUpdater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader if train_dataloader is not None else val_dataloader,
|
||||
val_dataloader=val_dataloader,
|
||||
max_grad_norm=max_grad_norm,
|
||||
anneal_config=server_config['annealing_config'],
|
||||
model_type=self.model_type,
|
||||
decoder_config=decoder_config
|
||||
)
|
||||
|
||||
self.val_dataloader = val_dataloader
|
||||
|
||||
# Creating an instance for the server-side trainer (runs mini-batch SGD)
|
||||
self.server_replay_iterations = None
|
||||
self.server_trainer = None
|
||||
if train_dataloader is not None:
|
||||
assert 'server_replay_config' in server_config, 'server_replay_config is not set'
|
||||
assert 'optimizer_config' in server_config[
|
||||
'server_replay_config'], 'server-side replay training optimizer is not set'
|
||||
self.server_optimizer_config = server_config['server_replay_config']['optimizer_config']
|
||||
self.server_trainer_config = server_config['server_replay_config'].get('trainer_config', {})
|
||||
self.server_replay_iterations = server_config['server_replay_config']['server_iterations']
|
||||
self.server_trainer = Trainer(
|
||||
model=model,
|
||||
optimizer=None,
|
||||
ss_scheduler=ss_scheduler,
|
||||
train_dataloader=train_dataloader,
|
||||
server_replay_config=server_config['server_replay_config'],
|
||||
val_dataloader=None,
|
||||
max_grad_norm=server_config['server_replay_config']\
|
||||
.get('max_grad_norm',server_config['data_config']['train']\
|
||||
.get('max_grad_norm',None)),
|
||||
anneal_config=server_config['server_replay_config'].get('annealing_config', None)
|
||||
)
|
||||
|
||||
self.skip_model_update = False # will not update the model if True
|
||||
|
||||
self.train_loss = 0.0
|
||||
self.model_path = model_path
|
||||
self.best_model_criterion = server_config['best_model_criterion']
|
||||
self.fall_back_to_best_model = server_config['fall_back_to_best_model']
|
||||
self.last_model_path = os.path.join(self.model_path, 'latest_model.tar')
|
||||
self.best_model_path = os.path.join(self.model_path,
|
||||
'best_val_{}_model.tar'.format(self.best_model_criterion))
|
||||
self.log_path = os.path.join(self.model_path, 'status_log.json')
|
||||
self.cur_iter_no = 0 # keep the iteration number for Tensor board plotting
|
||||
self.best_val_loss= float('inf')
|
||||
self.best_val_acc = -1.0
|
||||
self.best_test_loss= float('inf')
|
||||
self.best_test_acc= -1.0
|
||||
self.lr_weight = 1.0
|
||||
|
||||
self.weight_sum_stale = 0.0
|
||||
self.client_parameters_stack_stale = []
|
||||
self.stale_prob = server_config.get('stale_prob', 0.0)
|
||||
|
||||
self.losses = []
|
||||
self.no_label_updates = 0 # no. label updates
|
||||
|
||||
# Update the parameters above if the log file
|
||||
if server_config.get('resume_from_checkpoint', False):
|
||||
self.load_saved_status()
|
||||
|
||||
# Decoding config
|
||||
self.test_dataloader = test_dataloader
|
||||
self.decoder_config = decoder_config
|
||||
self.spm_model = server_config['data_config']['test'].get('spm_model', None)
|
||||
|
||||
self.do_profiling = server_config.get('do_profiling', False)
|
||||
|
||||
self.wantRL = server_config.get('wantRL', False)
|
||||
self.aggregate_fast = server_config.get('fast_aggregation', False)
|
||||
if self.aggregate_fast:
|
||||
print_rank('It is NOT possible to enable RL with fast_aggregation, RL is set to False', loglevel=logging.INFO)
|
||||
self.wantRL = False
|
||||
print_rank('It is NOT possible in Current Implementation to have stale gradients with fast_aggregation, stale_prob is set to 0.0', loglevel=logging.INFO)
|
||||
self.stale_prob = 0.0
|
||||
|
||||
if self.wantRL:
|
||||
self.RL = RL(config=server_config)
|
||||
|
||||
# Parallel processing
|
||||
self.clients_in_parallel = config['client_config'].get('clients_in_parallel', None)
|
||||
|
||||
def load_saved_status(self):
|
||||
'''Load checkpoint from disk'''
|
||||
|
||||
# Check if model is on disk, if so loads it onto trainer
|
||||
if os.path.exists(self.last_model_path):
|
||||
print_rank('Resuming from checkpoint model {}'.format(self.last_model_path))
|
||||
self.worker_trainer.load(self.last_model_path, update_lr_scheduler=True, update_ss_scheduler=True)
|
||||
if self.server_trainer is not None:
|
||||
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
||||
|
||||
# Check if log is on disk, if so loads it onto current stats
|
||||
if os.path.exists(self.log_path):
|
||||
with open(self.log_path, 'r') as logfp: # loading the iteration no., best loss and CER
|
||||
elems = json.load(logfp)
|
||||
self.cur_iter_no = elems.get('i', 0)
|
||||
self.best_val_loss = elems.get('best_val_loss', float('inf'))
|
||||
self.best_val_acc = elems.get('best_val_acc', float('inf'))
|
||||
self.best_test_loss = elems.get('best_test_loss', float('inf'))
|
||||
self.best_test_acc = elems.get('best_test_acc', float('inf'))
|
||||
self.lr_weight = elems.get('weight', 1.0)
|
||||
self.no_label_updates = elems.get('num_label_updates', 0)
|
||||
print_rank(f'Resuming from status_log: cur_iter: {self.cur_iter_no}')
|
||||
|
||||
def make_eval_clients(self, dataloader):
|
||||
'''Generator that yields clients for evaluation, continuously.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): used to get client's data
|
||||
'''
|
||||
|
||||
total = sum(dataloader.dataset.num_samples)
|
||||
clients = federated.size() - 1
|
||||
delta = total / clients + 1
|
||||
threshold = delta
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
|
||||
# Accumulate users until a threshold is reached to form client
|
||||
for i in range(len(dataloader.dataset.user_list)):
|
||||
current_users_idxs.append(i)
|
||||
count = dataloader.dataset.num_samples[i]
|
||||
current_total += count
|
||||
if current_total > threshold:
|
||||
print_rank(f'sending {len(current_users_idxs)} users', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
||||
current_users_idxs = list()
|
||||
current_total = 0
|
||||
|
||||
if len(current_users_idxs) != 0:
|
||||
print_rank(f'sending {len(current_users_idxs)} users -- residual', loglevel=logging.DEBUG)
|
||||
yield Client(current_users_idxs, self.config, False, dataloader)
|
||||
|
||||
def run_distributed_evaluation(self, dataloader, mode):
|
||||
'''Perform evaluation using available workers.
|
||||
|
||||
See also `process_test_validate` on federated.py.
|
||||
|
||||
Args:
|
||||
dataloader (torch.utils.data.DataLoader): used to fetch data.
|
||||
mode (str): `test` or `val`.
|
||||
'''
|
||||
val_clients = list(self.make_eval_clients(dataloader))
|
||||
print_rank(f'mode: {mode} evaluation_clients {len(val_clients)}', loglevel=logging.DEBUG)
|
||||
|
||||
usl_json = None # NOTE: deprecated
|
||||
val_loss = val_acc = total = 0
|
||||
self.logits = {'predictions': [], 'probabilities': [], 'labels': []}
|
||||
server_data = (0.0, usl_json, [p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()])
|
||||
|
||||
for result in self.process_testvalidate(val_clients, server_data, mode):
|
||||
output, (loss, cer, count) = result
|
||||
val_loss += loss * count
|
||||
val_acc += cer * count
|
||||
total += count
|
||||
|
||||
if output is not None:
|
||||
self.logits['predictions'].append(output['predictions'])
|
||||
self.logits['probabilities'].append(output['probabilities'])
|
||||
self.logits['labels'].append(output['labels'])
|
||||
|
||||
if self.logits['probabilities'] and self.logits['predictions'] and self.logits['labels']:
|
||||
self.logits['predictions'] = np.concatenate(self.logits['predictions'])
|
||||
self.logits['probabilities'] = np.concatenate(self.logits['probabilities'])
|
||||
self.logits['labels'] = np.concatenate(self.logits['labels'])
|
||||
|
||||
return val_loss / total, val_acc / total
|
||||
|
||||
def run_distributed_inference(self, mode):
|
||||
'''Call `run_distributed_evaluation` specifically for test or validation.
|
||||
|
||||
This is just a helper function that fetches the dataloader depending on
|
||||
the mode and calls `run_distributed_evaluation` using that dataloader.
|
||||
|
||||
Args:
|
||||
mode (str): `test` or `val`.
|
||||
'''
|
||||
if mode == 'val':
|
||||
dataloader = self.val_dataloader
|
||||
elif mode == 'test':
|
||||
dataloader = self.test_dataloader
|
||||
else:
|
||||
raise NotImplementedError('Unsupported mode: {}'.format(mode))
|
||||
return self.run_distributed_evaluation(dataloader, mode)
|
||||
|
||||
def run(self):
|
||||
'''Trigger training.
|
||||
|
||||
This is a simple wrapper to the `train` method.
|
||||
'''
|
||||
print_rank('server started')
|
||||
self.train()
|
||||
print_rank('server terminated')
|
||||
|
||||
def train(self):
|
||||
'''Main method for training.'''
|
||||
|
||||
self.run_stats = {
|
||||
'secsPerClientRound': [],
|
||||
'secsPerClient': [],
|
||||
'secsPerClientTraining': [],
|
||||
'secsPerClientSetup': [],
|
||||
'secsPerClientFull': [],
|
||||
'secsPerRoundHousekeeping': [],
|
||||
'secsPerRoundTotal': [],
|
||||
'mpiCosts': []
|
||||
}
|
||||
|
||||
run.log('Max iterations', self.max_iteration)
|
||||
try:
|
||||
self.worker_trainer.model.cuda() if torch.cuda.is_available() else None
|
||||
|
||||
# Do an initial validation round to understand the pretrained model's validation accuracy
|
||||
# Skip if we resumed from a checkpoint (cur_iter_no > 0)
|
||||
if self.cur_iter_no == 0:
|
||||
self.run_val_test(0, metric_logger=run.log)
|
||||
|
||||
# Dump all the information in aggregate_metric
|
||||
print_rank('Saving Model Before Starting Training', loglevel=logging.INFO)
|
||||
for token in ['best_val_loss', 'best_val_acc', 'best_test_acc', 'latest']:
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token=token,
|
||||
config=self.config['server_config']
|
||||
)
|
||||
|
||||
# Training loop
|
||||
self.worker_trainer.model.train()
|
||||
for i in range(self.cur_iter_no, self.max_iteration):
|
||||
begin = time.time()
|
||||
metrics_payload = {}
|
||||
|
||||
def log_metric(k, v):
|
||||
metrics_payload[k] = v
|
||||
|
||||
print_rank('==== iteration {}'.format(i))
|
||||
log_metric('Current iteration', i)
|
||||
usl_json = None # deprecated
|
||||
|
||||
# Initial value for the learning rate of the worker
|
||||
initial_lr = self.initial_lr_client * self.lr_weight
|
||||
print_rank('Client learning rate {}'.format(initial_lr))
|
||||
|
||||
# Run training on clients
|
||||
self.worker_trainer.model.zero_grad()
|
||||
self.train_loss = []
|
||||
server_data = (
|
||||
initial_lr,
|
||||
usl_json,
|
||||
[p.data.to(torch.device('cpu')) for p in self.worker_trainer.model.parameters()]
|
||||
)
|
||||
|
||||
# Random number of clients per iteration
|
||||
if len(self.num_clients_per_iteration) > 1:
|
||||
num_clients_curr_iter = random.randint(
|
||||
self.num_clients_per_iteration[0],
|
||||
self.num_clients_per_iteration[1]
|
||||
)
|
||||
else:
|
||||
num_clients_curr_iter = self.num_clients_per_iteration[0]
|
||||
log_metric('Clients for round', num_clients_curr_iter)
|
||||
|
||||
# Perform annealing in quantization threshold
|
||||
if self.quant_thresh is not None:
|
||||
self.config['client_config']['quant_thresh'] *= self.config['client_config'].get('quant_anneal', 1.0)
|
||||
self.quant_thresh = self.config['client_config']['quant_thresh']
|
||||
log_metric('Quantization Thresh.', self.config['client_config']['quant_thresh'])
|
||||
|
||||
# Create the pool of clients -- sample from this pool to assign to workers
|
||||
sampled_idx_clients = random.sample(self.client_idx_list,
|
||||
num_clients_curr_iter) if num_clients_curr_iter > 0 else self.client_idx_list
|
||||
sampled_clients = [
|
||||
Client(
|
||||
client_id,
|
||||
self.config,
|
||||
self.config['client_config']['type'] == 'optimization',
|
||||
None
|
||||
) for client_id in sampled_idx_clients
|
||||
]
|
||||
|
||||
# Initialize stats
|
||||
clients_begin = time.time()
|
||||
|
||||
client_losses = []
|
||||
client_weights = []
|
||||
client_mag_grads = []
|
||||
client_mean_grads = []
|
||||
client_var_grads = []
|
||||
client_norm_grads = []
|
||||
|
||||
self.client_parameters_stack = []
|
||||
self.run_stats['secsPerClient'].append([])
|
||||
self.run_stats['secsPerClientFull'].append([])
|
||||
self.run_stats['secsPerClientTraining'].append([])
|
||||
self.run_stats['secsPerClientSetup'].append([])
|
||||
self.run_stats['mpiCosts'].append([])
|
||||
|
||||
# Check if we want privacy metrics
|
||||
apply_privacy_metrics = self.config.get('privacy_metrics_config', None) and \
|
||||
self.config['privacy_metrics_config']['apply_metrics']
|
||||
adaptive_leakage = apply_privacy_metrics and \
|
||||
self.config['privacy_metrics_config'].get('adaptive_leakage_threshold', None)
|
||||
if apply_privacy_metrics:
|
||||
privacy_metrics_stats = defaultdict(list)
|
||||
|
||||
# Initialize profiler
|
||||
profiler = None
|
||||
if self.do_profiling:
|
||||
profiler = cProfile.Profile()
|
||||
profiler.enable()
|
||||
|
||||
# Reset gradient for the model before assigning the new gradients
|
||||
self.worker_trainer.model.zero_grad()
|
||||
|
||||
for client_output in self.process_clients(sampled_clients, server_data, self.clients_in_parallel):
|
||||
# Process client output
|
||||
client_timestamp = client_output['ts']
|
||||
client_stats = client_output['cs']
|
||||
client_loss = client_output['tl']
|
||||
client_weight = client_output['wt']
|
||||
client_mag_grad = client_output['mg']
|
||||
client_var_grad = client_output['vg']
|
||||
client_mean_grad = client_output['ng']
|
||||
client_norm_grad = client_output['rg']
|
||||
num_samples = client_output['ns']
|
||||
|
||||
# Client_output may contain 'gr' or 'pm' for grads or params.
|
||||
# For the time being we just support gradients.
|
||||
client_parameters = client_output['gr']
|
||||
|
||||
if apply_privacy_metrics:
|
||||
privacy_stats = client_output['ps']
|
||||
for metric, value in privacy_stats.items():
|
||||
privacy_metrics_stats[metric].append(value)
|
||||
|
||||
self.run_stats['mpiCosts'][-1].append(time.time() - client_timestamp)
|
||||
|
||||
# Ignore clients with agg. weight == 0.0
|
||||
if client_weight == 0.0:
|
||||
print_rank('Dropping client Due to issues with weighting', loglevel=logging.DEBUG)
|
||||
num_clients_curr_iter -= 1
|
||||
continue
|
||||
|
||||
# Get actual pseudo-gradients for aggregation
|
||||
if self.aggregate_fast:
|
||||
self.aggregate_gradients_inplace(client_parameters)
|
||||
else:
|
||||
self.client_parameters_stack.append(client_parameters)
|
||||
|
||||
# Aggregate stats
|
||||
self.train_loss.append(client_loss)
|
||||
client_losses.append(client_loss)
|
||||
client_weights.append(client_weight)
|
||||
client_mean_grads.append(client_mean_grad.item())
|
||||
client_var_grads.append(client_var_grad.item())
|
||||
client_norm_grads.append(client_norm_grad.item())
|
||||
|
||||
# Mark the end of client processing
|
||||
client_end = time.time()
|
||||
|
||||
self.run_stats['secsPerClientFull'][-1].append(client_stats['full cost'])
|
||||
self.run_stats['secsPerClientTraining'][-1].append(client_stats['training'])
|
||||
self.run_stats['secsPerClientSetup'][-1].append(client_stats['setup'])
|
||||
self.run_stats['secsPerClient'][-1].append(client_end - clients_begin)
|
||||
|
||||
# Tear down profiler
|
||||
if self.do_profiling:
|
||||
profiler.disable()
|
||||
stats = pstats.Stats(profiler)
|
||||
stats.sort_stats('cumulative').print_stats()
|
||||
|
||||
# Prepare output
|
||||
client_weights = np.array(client_weights)
|
||||
client_mag_grads = np.array(client_mag_grads)
|
||||
client_mean_grads = np.array(client_mean_grads)
|
||||
client_var_grads = np.array(client_var_grads)
|
||||
client_norm_grads = np.array(client_norm_grads)
|
||||
|
||||
dump_norm_stats = self.config.get('dump_norm_stats', False)
|
||||
if dump_norm_stats:
|
||||
with open(os.path.join(self.model_path, 'norm_stats.txt'), 'a', encoding='utf-8') as outF:
|
||||
outF.write('{}\n'.format(json.dumps(list(client_norm_grads))))
|
||||
|
||||
# Print the privacy metrics
|
||||
if apply_privacy_metrics:
|
||||
for metric, values in privacy_metrics_stats.items():
|
||||
if metric == 'Dropped clients':
|
||||
log_metric(metric, sum(values))
|
||||
else:
|
||||
log_metric(metric, max(values))
|
||||
|
||||
if type(adaptive_leakage) is float:
|
||||
values = privacy_metrics_stats['Practical epsilon (Max leakage)']
|
||||
new_threshold = list(sorted(values))[int(adaptive_leakage*len(values))]
|
||||
print_rank('Updating leakage threshold to {}'.format(new_threshold))
|
||||
self.config['privacy_metrics_config']['max_allowed_leakage'] = new_threshold
|
||||
|
||||
# Mark that all clients have been processed
|
||||
end = time.time()
|
||||
self.run_stats['secsPerClientRound'].append(end - begin)
|
||||
begin = end
|
||||
|
||||
if self.wantRL:
|
||||
rl_model = self.run_RL_inference(client_weights, client_mag_grads, client_mean_grads, client_var_grads)
|
||||
|
||||
# Aggregation step
|
||||
if dump_norm_stats:
|
||||
cps_copy = [[g.clone().detach() for g in x] for x in self.client_parameters_stack]
|
||||
weight_sum = self.aggregate_gradients(num_clients_curr_iter, client_weights, metric_logger=log_metric)
|
||||
print_rank('Sum of weights: {}'.format(weight_sum), loglevel=logging.DEBUG)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Normalize with weight_sum
|
||||
for p in self.worker_trainer.model.parameters():
|
||||
p.grad /= weight_sum
|
||||
|
||||
if dump_norm_stats:
|
||||
cosines = compute_grad_cosines(cps_copy, [p.grad.clone().detach() for p in self.worker_trainer.model.parameters()])
|
||||
with open(os.path.join(self.model_path, 'cosines.txt'), 'a', encoding='utf-8') as outfile:
|
||||
outfile.write('{}\n'.format(json.dumps(cosines)))
|
||||
|
||||
# DP-specific steps
|
||||
privacy.apply_global_dp(self.config, self.worker_trainer.model, num_clients_curr_iter=num_clients_curr_iter, select_grad=True, metric_logger=log_metric)
|
||||
eps = privacy.update_privacy_accountant(self.config, len(self.client_idx_list), curr_iter=i, num_clients_curr_iter=num_clients_curr_iter)
|
||||
if eps:
|
||||
print_rank(f'DP result: {eps}')
|
||||
|
||||
# Log the training loss to tensorboard/AML
|
||||
log_metric('Training loss', sum(self.train_loss))
|
||||
|
||||
if self.skip_model_update is True:
|
||||
print_rank('Skipping model update')
|
||||
continue
|
||||
|
||||
# Run optimization with gradient/model aggregated from clients
|
||||
print_rank('Updating model')
|
||||
self.worker_trainer.update_model()
|
||||
print_rank('Updating learning rate scheduler')
|
||||
self.losses = self.worker_trainer.run_lr_scheduler(force_run_val=False)
|
||||
|
||||
if self.wantRL:
|
||||
self.run_RL_training(i, rl_model, client_weights, client_mag_grads, client_mean_grads, client_var_grads, log_metric)
|
||||
|
||||
# Run a couple of iterations of training data on the server
|
||||
if self.server_trainer is not None:
|
||||
print_rank('Running replay iterations on server')
|
||||
|
||||
if 'updatable_names' in self.server_trainer_config:
|
||||
set_component_wise_lr(
|
||||
self.worker_trainer.model,
|
||||
self.server_optimizer_config,
|
||||
self.server_trainer_config['updatable_names']
|
||||
)
|
||||
self.server_trainer.prepare_iteration(self.worker_trainer.model)
|
||||
self.server_trainer.train_desired_samples(self.server_replay_iterations)
|
||||
self.worker_trainer.model.load_state_dict(self.server_trainer.model.state_dict())
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Update a sampling scheduler
|
||||
print_rank('Run ss scheduler')
|
||||
self.worker_trainer.run_ss_scheduler()
|
||||
|
||||
# Run inference and score on val/test depending on the iter. number
|
||||
self.run_val_test(i + 1, metric_logger=log_metric)
|
||||
|
||||
# Backup the current best models
|
||||
self.backup_models(i)
|
||||
|
||||
# Fall back to the best model if the option is enabled
|
||||
self.fall_back_to_prev_best_status()
|
||||
|
||||
# Logging the latest best values
|
||||
update_json_log(
|
||||
self.log_path,
|
||||
{
|
||||
'i': i + 1,
|
||||
'best_val_loss': float(self.best_val_loss),
|
||||
'best_val_acc': float(self.best_val_acc),
|
||||
'best_test_loss': float(self.best_test_loss),
|
||||
'best_test_acc': float(self.best_test_acc),
|
||||
'weight': float(self.lr_weight),
|
||||
'num_label_updates': int(self.no_label_updates)
|
||||
},
|
||||
)
|
||||
|
||||
end = time.time()
|
||||
|
||||
# Aggregate stats
|
||||
self.run_stats['secsPerRoundHousekeeping'].append(end - begin)
|
||||
self.run_stats['secsPerRoundTotal'].append(self.run_stats['secsPerClientRound'][-1] + \
|
||||
self.run_stats['secsPerRoundHousekeeping'][-1])
|
||||
|
||||
log_metric('secsPerRoundTotal', self.run_stats['secsPerRoundTotal'][-1])
|
||||
if self.do_profiling:
|
||||
log_metric('secsPerClientRound', self.run_stats['secsPerClientRound'][-1])
|
||||
log_metric('secsPerRoundHousekeeping', self.run_stats['secsPerRoundHousekeeping'][-1])
|
||||
|
||||
metrics_for_stats = [
|
||||
'secsPerClient',
|
||||
'secsPerClientTraining',
|
||||
'secsPerClientFull',
|
||||
'secsPerClientSetup'
|
||||
'mpiCosts',
|
||||
]
|
||||
|
||||
for metric in metrics_for_stats:
|
||||
log_metric(f'{metric}Mean', np.mean(self.run_metrics[metric][-1]))
|
||||
log_metric(f'{metric}Median', np.median(self.run_metrics[metric][-1]))
|
||||
log_metric(f'{metric}Max', max(self.run_metrics[metric][-1]))
|
||||
|
||||
for k in self.run_metrics:
|
||||
if k in metrics_for_stats:
|
||||
print_rank('{}: {}'.format(k, max(self.run_stats[k][-1])), loglevel=logging.DEBUG)
|
||||
else:
|
||||
print_rank('{}: {}'.format(k, self.run_stats[k][-1]), loglevel=logging.DEBUG)
|
||||
|
||||
# Log all the metrics
|
||||
for k in metrics_payload:
|
||||
run.log(k, metrics_payload[k])
|
||||
|
||||
finally: # perform cleanup even if error was raised above
|
||||
self.terminate_workers(terminate=(not self.do_clustering))
|
||||
|
||||
def backup_models(self, i):
|
||||
'''Save the current best models.
|
||||
|
||||
Save CER model, the best loss model and the best WER model. This occurs
|
||||
at a specified period.
|
||||
|
||||
Args:
|
||||
i: no. of iterations.
|
||||
'''
|
||||
|
||||
# Always save the latest model
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token='latest',
|
||||
config=self.config['server_config'],
|
||||
)
|
||||
|
||||
if (i % self.model_backup_freq) == 0: # save the current best models
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token='epoch{}'.format(i),
|
||||
config=self.config['server_config']
|
||||
)
|
||||
|
||||
for bodyname in ['best_val_acc', 'best_val_loss', 'best_test_acc']:
|
||||
src_model_path = os.path.join(self.model_path, '{}_model.tar'.format(bodyname))
|
||||
if os.path.exists(src_model_path):
|
||||
dst_model_path = os.path.join(self.model_path, 'epoch{}_{}_model.tar'.format(i, bodyname))
|
||||
shutil.copyfile(src_model_path, dst_model_path)
|
||||
print_rank('Saved {}'.format(dst_model_path))
|
||||
|
||||
def fall_back_to_prev_best_status(self):
|
||||
'''Go back to the past best status and switch to the recent best model.'''
|
||||
|
||||
if self.fall_back_to_best_model:
|
||||
print_rank('falling back to model {}'.format(self.best_model_path))
|
||||
|
||||
# Save current learning rate
|
||||
tmp_lr = get_lr(self.worker_trainer.optimizer)
|
||||
|
||||
# Load previous best model
|
||||
self.worker_trainer.load(self.best_model_path, update_lr_scheduler=False, update_ss_scheduler=False)
|
||||
|
||||
# Update previous learning rate on optimizer
|
||||
for g in self.worker_trainer.optimizer.param_groups:
|
||||
g['lr'] = tmp_lr
|
||||
|
||||
if self.server_trainer is not None:
|
||||
self.server_trainer.model = self.worker_trainer.model # make sure that the models are in sync
|
||||
|
||||
def run_RL_inference(self, client_weights, client_mag_grads, client_mean_grads, client_var_grads):
|
||||
'''Uses RL to estimate weights, using DGA.
|
||||
|
||||
Args:
|
||||
client_weights (numpy.ndarray): original weights for aggregation.
|
||||
client_mag_grads (numpy.ndarray): gradient stats for RL (magnitudes).
|
||||
client_mean_grads (numpy.ndarray): gradient stats for RL (means).
|
||||
client_var_grads (numpy.ndarray): gradient stats for RL (vars).
|
||||
|
||||
Returns:
|
||||
list of torch.Tensor: parameters of model used to perform RL.
|
||||
'''
|
||||
|
||||
weight_sum = 0
|
||||
original_model = copy.copy([p for p in self.worker_trainer.model.parameters()])
|
||||
|
||||
# Reinforcement learning for estimating weights
|
||||
print_rank('RL estimation of the aggregation weights', loglevel=logging.INFO)
|
||||
rl_weights = self.RL.forward(
|
||||
np.concatenate((client_weights, client_mag_grads, client_mean_grads, client_var_grads), axis=0)).cpu().detach().np()
|
||||
if rl_weights.ndim > 1:
|
||||
rl_weights = rl_weights[-1, :]
|
||||
rl_weights = np.exp(rl_weights)
|
||||
|
||||
print_rank('RL Weights BEFORE filtering: {}'.format(rl_weights), loglevel=logging.DEBUG)
|
||||
index = np.argwhere(np.isnan(rl_weights))
|
||||
rl_weights[index] = 0
|
||||
index = np.argwhere(np.isinf(rl_weights))
|
||||
rl_weights[index] = 0
|
||||
print_rank('RL Weights AFTER filtering: {}'.format(rl_weights), loglevel=logging.DEBUG)
|
||||
|
||||
for client_parameters, orig_weight, rl_weight in zip(self.client_parameters_stack, client_weights, rl_weights):
|
||||
# Model parameters are already multiplied with weight on client, we only have to sum them up
|
||||
for p, client_grad in zip(self.worker_trainer.model.parameters(), client_parameters):
|
||||
if p.grad is None:
|
||||
p.grad = _to_cuda(client_grad) * rl_weight / orig_weight
|
||||
else:
|
||||
p.grad += _to_cuda(client_grad) * rl_weight / orig_weight
|
||||
weight_sum += rl_weight
|
||||
|
||||
# Normalize with weight_sum
|
||||
for p in self.worker_trainer.model.parameters():
|
||||
p.grad /= weight_sum
|
||||
|
||||
# Run optimization with gradient/model aggregated from clients
|
||||
self.worker_trainer.update_model()
|
||||
|
||||
# Get the validation result back
|
||||
(rl_val_loss, rl_val_acc) = self.worker_trainer.run_lr_scheduler(force_run_val=True)
|
||||
|
||||
# Save model and revert to previous one
|
||||
rl_model = copy.copy([p.data for p in self.worker_trainer.model.parameters()])
|
||||
for p, p_ in zip(self.worker_trainer.model.parameters(), original_model):
|
||||
p.data = p_.data.detach().clone()
|
||||
|
||||
# Set the current set of weights
|
||||
self.RL.set_weights(rl_weights)
|
||||
self.RL.set_losses((rl_val_loss, rl_val_acc))
|
||||
|
||||
# Return the resulting RL-based model
|
||||
return rl_model
|
||||
|
||||
def run_RL_training(self, iter, rl_model, client_weights, client_mag_grads, client_mean_grads, client_var_grads, metric_logger):
|
||||
'''Trains RL for estimating weights, following DGA recipe.
|
||||
|
||||
Args:
|
||||
iter (int): current iteration.
|
||||
rl_model (list of torch.Tensor): parameters of model used to perform RL.
|
||||
client_weights (numpy.ndarray): original weights for aggregation.
|
||||
client_mag_grads (numpy.ndarray): gradient stats for RL (magnitudes).
|
||||
client_mean_grads (numpy.ndarray): gradient stats for RL (means).
|
||||
client_var_grads (numpy.ndarray): gradient stats for RL (vars).
|
||||
metric_logger (callback, optional): callback used for logging.
|
||||
Defaults to None, in which case AML logger is used.
|
||||
'''
|
||||
|
||||
# Get the validation result back
|
||||
if None in self.losses:
|
||||
self.losses = self.run_distributed_inference(mode='val')
|
||||
|
||||
# Expected structure of batch
|
||||
print_rank('Performing RL training on the aggregation weights')
|
||||
if abs(self.losses[1] - self.RL.rl_losses[1]) < 0.001:
|
||||
reward = 0.1
|
||||
print_rank(
|
||||
'Iter:{} val_ACC={} rl_val_ACC={} reward={}'.format(iter, self.losses[1], self.RL.rl_losses[1], reward))
|
||||
if 'marginal_update_RL' in self.config['server_config'] and \
|
||||
self.config['server_config']['marginal_update_RL']:
|
||||
self.losses = self.RL.rl_losses
|
||||
for p, p_ in zip(self.worker_trainer.model.parameters(), rl_model):
|
||||
p.data= p_.data.detach().clone()
|
||||
|
||||
elif (self.losses[1] - self.RL.rl_losses[1]) > 0:
|
||||
reward = 1.0
|
||||
print_rank(
|
||||
'Iter:{} val_ACC={} rl_val_ACC={} reward={}'.format(iter, self.losses[1], self.RL.rl_losses[1], reward))
|
||||
self.losses = self.RL.rl_losses
|
||||
for p, p_ in zip(self.worker_trainer.model.parameters(), rl_model):
|
||||
p.data = p_.data.detach().clone()
|
||||
|
||||
else:
|
||||
reward = -1.0
|
||||
print_rank(
|
||||
'Iter:{} val_ACC={} rl_val_ACC={} reward={}'.format(iter, self.losses[1], self.RL.rl_losses[1], reward))
|
||||
|
||||
# Taking the policy from a game-based RL
|
||||
batch = (
|
||||
(np.concatenate((client_weights, client_mag_grads, client_mean_grads, client_var_grads), axis=0)),
|
||||
(self.RL.rl_weights),
|
||||
[reward]
|
||||
)
|
||||
|
||||
print_rank('RL Model Update -- Training')
|
||||
self.RL.train(batch)
|
||||
|
||||
print_rank('RL State Saving')
|
||||
self.RL.save(iter)
|
||||
|
||||
print_rank('RL logging')
|
||||
metric_logger('RL Running Loss', self.RL.runningLoss)
|
||||
metric_logger('RL Rewards', reward)
|
||||
|
||||
def run_val_test(self, i, metric_logger=None):
|
||||
'''Run validation or test, depending on current iteration i.
|
||||
|
||||
Args:
|
||||
i (int): current iteration.
|
||||
metric_logger (callback, optional): callback used for logging.
|
||||
Defaults to None, in which case AML logger is used.
|
||||
'''
|
||||
|
||||
if metric_logger is None:
|
||||
metric_logger = run.log
|
||||
|
||||
# Run validation and update the LR scheduler
|
||||
if (i % self.val_freq) == 0: # print loss info to Tensorboard on Philly
|
||||
if 'wantRL' not in self.config['server_config'] or not self.config['server_config']['wantRL']:
|
||||
print_rank('Running validation at itr={}'.format(i))
|
||||
self.losses = self.run_distributed_inference(mode='val')
|
||||
|
||||
# Log changes
|
||||
metric_logger('LR for agg. opt.', get_lr(self.worker_trainer.optimizer))
|
||||
metric_logger('Val Loss', self.losses[0])
|
||||
metric_logger('Val Acc', self.losses[1])
|
||||
|
||||
print_rank('LOG: val_loss={}: best_val_loss={}'.format(self.losses[0], self.best_val_loss))
|
||||
print_rank('LOG: val_acc={}: best_val_acc={}'.format(self.losses[1], self.best_val_acc))
|
||||
|
||||
if self.losses[0] < self.best_val_loss: # save the model when loss is improved
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token='best_val_loss',
|
||||
config=self.config['server_config']
|
||||
)
|
||||
self.best_val_loss = self.losses[0]
|
||||
else:
|
||||
# Create a schedule for the initial_lr (for the worker)
|
||||
self.lr_weight *= self.lr_decay_factor
|
||||
print_rank('LOG: Client weight of learning rate {}..'.format(self.lr_weight))
|
||||
|
||||
if self.losses[1] > self.best_val_acc: # save the model when CER is improved
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token='best_val_acc',
|
||||
config=self.config['server_config']
|
||||
)
|
||||
self.best_val_acc = self.losses[1]
|
||||
|
||||
# Run full testing
|
||||
if (i % self.rec_freq) == 0 and self.test_dataloader is not None:
|
||||
print_rank('Running Testing at itr={}'.format(i))
|
||||
|
||||
aggregated_metrics = self.run_distributed_inference(mode='test')
|
||||
|
||||
metric_logger('Test Loss', aggregated_metrics[0])
|
||||
metric_logger('Test Acc', aggregated_metrics[1])
|
||||
print_rank('LOG: test_loss={}: best_test_loss={}'.format(aggregated_metrics[0], self.best_test_loss))
|
||||
print_rank('LOG: test_acc={}: best_test_acc={}'.format(aggregated_metrics[1], self.best_test_acc))
|
||||
|
||||
if aggregated_metrics[0] < self.best_test_loss:
|
||||
self.best_test_loss=aggregated_metrics[0]
|
||||
|
||||
if aggregated_metrics[1] > self.best_test_acc:
|
||||
self.best_test_acc = aggregated_metrics[1]
|
||||
self.worker_trainer.save(
|
||||
model_path=self.model_path,
|
||||
token='best_test_acc',
|
||||
config=self.config['server_config'],
|
||||
)
|
||||
|
||||
def aggregate_gradients_inplace(self, client_parameters):
|
||||
'''Aggregate list of tensors into model gradients.
|
||||
|
||||
Args:
|
||||
client_parameters (list): list of tensors to aggregate.
|
||||
'''
|
||||
for p, client_grad in zip(self.worker_trainer.model.parameters(), client_parameters):
|
||||
if p.grad is None:
|
||||
p.grad = _to_cuda(client_grad)
|
||||
else:
|
||||
p.grad += _to_cuda(client_grad)
|
||||
|
||||
def aggregate_gradients(self, num_clients_curr_iter, client_weights, metric_logger=None):
|
||||
'''Go through stored gradients, aggregate and put them inside model.
|
||||
|
||||
Args:
|
||||
num_clients_curr_iter (int): how many clients were processed.
|
||||
client_weights: weight for each client.
|
||||
metric_logger (callback, optional): callback used for logging.
|
||||
Defaults to None, in which case AML logger is used.
|
||||
|
||||
Returns:
|
||||
float: sum of weights for all clients.
|
||||
'''
|
||||
|
||||
weight_sum = 0
|
||||
if metric_logger is None:
|
||||
metric_logger = run.log
|
||||
|
||||
if not self.aggregate_fast:
|
||||
metric_logger('Stale Gradients Ratio', len(self.client_parameters_stack_stale) / num_clients_curr_iter)
|
||||
if len(self.client_parameters_stack_stale) > 0:
|
||||
weight_sum = self.weight_sum_stale
|
||||
for client_parameters in self.client_parameters_stack_stale:
|
||||
# Model parameters are already multiplied with weight on client, we only have to sum them up
|
||||
self.aggregate_gradients_inplace(client_parameters)
|
||||
self.client_parameters_stack_stale = []
|
||||
self.weight_sum_stale = 0
|
||||
|
||||
for client_weight, client_parameters in zip(client_weights, self.client_parameters_stack):
|
||||
if np.random.random() > self.stale_prob:
|
||||
# Model parameters are already multiplied with weight on client, we only have to sum them up
|
||||
self.aggregate_gradients_inplace(client_parameters)
|
||||
else:
|
||||
self.weight_sum_stale += client_weight
|
||||
self.client_parameters_stack_stale.append(client_parameters)
|
||||
|
||||
# Some cleaning
|
||||
self.client_parameters_stack = []
|
||||
|
||||
weight_sum += sum(client_weights) - self.weight_sum_stale
|
||||
return weight_sum
|
||||
|
||||
|
||||
def select_server(server_type, config):
|
||||
'''Select a server type using different possible strings.
|
||||
|
||||
Right now this just returns `OptimizationServer`, but this
|
||||
function could be useful when there are multiple choices of
|
||||
server.
|
||||
|
||||
Args:
|
||||
server_type (str): indicates server choice.
|
||||
config (dict): config parsed from YAML, passed so that
|
||||
parameters can be used to select a given server.
|
||||
'''
|
||||
return OptimizationServer
|
|
@ -0,0 +1,536 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from utils import \
|
||||
get_lr, \
|
||||
get_lr_all, \
|
||||
make_optimizer, \
|
||||
make_lr_scheduler, \
|
||||
print_rank, \
|
||||
torch_save, \
|
||||
try_except_save, \
|
||||
write_yaml
|
||||
|
||||
|
||||
class TrainerBase:
|
||||
"""Abstract class defining Trainer objects' common interface.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to be trained.
|
||||
train_dataloader (torch.utils.data.DataLoader): dataloader that
|
||||
provides the training data.
|
||||
optimizer: (torch.optim.Optimizer): optimizer that will be used to
|
||||
update the model.
|
||||
max_grad_norm (float): if not None, avg gradients are clipped to this
|
||||
norm; defaults to None.
|
||||
ignore_subtask (bool): ignore subtasks, defaults to True.
|
||||
model_type (str): what kind of model is used, defaults to
|
||||
:code:`LanguageModel`.
|
||||
decoder_config (dict or None): config for decoder, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
train_dataloader,
|
||||
optimizer,
|
||||
max_grad_norm=None,
|
||||
ignore_subtask=True,
|
||||
model_type="LanguageModel",
|
||||
decoder_config=None
|
||||
):
|
||||
|
||||
self.model = model
|
||||
self.train_dataloader = train_dataloader
|
||||
self.optimizer = optimizer
|
||||
self.max_grad_norm = max_grad_norm
|
||||
self.model_type = model_type
|
||||
self.decoder_config = decoder_config
|
||||
|
||||
self.step = 0 # count how many batches are processed
|
||||
self.ignore_subtask = ignore_subtask # ignore subtasks even if there are multiple task branches
|
||||
|
||||
def epoch_boundary(self):
|
||||
'''Check if we are at the end of any given epoch.'''
|
||||
return self.step % len(self.train_dataloader.create_loader()) == 0 and self.step != 0
|
||||
|
||||
def train_desired_samples(self, desired_max_samples, apply_privacy_metrics):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
|
||||
class ModelUpdater(TrainerBase):
|
||||
"""Update the model, given the already computed gradient.
|
||||
|
||||
This is a special kind of trainer, that actually does not use any data.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to be updated.
|
||||
optimizer (torch.optim.Optimizer): optimizer that will be used to
|
||||
update the model.
|
||||
ss_scheduler: scheduled sampler.
|
||||
train_dataloader: train dataloader, this is not actually used.
|
||||
val_dataloader: val dataloader, this is not actually used.
|
||||
max_grad_norm (float): avg gradients are clipped to this norm.
|
||||
anneal_config (dict): annealing configuration.
|
||||
model_type (str): what kind of model is used, defaults to
|
||||
:code:`LanguageModel`.
|
||||
decoder_config (dict): config for decoder, defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
optimizer,
|
||||
ss_scheduler,
|
||||
train_dataloader,
|
||||
val_dataloader,
|
||||
max_grad_norm,
|
||||
anneal_config,
|
||||
model_type="LanguageModel",
|
||||
decoder_config=None
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
optimizer=optimizer,
|
||||
max_grad_norm=max_grad_norm,
|
||||
model_type=model_type,
|
||||
decoder_config=decoder_config
|
||||
)
|
||||
|
||||
self.val_dataloader = val_dataloader
|
||||
self.annealing_type = anneal_config["type"] if anneal_config is not None else None
|
||||
self.lr_scheduler = make_lr_scheduler(anneal_config, self.optimizer)
|
||||
self.ss_scheduler = ss_scheduler
|
||||
|
||||
def update_model(self):
|
||||
"""Update model parameters using pre-computed gradients."""
|
||||
|
||||
# Apply gradient clipping
|
||||
if self.max_grad_norm is not None:
|
||||
grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
print_rank(f"clipped norm: {grad_norm} to {min(grad_norm,self.max_grad_norm)}", logging.DEBUG)
|
||||
|
||||
# Do optimizer step
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def run_lr_scheduler(self, force_run_val=False):
|
||||
"""Update learning rate using scheduler."""
|
||||
|
||||
val_loss = val_acc = None
|
||||
if force_run_val is True or self.annealing_type == "val_loss":
|
||||
_, val_loss, val_acc = run_validation_generic(self.model, self.val_dataloader)
|
||||
|
||||
# Do LR scheduling
|
||||
print_rank(f"LR all: {list(get_lr_all(self.optimizer))}", loglevel=logging.DEBUG)
|
||||
print_rank("LR BEFORE lr_scheduler step: {}".format(get_lr(self.optimizer)))
|
||||
if self.annealing_type == "val_loss":
|
||||
self.lr_scheduler.step(val_loss)
|
||||
else:
|
||||
self.lr_scheduler.step()
|
||||
print_rank("LR AFTER lr_scheduler step: {}".format(get_lr(self.optimizer)), loglevel=logging.DEBUG)
|
||||
|
||||
return (val_loss, val_acc)
|
||||
|
||||
def run_ss_scheduler(self):
|
||||
"""Do scheduled sampling."""
|
||||
|
||||
if self.ss_scheduler is not None:
|
||||
self.ss_scheduler.step()
|
||||
|
||||
def save(self, model_path, token=None, config=None):
|
||||
"""Save model to disk."""
|
||||
|
||||
save_model(
|
||||
model_path=model_path,
|
||||
config=config,
|
||||
model=self.model,
|
||||
optimizer=self.optimizer,
|
||||
lr_scheduler=self.lr_scheduler,
|
||||
ss_scheduler=self.ss_scheduler,
|
||||
token=token
|
||||
)
|
||||
|
||||
def load(self, save_path, update_lr_scheduler, update_ss_scheduler):
|
||||
"""Load model from disk.
|
||||
|
||||
If save_path is given, load from there. If not, then resume training
|
||||
from current model dir. If at any point the save_path is not present on
|
||||
the disk, it won't be loaded.
|
||||
"""
|
||||
|
||||
if os.path.isfile(save_path):
|
||||
print_rank("Loading checkpoint: {}".format(save_path))
|
||||
checkpoint = torch.load(save_path)
|
||||
self.model.load_state_dict(checkpoint["model_state_dict"])
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
||||
|
||||
anl_st_dict = checkpoint.get("lr_scheduler_state_dict")
|
||||
if anl_st_dict and self.lr_scheduler is not None and update_lr_scheduler is True:
|
||||
self.lr_scheduler.load_state_dict(anl_st_dict)
|
||||
|
||||
sss_st_dict = checkpoint.get("ss_scheduler_state_dict")
|
||||
if sss_st_dict and self.ss_scheduler is not None and update_lr_scheduler is True:
|
||||
self.ss_scheduler.load_state_dict(sss_st_dict)
|
||||
|
||||
|
||||
class Trainer(TrainerBase):
|
||||
"""Perform training step for any given client.
|
||||
|
||||
The main method to be called for triggering a training step is
|
||||
:code:`train_desired_samples`, which on its turn relies on
|
||||
:code:`run_train_epoch`.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to be trained.
|
||||
ss_scheduler: scheduled sampler.
|
||||
train_dataloader (torch.data.utils.DataLoader): dataloader that
|
||||
provides the training data.
|
||||
val_dataloader (torch.data.utils.DataLoader): provides val data.
|
||||
server_replay_config (dict or None): config for replaying training;
|
||||
defaults to None, in which case no replaying happens.
|
||||
optimizer (torch.optim.Optimizer or None): optimizer that will be used
|
||||
to update the model. If :code:`None`, skip optimization.
|
||||
max_grad_norm (float or None): if not None, avg gradients are clipped
|
||||
to this norm; defaults to None.
|
||||
anneal_config (dict or None): annealing configuration.
|
||||
num_skips_threshold (int): previously used to skip users, deprecated.
|
||||
ignore_subtask (bool): ignore subtasks, defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
ss_scheduler,
|
||||
train_dataloader,
|
||||
val_dataloader,
|
||||
server_replay_config=None,
|
||||
optimizer=None,
|
||||
max_grad_norm=None,
|
||||
anneal_config=None,
|
||||
num_skips_threshold=-1,
|
||||
ignore_subtask=True
|
||||
):
|
||||
super().__init__(
|
||||
model=model,
|
||||
train_dataloader=train_dataloader,
|
||||
optimizer=optimizer,
|
||||
max_grad_norm=max_grad_norm,
|
||||
ignore_subtask=ignore_subtask
|
||||
)
|
||||
|
||||
self.server_replay_config=None
|
||||
if server_replay_config is not None:
|
||||
self.server_replay_config = server_replay_config
|
||||
|
||||
self.anneal_config=None
|
||||
if anneal_config is not None:
|
||||
self.anneal_config = anneal_config
|
||||
|
||||
self.lr_scheduler = None
|
||||
if self.optimizer is None and self.server_replay_config is not None and "optimizer" in self.server_replay_config:
|
||||
self.optimizer = make_optimizer(self.server_replay_config["optimizer_config"], model)
|
||||
|
||||
if self.optimizer is not None and self.anneal_config is not None:
|
||||
self.lr_scheduler = make_lr_scheduler(
|
||||
self.anneal_config,
|
||||
self.optimizer)
|
||||
|
||||
self.val_dataloader = val_dataloader
|
||||
self.cached_batches = []
|
||||
self.ss_scheduler = ss_scheduler
|
||||
|
||||
def reset_gradient_power(self):
|
||||
"""Reset the sum of gradient power.
|
||||
|
||||
This is used to compute statistics about the gradients.
|
||||
"""
|
||||
|
||||
self.sum_grad = self.sum_grad2 = self.counter = 0
|
||||
|
||||
def accumulate_gradient_power(self):
|
||||
"""Compute sum of gradient power.
|
||||
|
||||
This is used to compute statistics about the gradients.
|
||||
"""
|
||||
|
||||
for p in self.model.parameters():
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
grad = p.grad.detach().clone().cpu().numpy()
|
||||
p1 = np.sum(grad)
|
||||
p2 = np.sum(grad ** 2)
|
||||
n = p.grad.numel()
|
||||
|
||||
self.sum_grad += p1
|
||||
self.sum_grad2 += p2
|
||||
self.counter += n
|
||||
|
||||
print_rank("Magn. Grad. Squared: {}".format(self.sum_grad2), loglevel=logging.DEBUG)
|
||||
print_rank("Magn. Grad.: {}".format(self.sum_grad), loglevel=logging.DEBUG)
|
||||
return self.sum_grad, self.sum_grad2, self.counter
|
||||
|
||||
def estimate_sufficient_stats(self):
|
||||
"""Compute statistics about the gradients."""
|
||||
|
||||
sum_mean_grad, sum_mean_grad2, n = self.accumulate_gradient_power()
|
||||
self.sufficient_stats = {"n": n, "sum": sum_mean_grad, "sq_sum": sum_mean_grad2}
|
||||
|
||||
def train_desired_samples(self, desired_max_samples=None, apply_privacy_metrics=False):
|
||||
"""Triggers training step.
|
||||
|
||||
Args:
|
||||
desired_max_samples (int): number of samples that you would like to process.
|
||||
apply_privacy_metrics (bool): whether to save the batches used for the round for privacy metrics evaluation.
|
||||
|
||||
Returns:
|
||||
2-tuple of (float, int): total training loss and number of processed samples.
|
||||
"""
|
||||
|
||||
num_samples = 0
|
||||
total_train_loss = 0
|
||||
|
||||
num_samples_per_epoch, train_loss_per_epoch = self.run_train_epoch(desired_max_samples, apply_privacy_metrics)
|
||||
|
||||
num_samples += num_samples_per_epoch
|
||||
total_train_loss += train_loss_per_epoch
|
||||
|
||||
return total_train_loss, num_samples
|
||||
|
||||
def run_train_epoch(self, desired_max_samples=None, apply_privacy_metrics=False):
|
||||
"""Implementation example for training the model.
|
||||
|
||||
The training process should stop after the desired number of samples is processed.
|
||||
|
||||
Args:
|
||||
desired_max_samples (int): number of samples that you would like to process.
|
||||
apply_privacy_metrics (bool): whether to save the batches used for the round for privacy metrics evaluation.
|
||||
|
||||
Returns:
|
||||
2-tuple of (int, float): number of processed samples and total training loss.
|
||||
"""
|
||||
|
||||
sum_train_loss = 0.0
|
||||
num_samples = 0
|
||||
self.reset_gradient_power()
|
||||
|
||||
# Reset gradient just in case
|
||||
self.model.zero_grad()
|
||||
|
||||
train_loader = self.train_dataloader.create_loader()
|
||||
for batch in train_loader:
|
||||
if desired_max_samples is not None and num_samples >= desired_max_samples:
|
||||
break
|
||||
|
||||
# Compute loss
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
if self.ignore_subtask is True:
|
||||
loss = self.model.single_task_loss(batch)
|
||||
else:
|
||||
if apply_privacy_metrics:
|
||||
if "x" in batch:
|
||||
indices = batch["x"].cuda() if torch.cuda.is_available() else batch["x"]
|
||||
elif "input_ids" in batch:
|
||||
indices = batch["input_ids"].cuda() if torch.cuda.is_available() else batch["input_ids"]
|
||||
self.cached_batches.append(indices)
|
||||
loss = self.model.loss(batch)
|
||||
loss.backward()
|
||||
|
||||
# Apply gradient clipping
|
||||
if self.max_grad_norm is not None:
|
||||
grad_norm = nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
||||
|
||||
# Sum up the gradient power
|
||||
self.estimate_sufficient_stats()
|
||||
|
||||
# Now that the gradients have been scaled, we can apply them
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.step()
|
||||
|
||||
print_rank("step: {}, loss: {}".format(self.step, loss.item()), loglevel=logging.DEBUG)
|
||||
|
||||
# Post-processing in this loop
|
||||
# Sum up the loss
|
||||
sum_train_loss += loss.item()
|
||||
|
||||
# Increment the number of frames processed already
|
||||
if "attention_mask" in batch:
|
||||
num_samples += torch.sum(batch["attention_mask"].detach().cpu() == 1).item()
|
||||
elif "total_frames" in batch:
|
||||
num_samples += batch["total_frames"]
|
||||
else:
|
||||
num_samples += len(batch["x"])
|
||||
|
||||
# Update the counters
|
||||
self.step += 1
|
||||
|
||||
# Take a step in lr_scheduler
|
||||
if self.lr_scheduler is not None:
|
||||
self.lr_scheduler.step()
|
||||
|
||||
return num_samples, sum_train_loss
|
||||
|
||||
def prepare_iteration(self, model=None):
|
||||
"""Steps to run before iteration begins."""
|
||||
|
||||
if model is not None:
|
||||
self.model.load_state_dict(model.state_dict())
|
||||
|
||||
self.lr_scheduler = None
|
||||
if self.optimizer is None and self.server_replay_config is not None and \
|
||||
"optimizer_config" in self.server_replay_config:
|
||||
print_rank("Creating server-side replay training optimizer", loglevel=logging.DEBUG)
|
||||
self.optimizer = make_optimizer(self.server_replay_config["optimizer_config"], self.model)
|
||||
|
||||
if self.optimizer is not None and self.anneal_config is not None:
|
||||
print_rank("Creating server-side replay-training lr_scheduler", loglevel=logging.DEBUG)
|
||||
self.lr_scheduler = make_lr_scheduler(self.anneal_config, self.optimizer)
|
||||
|
||||
def reset_optimizer(self, optimizer_state_dict, annealing_config=None):
|
||||
"""Re-load optimizer."""
|
||||
|
||||
assert self.optimizer is not None, "This trainer does not have an optimizer"
|
||||
|
||||
# Load optimizer on state dict
|
||||
self.optimizer.load_state_dict(optimizer_state_dict)
|
||||
|
||||
# Set learning rate scheduler
|
||||
self.lr_scheduler = None
|
||||
if annealing_config is not None:
|
||||
self.lr_scheduler = make_lr_scheduler(annealing_config, self.optimizer)
|
||||
|
||||
|
||||
def run_validation_generic(model, val_dataloader):
|
||||
"""Perform a validation step.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): model to be validated.
|
||||
val_dataloader (torch.data.utils.DataLoader): provides val data.
|
||||
|
||||
Returns:
|
||||
Average validation loss.
|
||||
"""
|
||||
|
||||
print_rank("run_validation_generic", loglevel=logging.DEBUG)
|
||||
|
||||
val_losses, val_accuracies = list(), list()
|
||||
counter = 0
|
||||
model.set_eval()
|
||||
print_rank("set_eval", loglevel=logging.DEBUG)
|
||||
|
||||
# Initialize dataloader etc.
|
||||
val_loader = val_dataloader.create_loader()
|
||||
print_rank(
|
||||
f"created loader {val_loader.num_workers}, " + \
|
||||
f"users: {len(val_dataloader.dataset.user_list)} " + \
|
||||
f"examples: {sum(val_dataloader.dataset.num_samples)} " + \
|
||||
f"lendata: {len(val_loader)} ",
|
||||
loglevel=logging.DEBUG
|
||||
)
|
||||
|
||||
print_rank(
|
||||
f"drop_last: {val_loader.drop_last} " + \
|
||||
f"len_sampler: {len(val_loader._index_sampler)}",
|
||||
loglevel=logging.DEBUG
|
||||
)
|
||||
|
||||
# Perform inference and compute metrics
|
||||
output_tot = {"probabilities": [], "predictions": [], "labels":[]}
|
||||
with torch.no_grad():
|
||||
for _, batch in enumerate(val_loader):
|
||||
val_loss = model.loss(batch).item()
|
||||
output, val_acc, batch_size = model.inference(batch)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output_tot["probabilities"].append(output["probabilities"])
|
||||
output_tot["predictions"].append(output["predictions"])
|
||||
output_tot["labels"].append(output["labels"])
|
||||
|
||||
val_losses.append(val_loss * batch_size)
|
||||
val_accuracies.append(val_acc * batch_size)
|
||||
counter += batch_size
|
||||
|
||||
output_tot["probabilities"] = np.concatenate(output_tot["probabilities"]) if output_tot["probabilities"] else []
|
||||
output_tot["predictions"] = np.concatenate(output_tot["predictions"]) if output_tot["predictions"] else []
|
||||
output_tot["labels"] = np.concatenate(output_tot["labels"]) if output_tot["labels"] else []
|
||||
|
||||
# Post-processing of metrics
|
||||
print_rank(f"validation complete {counter}", loglevel=logging.DEBUG)
|
||||
|
||||
model.set_train()
|
||||
avg_val_loss = sum(val_losses) / counter
|
||||
avg_val_acc = sum(val_accuracies) / counter
|
||||
print_rank(f"validation examples {counter}", loglevel=logging.DEBUG)
|
||||
|
||||
return output_tot, avg_val_loss, avg_val_acc
|
||||
|
||||
def set_component_wise_lr(model, optimizer_config, updatable_names):
|
||||
"""Set zero learning rate for layers in order to freeze the update.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module):
|
||||
optimizer_config (string):
|
||||
updatable_names (list): ["^dec_rnn", "^fc"]
|
||||
"""
|
||||
|
||||
def name_matched(name, updatable_names):
|
||||
for updatable_name in updatable_names:
|
||||
if re.match(updatable_name, name) is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
# Set learning rate to zero in layers which name does not follow regex
|
||||
parameters = []
|
||||
for name, params in model.named_parameters():
|
||||
if name_matched(name, updatable_names) is True:
|
||||
print_rank("updating {} with lr = {}".format(name, optimizer_config["lr"]))
|
||||
parameters.append({"params": params, "lr":optimizer_config["lr"]})
|
||||
else:
|
||||
print_rank("freezing {}".format(name))
|
||||
parameters.append({"params": params, "lr": 0.0})
|
||||
|
||||
return parameters
|
||||
|
||||
def save_model(model_path, config, model, optimizer, lr_scheduler, ss_scheduler, token=None):
|
||||
"""Save a model as well as training information."""
|
||||
|
||||
save_state = {
|
||||
"model_state_dict" : model.state_dict(),
|
||||
"optimizer_state_dict" : optimizer.state_dict() if optimizer is not None else None,
|
||||
"lr_scheduler_state_dict" : lr_scheduler.state_dict() if lr_scheduler is not None else None
|
||||
}
|
||||
if ss_scheduler is not None:
|
||||
save_state["ss_scheduler_state_dict"] = ss_scheduler.state_dict()
|
||||
|
||||
if token: # just save as "best" and return
|
||||
save_path = os.path.join(model_path, "{}_model.tar".format(token))
|
||||
else:
|
||||
save_path = os.path.join(model_path, "model.tar")
|
||||
|
||||
print_rank("Saving model to: {}".format(save_path))
|
||||
try_except_save(torch_save, state_or_model=save_state, save_path=save_path)
|
||||
|
||||
# Write out the config to model_dir
|
||||
if config is not None:
|
||||
try_except_save(write_yaml, config=config,
|
||||
save_path=os.path.join(model_path, "config.yaml"))
|
|
@ -0,0 +1,20 @@
|
|||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = .
|
||||
BUILDDIR = _build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
|
@ -0,0 +1,12 @@
|
|||
Advanced Topics
|
||||
===============
|
||||
|
||||
Privacy
|
||||
-------
|
||||
|
||||
Aggregation Options
|
||||
-------------------
|
||||
|
||||
|
||||
Optimizer Options
|
||||
-----------------
|
|
@ -0,0 +1,58 @@
|
|||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# This file only contains a selection of the most common options. For a full
|
||||
# list see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Path setup --------------------------------------------------------------
|
||||
|
||||
# If extensions (or modules to document with autodoc) are in another directory,
|
||||
# add these directories to sys.path here. If the directory is relative to the
|
||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
|
||||
#
|
||||
# import os
|
||||
# import sys
|
||||
# sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'FLUTE'
|
||||
copyright = '2021, Microsoft Research'
|
||||
author = 'Microsoft Research'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
||||
# Add any Sphinx extension module names here, as strings. They can be
|
||||
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
|
||||
# ones.
|
||||
extensions = [
|
||||
]
|
||||
|
||||
# Add any paths that contain templates here, relative to this directory.
|
||||
templates_path = ['_templates']
|
||||
|
||||
# List of patterns, relative to source directory, that match files and
|
||||
# directories to ignore when looking for source files.
|
||||
# This pattern also affects html_static_path and html_extra_path.
|
||||
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
|
||||
# The theme to use for HTML and HTML Help pages. See the documentation for
|
||||
# a list of builtin themes.
|
||||
#
|
||||
#html_theme = 'alabaster'
|
||||
|
||||
# Add any paths that contain custom static files (such as style sheets) here,
|
||||
# relative to this directory. They are copied after the builtin static files,
|
||||
# so a file named "default.css" will overwrite the builtin "default.css".
|
||||
html_static_path = ['_static']
|
||||
|
||||
import sphinx_rtd_theme
|
||||
|
||||
html_theme = 'sphinx_rtd_theme'
|
||||
|
||||
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
|
|
@ -0,0 +1,30 @@
|
|||
Data Preparation
|
||||
================
|
||||
|
||||
TODO: formatting for other data loaders.
|
||||
|
||||
Here is a sample data blob for language model training.
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"users": ["bert","elmo"],
|
||||
"user_data": {
|
||||
"bert": {"x": ["my name is Bert.", "I live with Ernie."]},
|
||||
"elmo": {"x": ["Big Bird is my friend."]}
|
||||
},
|
||||
"num_samples": [2, 1]
|
||||
}
|
||||
|
||||
The blob consists of three fields. The ``users`` field indicates a unique id for each user in the training data. Users are sampled uniformly to create client tasks during training. There could be many more users than client tasks per round or even over all client tasks over all rounds. The ``user_data`` field contains user-indexed training data. Each user's data is a dictionary of the form ``{"x": [list of examples]}``. Finally, the ``num_samples`` field indicates the number of samples for each user, in order of the ``users`` list. That is, for any index ``i`` in ``range(len(data['users']))``:
|
||||
|
||||
.. code:: python
|
||||
|
||||
data['num_samples'][i] == len(data['user_data'][data['users'][i]]['x'])
|
||||
|
||||
|
||||
Test and validation data is formatted similarly.
|
||||
|
||||
.. note::
|
||||
|
||||
Test/validate data is dispatched to workers by partitioning on users. If your test data isn't user-partitioned, we recommend partitioning it uniformly using some dummy user ids.
|
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 68 KiB |
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 80 KiB |
Двоичный файл не отображается.
|
@ -0,0 +1,25 @@
|
|||
.. FLUTE documentation master file, created by
|
||||
sphinx-quickstart on Sat Jun 19 09:15:36 2021.
|
||||
You can adapt this file completely to your liking, but it should at least
|
||||
contain the root `toctree` directive.
|
||||
|
||||
Welcome to FLUTE documentation!
|
||||
===============================
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: Contents:
|
||||
|
||||
overview
|
||||
quickstart
|
||||
dataprep
|
||||
scenarios
|
||||
advanced
|
||||
reference
|
||||
|
||||
Indices and tables
|
||||
==================
|
||||
|
||||
* :ref:`genindex`
|
||||
* :ref:`modindex`
|
||||
* :ref:`search`
|
|
@ -0,0 +1,35 @@
|
|||
@ECHO OFF
|
||||
|
||||
pushd %~dp0
|
||||
|
||||
REM Command file for Sphinx documentation
|
||||
|
||||
if "%SPHINXBUILD%" == "" (
|
||||
set SPHINXBUILD=sphinx-build
|
||||
)
|
||||
set SOURCEDIR=.
|
||||
set BUILDDIR=_build
|
||||
|
||||
if "%1" == "" goto help
|
||||
|
||||
%SPHINXBUILD% >NUL 2>NUL
|
||||
if errorlevel 9009 (
|
||||
echo.
|
||||
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
||||
echo.installed, then set the SPHINXBUILD environment variable to point
|
||||
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
||||
echo.may add the Sphinx directory to PATH.
|
||||
echo.
|
||||
echo.If you don't have Sphinx installed, grab it from
|
||||
echo.http://sphinx-doc.org/
|
||||
exit /b 1
|
||||
)
|
||||
|
||||
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
goto end
|
||||
|
||||
:help
|
||||
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
||||
|
||||
:end
|
||||
popd
|
|
@ -0,0 +1,26 @@
|
|||
FLUTE Overview
|
||||
============
|
||||
|
||||
FLUTE uses a distributed processing architecture backed by OpenMPI. An FLUTE job consists of one or more nodes (physical or virtual machines) executing a total of K workers (independent OS-level processes).
|
||||
|
||||
Worker 0 acts as a central orchestrator, maintaining and distributing a central model to workers, and subsequently distributing client tasks to the workers.
|
||||
|
||||
Each worker>0 processes client tasks sequentially, consisting of data encoding and one or more batch updates to the central model (note the central model is reset to its original state for each client task). As each client task completes, the model delta, aka the pseudo-gradient is sent back to the orchestrator for federation into a new central model.
|
||||
|
||||
Execution runs for up to N training rounds. In each round the orchestrator may sample a subset of clients, and may also randomly delay pseudo-gradient updates from some clients to future rounds. The orchestrator will also periodically distribute evaluation tasks to determine model quality on validation and test data.
|
||||
|
||||
.. note:: AzureML generally expects there will be one worker per GPU on each node.
|
||||
.. note:: Due to networking overhead, it is often faster to run jobs on a single node with 8 or 16 GPUs, rather than on multiple nodes.
|
||||
|
||||
Architecture
|
||||
------------
|
||||
|
||||
.. figure:: img/concepts.png
|
||||
:width: 400
|
||||
|
||||
An FLUTE job consists of one or more independent nodes (multi-GPU VMs) executing up to K workers.
|
||||
|
||||
.. figure:: img/client-server.png
|
||||
:width: 600
|
||||
|
||||
On each training round the orchestrator (Worker 0) dispatches the central model to the rest of the workers, and then queues up client tasks for workers to execute. Workers receive client tasks (client training data and training config) and execute SGD on the central model using their client's training data, sending the model delta (pseudo-gradient) back to the orchestrator.
|
|
@ -0,0 +1,33 @@
|
|||
Option Reference
|
||||
================
|
||||
|
||||
Command Line Arguments
|
||||
----------------------
|
||||
|
||||
YAML Configuration
|
||||
------------------
|
||||
|
||||
FLUTE yaml files consist of three main sections, and a few optional sections. The `model_config` specifies model architecture and pretrained model setup path. The `server_config` section defines server settings such as total training rounds, aggregation method, optimizer settings, learning rate schedule, and any server-side training data. The `client_config` section specifies client optimizer settings and the client-side training data.
|
||||
|
||||
.. note:: Training data is loaded by the server and dispatched to the clients. The configuration settings for this data are specified in the `client_config`.
|
||||
|
||||
|
||||
model_config
|
||||
~~~~~~~~~~~~
|
||||
|
||||
server_config
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
client_config
|
||||
~~~~~~~~~~~~~
|
||||
|
||||
Optional Sections
|
||||
-----------------
|
||||
In addition to the main sections, some optional sections may be specified to control privacy settings, specifically a `dp_config` section for differential privacy settings, and `privacy_metrics_config` for applying privacy metrics.
|
||||
|
||||
|
||||
dp_config
|
||||
~~~~~~~~~
|
||||
|
||||
privacy_metrics_config
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
|
@ -0,0 +1,64 @@
|
|||
Adding New Scenarios
|
||||
====================
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
Before adding a new scenario in FLUTE, make sure that your files comply with following:
|
||||
|
||||
* The model class has declared the functions: loss(), inference(), set_eval() and set_train()
|
||||
* Inference function is used for testing and must return loss, accuracy and batch size.
|
||||
* Raw data input must be stored in JSON or HDF5 files
|
||||
* FLUTE assumes the following format from the text data
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
{"num_samples": [sample_1, ......, sample_n],
|
||||
"users":[user_1, ......, user_n],
|
||||
"user_data": {
|
||||
"user_1":{
|
||||
"x":[ .. data..,
|
||||
.....data_n..]
|
||||
},
|
||||
.......
|
||||
"user_n":{
|
||||
"x":[ .. data..,
|
||||
.....data_n..]
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
.. note:: The list 'x' inside of the dictionary with the name of the user, can contain data or arrays.
|
||||
Alternatively, instead of using a single-key dictionary, a list of lists might be assigned to each user.
|
||||
|
||||
Copy the files
|
||||
------------
|
||||
|
||||
All mandatory files must be inside a folder with the same name as the model in /models. Please adjust your files with the following
|
||||
naming structure so FLUTE can be able to find all the scripts needed.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
model_name
|
||||
|---- dataloaders
|
||||
|---- text_dataloader.py
|
||||
|---- utils
|
||||
|---- utils.py
|
||||
|---- model.py
|
||||
|---- README.txt
|
||||
|
||||
In case you need to import a module that has not been considered in FLUTE, this can be added in requirements.txt
|
||||
|
||||
.. note:: All files must contain only absolute imports, in order to avoid issues when running your job.
|
||||
|
||||
Create a model configuration file
|
||||
------------
|
||||
Once your model has been added into FLUTE, it is necessary to create a configuration file (in case you haven't already), specifiying all the parameters
|
||||
for the model. A template has been provided for this in ./configs/hello_world_local_nlg_gru_json.yaml
|
||||
|
||||
Troubleshooting
|
||||
------------
|
||||
* If a module is not being recognized by Python, verify that this module has been previously installed or is included in requirements.txt
|
||||
* If the model class is not being detected, make sure the name of the model class is the same as specified in the yaml configuration file (case sensitive)
|
||||
* If the dataloader type is not being detected, make sure that field 'loader_type' has been declared in the yaml configuration file.
|
|
@ -0,0 +1 @@
|
|||
cf6d3e4ce75fc1d8f74e62a29933ccb4d956cdfe
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
'''
|
||||
This is the main script to run on each MPI thread. It will spawn either a
|
||||
Server or Worker object -- the former is responsible for orchestrating and
|
||||
aggregating models, where as the latter processes clients' data to generate
|
||||
a new model. The Server lives on the very first thread, whereas remaining
|
||||
threads contain each a diferent Worker.
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import yaml
|
||||
from psutil import virtual_memory
|
||||
|
||||
import torch
|
||||
from azureml.core import Run
|
||||
|
||||
from core import federated
|
||||
from core.server import select_server
|
||||
from core.client import Client
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE, logging_level, define_file_type
|
||||
from experiments import make_model
|
||||
from utils import (
|
||||
make_optimizer,
|
||||
init_logging,
|
||||
print_rank,
|
||||
find_pretrained_model
|
||||
)
|
||||
from utils.dataloaders_utils import (
|
||||
make_train_dataloader,
|
||||
make_val_dataloader,
|
||||
make_test_dataloader,
|
||||
)
|
||||
from config_file_parser import (
|
||||
check_server_config,
|
||||
check_client_config
|
||||
)
|
||||
|
||||
assert TRAINING_FRAMEWORK_TYPE == "mpi", "Unsupported platform {}".format(TRAINING_FRAMEWORK_TYPE)
|
||||
|
||||
|
||||
def log_run_properties(config):
|
||||
"""Log parameters on AzureML.
|
||||
|
||||
Args:
|
||||
config (dict): config containing parameters to log.
|
||||
"""
|
||||
|
||||
properties = {}
|
||||
|
||||
def lookup(key, cfg, default):
|
||||
"""Look for key on dict"""
|
||||
keys = key.split(".")
|
||||
if len(keys) == 1:
|
||||
return cfg.get(key, default)
|
||||
if keys[0] in cfg:
|
||||
return lookup(".".join(keys[1:]), cfg[keys[0]], default)
|
||||
else:
|
||||
return default
|
||||
|
||||
# Build properties dictionary
|
||||
mem = virtual_memory()
|
||||
properties["System memory (GB)"] = float(mem.total) / (1024**3)
|
||||
|
||||
props = [
|
||||
("server_config.num_clients_per_iteration", 0),
|
||||
("server_config.max_iteration", 0),
|
||||
("dp_config.eps", 0),
|
||||
("dp_config.max_weight", 0),
|
||||
("dp_config.min_weight", 0),
|
||||
("server_config.optimizer_config.type", "sgd"),
|
||||
("server_config.optimizer_config.lr", 1.0),
|
||||
("server_config.optimizer_config.amsgrad", False),
|
||||
("server_config.annealing_config.type", "step_lr"),
|
||||
("server_config.annealing_config.step_interval", "epoch"),
|
||||
("server_config.annealing_config.gamma", 1.0),
|
||||
("server_config.annealing_config.step_size", 100),
|
||||
]
|
||||
|
||||
for (key, default) in props:
|
||||
properties[key] = lookup(key, config, default)
|
||||
|
||||
# Log the properties dictionary into AzureML
|
||||
run = Run.get_context()
|
||||
for k in properties:
|
||||
run.log(k, properties[k])
|
||||
|
||||
|
||||
def run_worker(model_path, config, task, data_path, local_rank):
|
||||
"""Spawn worker object that lives throughout MPI thread.
|
||||
|
||||
Args:
|
||||
model_path (str): path to the pretrained model.
|
||||
config (dict): dictionary containing parameters.
|
||||
task (str): what task to solve, must be a folder of :code:`experiments`.
|
||||
data_path (str): path to data.
|
||||
local_rank (int): the rank of the MPI thread.
|
||||
"""
|
||||
model_config = config["model_config"]
|
||||
server_config = config["server_config"]
|
||||
define_file_type(data_path, config)
|
||||
|
||||
# Get the rank on MPI
|
||||
rank = local_rank if local_rank > -1 else federated.rank()
|
||||
|
||||
# Assign MPI thread to a specific GPU
|
||||
if torch.cuda.is_available():
|
||||
n_gpus = torch.cuda.device_count()
|
||||
torch.cuda.set_device(federated.local_rank() % n_gpus)
|
||||
print_rank(f"Assigning worker to GPU {federated.local_rank() % n_gpus}")
|
||||
|
||||
# Make the Model to distribute to workers
|
||||
model = make_model(model_config)
|
||||
|
||||
# Instantiate the Server object on the first thread
|
||||
if rank == 0:
|
||||
try:
|
||||
print_rank('Server data preparation')
|
||||
|
||||
# pre-cache the training data and capture the number of clients for sampling
|
||||
training_filename = os.path.join(data_path, config["client_config"]["data_config"]["train"]["list_of_train_data"])
|
||||
config["server_config"]["data_config"]["num_clients"] = Client.get_num_users(training_filename)
|
||||
data_config = config['server_config']['data_config']
|
||||
|
||||
# Make the Dataloaders
|
||||
if 'train' in data_config:
|
||||
server_train_dataloader = make_train_dataloader(data_config['train'], data_path, task=task, clientx=None)
|
||||
else:
|
||||
server_train_dataloader = None
|
||||
val_dataloader = make_val_dataloader(data_config["val"], data_path, task=task)
|
||||
test_dataloader = make_test_dataloader(data_config["test"], data_path, task=task)
|
||||
|
||||
print_rank("Prepared the dataloaders")
|
||||
|
||||
# Create the optimizer on the server
|
||||
optimizer = make_optimizer(server_config["optimizer_config"], model)
|
||||
|
||||
# Load a model that's already trained
|
||||
best_trained_model = find_pretrained_model(model_path, model_config)
|
||||
if best_trained_model is not None:
|
||||
model_state_dict = torch.load(best_trained_model,
|
||||
map_location=None if torch.cuda.is_available() else torch.device("cpu"))
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
server_type = server_config["type"]
|
||||
server_setup = select_server(server_type, config) # Return the server class
|
||||
server = server_setup(
|
||||
data_config["num_clients"],
|
||||
model,
|
||||
optimizer,
|
||||
None,
|
||||
data_path,
|
||||
model_path,
|
||||
server_train_dataloader,
|
||||
val_dataloader,
|
||||
test_dataloader,
|
||||
config,
|
||||
server_config
|
||||
)
|
||||
log_run_properties(config)
|
||||
|
||||
except Exception as e:
|
||||
# Be sure the other workers are shut down.
|
||||
server.terminate_workers()
|
||||
raise e
|
||||
|
||||
print_rank("Launching server")
|
||||
server.run()
|
||||
|
||||
else:
|
||||
# Instantiate client-processing Worker on remaining threads
|
||||
print_rank("Worker on node {}: process started".format(rank))
|
||||
client_config = config["client_config"]
|
||||
worker = federated.Worker(
|
||||
model,
|
||||
data_path,
|
||||
do_profiling=client_config.get("do_profiling", False),
|
||||
clients_in_parallel=client_config.get("clients_in_parallel", None),
|
||||
)
|
||||
worker.run()
|
||||
|
||||
|
||||
def _reconcile_args(args, config):
|
||||
'''Change parameters depending on command-line arguments'''
|
||||
|
||||
if args.dp_config_grad_dir_eps:
|
||||
config["dp_config"]["grad_dir_eps"] = args.dp_config_grad_dir_eps
|
||||
if args.dp_config_grad_mag_eps:
|
||||
config["dp_config"]["grad_mag_eps"] = args.dp_config_grad_mag_eps
|
||||
if args.dp_config_weight_eps:
|
||||
config["dp_config"]["weight_eps"] = args.dp_config_weight_eps
|
||||
|
||||
return config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-config")
|
||||
parser.add_argument("-outputPath")
|
||||
parser.add_argument("-dataPath", default=None)
|
||||
parser.add_argument("-task", default=None, help="Define the task for the run")
|
||||
parser.add_argument("-num_skip_decoding", default=-1, type=int, help="Skip decoding in unsupervised learning mode")
|
||||
parser.add_argument("--local_rank", default=-1, type=int)
|
||||
parser.add_argument("--dp_config_grad_dir_eps", default=None, type=float, help="DP direction epsilon")
|
||||
parser.add_argument("--dp_config_grad_mag_eps", default=None, type=float, help="DP magnitude epsilon")
|
||||
parser.add_argument("--dp_config_weight_eps", default=None, type=float, help="DP weight epsilon")
|
||||
|
||||
args = parser.parse_args()
|
||||
data_path = args.dataPath
|
||||
task = args.task
|
||||
local_rank = args.local_rank
|
||||
|
||||
# Create dictionaries w/ parameters
|
||||
default_data_conf = {
|
||||
"input_dim": 300,
|
||||
"batch_size": 40,
|
||||
"loader_type": "text",
|
||||
"prepend_datapath": False,
|
||||
"pin_memory": True,
|
||||
"num_frames": 0,
|
||||
"desired_max_samples": 300,
|
||||
"max_grad_norm": 5.0, # max_grad_norm for gradient clipping
|
||||
"num_workers": 1,
|
||||
"max_batch_size": 0, # maximum number of batch size; if 0, no limitation is applied
|
||||
"unsorted_batch": False # do not sort when making batch; this is inefficient in terms of batch, but could be efficient in terms of accuracy
|
||||
}
|
||||
|
||||
default_server_conf = {
|
||||
"val_freq": 1,
|
||||
"rec_freq": 8,
|
||||
"max_iteration": 100000000,
|
||||
"type": "optimization",
|
||||
"data_config": default_data_conf,
|
||||
"aggregate_median": None,
|
||||
"best_model_criterion": "loss",
|
||||
"fall_back_to_best_model": False,
|
||||
"num_clients_per_iteration": -1
|
||||
}
|
||||
|
||||
default_client_conf = {
|
||||
"copying_train_jsonls": True,
|
||||
"type": "gradient_computation",
|
||||
"data_config": default_data_conf,
|
||||
}
|
||||
|
||||
# The mount point can also be retrieved from input_datasets of the run context
|
||||
if data_path is None:
|
||||
data_path = Run.get_context().input_datasets["input"]
|
||||
print("The data can be found here: ", data_path)
|
||||
|
||||
# Update the model path for the sake of AzureML
|
||||
id = Run.get_context().id
|
||||
experiment_name = "-".join(id.split("-")[-4:-2])
|
||||
experiment_root = os.path.join(args.outputPath, experiment_name)
|
||||
model_path = os.path.join(experiment_root, "models")
|
||||
log_path = os.path.join(experiment_root, "log")
|
||||
|
||||
os.makedirs(model_path, exist_ok=True)
|
||||
os.makedirs(log_path, exist_ok=True)
|
||||
|
||||
# Make a copy of the config file into the output folder, for future reference
|
||||
cfg_out = os.path.join(experiment_root, "FLUTE_config.yaml")
|
||||
if local_rank <= 0:
|
||||
shutil.copyfile(args.config, cfg_out)
|
||||
print("Copy created")
|
||||
|
||||
# Initialize logging
|
||||
init_logging(log_path, loglevel=logging_level)
|
||||
|
||||
with open(args.config) as f:
|
||||
config = yaml.safe_load(f)
|
||||
config = _reconcile_args(args, config) # replace params. depending on CL args.
|
||||
|
||||
assert "num_clients" not in config["server_config"]["data_config"], "Remove \"num_clients\" from server data_config since this is a reserved key"
|
||||
assert "num_clients" not in config["client_config"]["data_config"], "Remove \"num_clients\" from client data_config since this is a reserved key"
|
||||
|
||||
# Make sure the pretrained model is found in the correct place
|
||||
if "pretrained_model_path" in config["model_config"]["model_type"]:
|
||||
config["model_config"]["model_type"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["model_type"]["pretrained_model_path"])
|
||||
if "pretrained_model_path" in config["model_config"]:
|
||||
config["model_config"]["pretrained_model_path"] = os.path.join(data_path, config["model_config"]["pretrained_model_path"])
|
||||
|
||||
config["data_path"] = data_path
|
||||
|
||||
config = check_server_config(config, default_server_conf)
|
||||
config = check_client_config(config, default_client_conf)
|
||||
|
||||
# Add task specification to client configuration
|
||||
config["client_config"]["task"] = task
|
||||
config["server_config"]["task"] = task
|
||||
|
||||
# RL-related options
|
||||
if config["server_config"].get("wantRL", False):
|
||||
if config["server_config"]["RL"].get("RL_path_global", True):
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath,
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
else:
|
||||
config["server_config"]["RL"]["RL_path"] = os.path.join(args.outputPath, experiment_name,
|
||||
config["server_config"]["RL"]["RL_path"])
|
||||
|
||||
# Instantiate either Server or Worker on the thread
|
||||
run_worker(model_path, config, task, data_path, local_rank)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from utils import print_rank, print_cuda_stats
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
def make_model(model_config, dataloader_type=None, input_dim=-1, output_dim=-1):
|
||||
print('Preparing model .. Initializing')
|
||||
|
||||
try:
|
||||
dir = "./"+ str(model_config["model_folder"])
|
||||
model_class = model_config["model_type"]
|
||||
loader = SourceFileLoader(model_class,dir).load_module()
|
||||
model_type = getattr(loader,model_class )
|
||||
except:
|
||||
raise ValueError("{} model not found, make sure to indicate the model path in the .yaml file".format(model_config["type"]))
|
||||
|
||||
model = model_type(model_config)
|
||||
print(model)
|
||||
|
||||
if not "weight_init" in model_config or model_config["weight_init"] == "default":
|
||||
print_rank("initialize model with default settings")
|
||||
pass
|
||||
elif model_config["weight_init"] == "xavier_normal":
|
||||
print_rank("initialize model with xavier_normal")
|
||||
for p in model.parameters():
|
||||
if p.dim() > 1: # weight
|
||||
torch.nn.init.xavier_normal_(p.data)
|
||||
elif p.dim() == 1: # bias
|
||||
p.data.zero_()
|
||||
for m in model.modules():
|
||||
if isinstance(m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.BatchNorm2d)):
|
||||
m.reset_parameters()
|
||||
else:
|
||||
return ValueError("{} not supported".format(model_config["weight_init"]))
|
||||
|
||||
print_rank("trying to move the model to GPU")
|
||||
|
||||
# Move it to GPU if you can
|
||||
if torch.cuda.is_available():
|
||||
model.cuda()
|
||||
print_rank(f"moved the model to GPU {torch.cuda.current_device()}")
|
||||
else:
|
||||
model.cpu()
|
||||
print_rank("no GPU available.")
|
||||
|
||||
print_rank("model: {}".format(model))
|
||||
print_cuda_stats()
|
||||
|
||||
return model
|
|
@ -0,0 +1,3 @@
|
|||
utils/data
|
||||
*.hdf5
|
||||
*.json
|
|
@ -0,0 +1,70 @@
|
|||
# Simple example of a CNN on CIFAR-10
|
||||
|
||||
Our objective here is to bring a simple experiment from the Pytorch tutorials,
|
||||
more specifically the one in https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py,
|
||||
and convert it to FLUTE. Instructions on how to do this are given below.
|
||||
|
||||
An adapted version of the tutorial above is provided in the
|
||||
`utils/centralized_training.py` script.
|
||||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but right now we need to
|
||||
convert the data to either of these formats. In our case, we can use the script
|
||||
`utils/download_and_convert_data.py` to do that for us; a HDF5 file will be
|
||||
generated.
|
||||
|
||||
## Specifying the model
|
||||
|
||||
Next, we prepare the model. The `model.py` file contains two classes: one is the
|
||||
`Net` class already contained in the original script, and the other, a class
|
||||
called `CNN` which effectively wraps `Net`. Importantly, the `CNN` class defines
|
||||
two methods: `loss` and `inference`; both perform forward steps and then perform
|
||||
additional computations, in particular, the former executes the loss' evaluation,
|
||||
and the latter the metrics' computation. The format of the inputs and outputs
|
||||
should be the same as in this example.
|
||||
|
||||
## Specifying dataset and dataloaders
|
||||
|
||||
Inside the `dataloaders` folder, there are two files: `text_dataset.py` and
|
||||
`text_dataloader.py` (the word "text" is used to mimic the other datasets, even
|
||||
though in practice this loads images -- this will be changed in the future).
|
||||
Both inherit from the Pytorch classes with same name.
|
||||
|
||||
The dataset should be able to access all the data, which is stored in the
|
||||
attributes `user_list`, `user_data`, `user_data_labels` and `num_samples` (user
|
||||
names, user features, user labels if the problem is supervised, and number of
|
||||
samples for each user, respectively). These attributes are required to have
|
||||
these exact names. Otherwise, it should also be able to access the examples of a
|
||||
specific user, which id is passed during initialization via the `user_idx`
|
||||
argument.
|
||||
|
||||
The dataloader is simpler, and essentially just instantiates the dataset and
|
||||
creates batches with a specific format.
|
||||
|
||||
## Creating a config file
|
||||
|
||||
All the parameters of the experiment are passed in a YAML file. A documented
|
||||
example is provided in `config.yaml`.
|
||||
|
||||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment, it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI (don't forget to first run
|
||||
`utils/download_and_convert_data.py`):
|
||||
|
||||
```
|
||||
mpiexec -n 4 python e2e_trainer.py -dataPath experiments/classif_cnn/utils/data -outputPath scratch -config experiments/classif_cnn/config.yaml -task classif_cnn
|
||||
```
|
||||
|
||||
The `dataPath`, `outputPath` and `config` arguments should just specify the
|
||||
respective files or folders, as in the example above -- in this case, a folder
|
||||
called `scratch` will be created containing logs and checkpoints. The task
|
||||
should be the name of the folder insider `experiments`.
|
||||
|
||||
Following what is specified in the config file, the experiment will run for
|
||||
2000 rounds, and during each of them 10 clients will be selected at random,
|
||||
each of whom has 50 samples. It is more or less the same, then, as the 2
|
||||
epochs in the centralized training, except that clients are selected at
|
||||
random so we might not see all of them.
|
|
@ -0,0 +1,57 @@
|
|||
model_config:
|
||||
model_type: CNN # class w/ `loss` and `inference` methods
|
||||
model_folder: experiments/classif_cnn/model.py # file containing class
|
||||
|
||||
dp_config:
|
||||
enable_local_dp: false # whether to enable user-level DP
|
||||
|
||||
privacy_metrics_config:
|
||||
apply_metrics: false # cache data to compute additional metrics
|
||||
|
||||
server_config:
|
||||
wantRL: false # whether to use RL-based meta-optimizers
|
||||
resume_from_checkpoint: false # restart from checkpoint if file exists
|
||||
do_profiling: false # run profiler and compute runtime metrics
|
||||
optimizer_config: # this is the optimizer used to update the model
|
||||
type: sgd
|
||||
lr: 1.0
|
||||
annealing_config: # annealer for the learning rate
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
gamma: 1.0
|
||||
step_size: 100
|
||||
val_freq: 50 # how many iterations between metric eval on val set
|
||||
rec_freq: 100 # how many iterations between metric eval on test set
|
||||
max_iteration: 2000 # how many iterations in total
|
||||
num_clients_per_iteration: 10 # how many clients per iteration
|
||||
data_config: # where to get val and test data from
|
||||
val:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
val_data: test_data.hdf5
|
||||
test:
|
||||
batch_size: 10000
|
||||
loader_type: text
|
||||
test_data: test_data.hdf5
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # how aggregations weights are computed
|
||||
initial_lr_client: 0.001 # learning rate used on client optimizer
|
||||
lr_decay_factor: 1.0
|
||||
weight_train_loss: train_loss
|
||||
best_model_criterion: loss
|
||||
fall_back_to_best_model: false
|
||||
|
||||
client_config:
|
||||
do_profiling: false # run profiling and compute runtime metrics
|
||||
ignore_subtask: false
|
||||
data_config: # where to get training data from
|
||||
train:
|
||||
batch_size: 4
|
||||
loader_type: text
|
||||
list_of_train_data: train_data.hdf5
|
||||
desired_max_samples: 50000
|
||||
optimizer_config: # this is the optimizer used by the client
|
||||
type: sgd
|
||||
lr: 0.001 # this is overridden by `initial_lr_client`
|
||||
momentum: 0.9
|
||||
type: optimization
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from experiments.classif_cnn.dataloaders.text_dataset import TextDataset
|
||||
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
def __init__(self, mode, num_workers=0, **kwargs):
|
||||
args = kwargs['args']
|
||||
self.batch_size = args['batch_size']
|
||||
|
||||
dataset = TextDataset(
|
||||
data=kwargs['data'],
|
||||
test_only=(not mode=='train'),
|
||||
user_idx=kwargs.get('user_idx', None),
|
||||
file_type='hdf5',
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=(mode=='train'),
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
)
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
def collate_fn(self, batch):
|
||||
x, y = list(zip(*batch))
|
||||
return {'x': torch.tensor(x), 'y': torch.tensor(y)}
|
|
@ -0,0 +1,56 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import h5py
|
||||
import json
|
||||
import numpy as np
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class TextDataset(Dataset):
|
||||
def __init__(self, data, test_only=False, user_idx=None, file_type=None):
|
||||
self.test_only = test_only
|
||||
self.user_idx = user_idx
|
||||
self.file_type = file_type
|
||||
|
||||
# Get all data
|
||||
self.user_list, self.user_data, self.user_data_label, self.num_samples = self.load_data(data, self.file_type)
|
||||
|
||||
if self.test_only: # combine all data into single array
|
||||
self.user = 'test_only'
|
||||
self.features = np.vstack([user_data['x'] for user_data in self.user_data.values()])
|
||||
self.labels = np.hstack(list(self.user_data_label.values()))
|
||||
else: # get a single user's data
|
||||
if user_idx is None:
|
||||
raise ValueError('in train mode, user_idx must be specified')
|
||||
|
||||
self.user = self.user_list[user_idx]
|
||||
self.features = self.user_data[self.user]['x']
|
||||
self.labels = self.user_data_label[self.user]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.features[idx].astype(np.float32).T, self.labels[idx]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
@staticmethod
|
||||
def load_data(data, file_type):
|
||||
'''Load data from disk or memory.
|
||||
|
||||
The :code:`data` argument can be either the path to the JSON
|
||||
or HDF5 file that contains the expected dictionary, or the
|
||||
actual dictionary.'''
|
||||
|
||||
if isinstance(data, str):
|
||||
if file_type == 'json':
|
||||
with open(data, 'r') as fid:
|
||||
data = json.load(fid)
|
||||
elif file_type == 'hdf5':
|
||||
data = h5py.File(data, 'r')
|
||||
|
||||
users = data['users']
|
||||
features = data['user_data']
|
||||
labels = data['user_data_label']
|
||||
num_samples = data['num_samples']
|
||||
|
||||
return users, features, labels, num_samples
|
|
@ -0,0 +1,62 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
'''The standard PyTorch model we want to federate'''
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class CNN(nn.Module):
|
||||
'''This is a PyTorch model with some extra methods'''
|
||||
|
||||
def __init__(self, model_config):
|
||||
super().__init__()
|
||||
self.net = Net()
|
||||
|
||||
def loss(self, input: torch.Tensor) -> torch.Tensor:
|
||||
'''Performs forward step and computes the loss'''
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
features, labels = input['x'].to(device), input['y'].to(device)
|
||||
output = self.net.forward(features)
|
||||
return F.cross_entropy(output, labels.long())
|
||||
|
||||
def inference(self, input):
|
||||
'''Performs forward step and computes metrics'''
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
features, labels = input['x'].to(device), input['y'].to(device)
|
||||
output = self.net.forward(features)
|
||||
|
||||
n_samples = features.shape[0]
|
||||
accuracy = torch.mean((torch.argmax(output, dim=1) == labels).float()).item()
|
||||
|
||||
return output, accuracy, n_samples
|
||||
|
||||
def set_eval(self):
|
||||
'''Bring the model into evaluation mode'''
|
||||
self.eval()
|
||||
|
||||
def set_train(self):
|
||||
'''Bring the model into training mode'''
|
||||
self.train()
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
'''Simple example of a CNN on CIFAR-10
|
||||
|
||||
This is adapted from the Pytorch tutorials. See
|
||||
https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py
|
||||
for more info.
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
|
||||
# Parameters
|
||||
BATCH_SIZE = 4
|
||||
N_EPOCHS = 2
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Create dataloaders
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||
download=True, transform=transform)
|
||||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE,
|
||||
shuffle=True, num_workers=2)
|
||||
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||
download=True, transform=transform)
|
||||
testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
|
||||
shuffle=False, num_workers=2)
|
||||
|
||||
|
||||
# Define the model
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv1 = nn.Conv2d(3, 6, 5)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(6, 16, 5)
|
||||
self.fc1 = nn.Linear(16 * 5 * 5, 120)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pool(F.relu(self.conv1(x)))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = torch.flatten(x, 1) # flatten all dimensions except batch
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
# Instantiate model, loss and optimizer
|
||||
net = Net().to(device)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
# Training loop
|
||||
for epoch in range(N_EPOCHS): # loop over the dataset multiple times
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(trainloader, 0):
|
||||
# Get the inputs; data is a list of [inputs, labels]
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
|
||||
# Zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Forward + backward + optimize
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# Print statistics
|
||||
running_loss += loss.item()
|
||||
if i % 2000 == 1999: # print every 2000 mini-batches
|
||||
print('[%d, %5d] loss: %.3f' %
|
||||
(epoch + 1, i + 1, running_loss / 2000))
|
||||
running_loss = 0.0
|
||||
|
||||
# Compute accuracy
|
||||
correct = 0
|
||||
total = 0
|
||||
with torch.no_grad():
|
||||
for data in testloader:
|
||||
images, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = net(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
print('Accuracy of the network on the 10000 test images: %d %%' % (
|
||||
100 * correct / total))
|
|
@ -0,0 +1,81 @@
|
|||
import h5py
|
||||
import json
|
||||
import time
|
||||
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import tqdm
|
||||
|
||||
|
||||
def _dump_dict_to_hdf5(data_dict: dict, hdf5_file: h5py.File):
|
||||
'''Dump dict with expected structure to HDF5 file'''
|
||||
|
||||
hdf5_file.create_dataset('users', data=data_dict['users'])
|
||||
hdf5_file.create_dataset('num_samples', data=data_dict['num_samples'])
|
||||
|
||||
# Store actual data in groups
|
||||
user_data_group = hdf5_file.create_group('user_data')
|
||||
for user, user_data in tqdm.tqdm(data_dict['user_data'].items()):
|
||||
user_subgroup = user_data_group.create_group(user)
|
||||
user_subgroup.create_dataset('x', data=user_data)
|
||||
|
||||
user_data_label_group = hdf5_file.create_group('user_data_label')
|
||||
for user, user_data_label in tqdm.tqdm(data_dict['user_data_label'].items()):
|
||||
user_data_label_group.create_dataset(user, data=user_data_label)
|
||||
|
||||
def _process_and_save_to_disk(dataset, n_users, file_format, output):
|
||||
'''Process a Torchvision dataset to expected format and save to disk'''
|
||||
|
||||
# Split training data equally among all users
|
||||
total_samples = len(dataset)
|
||||
samples_per_user = total_samples // n_users
|
||||
assert total_samples % n_users == 0
|
||||
|
||||
# Function for getting a given user's data indices
|
||||
user_idxs = lambda user_id: slice(user_id * samples_per_user, (user_id + 1) * samples_per_user)
|
||||
|
||||
# Convert training data to expected format
|
||||
print('Converting data to expected format...')
|
||||
start_time = time.time()
|
||||
|
||||
data_dict = { # the data is expected to have this format
|
||||
'users' : [f'{user_id:04d}' for user_id in range(n_users)],
|
||||
'num_samples' : 10000 * [samples_per_user],
|
||||
'user_data' : {f'{user_id:04d}': dataset.data[user_idxs(user_id)].tolist() for user_id in range(n_users)},
|
||||
'user_data_label': {f'{user_id:04d}': dataset.targets[user_idxs(user_id)] for user_id in range(n_users)},
|
||||
}
|
||||
|
||||
print(f'Finished converting data in {time.time() - start_time:.2f}s.')
|
||||
|
||||
# Save training data to disk
|
||||
print('Saving data to disk...')
|
||||
start_time = time.time()
|
||||
|
||||
if file_format == 'json':
|
||||
with open(output + '.json', 'w') as json_file:
|
||||
json.dump(data_dict, json_file)
|
||||
elif file_format == 'hdf5':
|
||||
with h5py.File(output + '.hdf5', 'w') as hdf5_file:
|
||||
_dump_dict_to_hdf5(data_dict=data_dict, hdf5_file=hdf5_file)
|
||||
else:
|
||||
raise ValueError('unknown format.')
|
||||
|
||||
print(f'Finished saving data in {time.time() - start_time:.2f}s.')
|
||||
|
||||
|
||||
# Get training and testing data from torchvision
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
|
||||
])
|
||||
|
||||
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
|
||||
download=True, transform=transform)
|
||||
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
|
||||
download=True, transform=transform)
|
||||
|
||||
print('Processing training set...')
|
||||
_process_and_save_to_disk(trainset, n_users=1000, file_format='hdf5', output='./data/train_data')
|
||||
|
||||
print('Processing test set...')
|
||||
_process_and_save_to_disk(testset, n_users=200, file_format='hdf5', output='./data/test_data')
|
|
@ -0,0 +1,31 @@
|
|||
# Simple example of a MLM task on Reddit Dataset
|
||||
|
||||
Instructions on how to run the experiment, given below.
|
||||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
||||
script located in `testing/create_data.py` as follows:
|
||||
|
||||
```code
|
||||
python create_data.py -e mlm
|
||||
```
|
||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
||||
inside `testing` with all data needed for a local run.
|
||||
|
||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||
in case you want to use your own data.
|
||||
|
||||
## Creating a config file
|
||||
|
||||
All the parameters of the experiment are passed in a YAML file. An example is
|
||||
provided in `configs/hello_world_mlm_bert_json.yaml` with the suggested parameters
|
||||
to do a simple run for this experiment. Make sure to point your training files at
|
||||
the fields: train_data, test_data and val_data inside the config file.
|
||||
|
||||
## Running the experiment
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
section of the main `README.md`.
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from transformers.data.data_collator import default_data_collator, DataCollatorWithPadding
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import DataCollatorForLanguageModeling
|
||||
from experiments.mlm_bert.dataloaders.text_dataset import TextDataset
|
||||
import torch
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
"""
|
||||
PyTorch dataloader for loading text data from
|
||||
text_dataset.
|
||||
"""
|
||||
def __init__(self, mode, data, num_workers=0, **kwargs):
|
||||
|
||||
args = kwargs['args']
|
||||
task = args['task']
|
||||
user_idx = kwargs['user_idx']
|
||||
mlm_probability = args['mlm_probability']
|
||||
self.batch_size = args['batch_size']
|
||||
self.mode = mode
|
||||
self.num_workers = num_workers
|
||||
self.utt_ids = None
|
||||
max_samples_per_user = args.get('max_samples_per_user', -1)
|
||||
min_words_per_utt = args.get('min_words_per_utt', 5)
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": args['cache_dir'],
|
||||
"use_fast": args['tokenizer_type_fast'],
|
||||
"use_auth_token": None
|
||||
}
|
||||
|
||||
if 'tokenizer_name' in args:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args['tokenizer_name'], **tokenizer_kwargs)
|
||||
elif 'model_name_or_path' in args:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args['model_name_or_path'], **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError("You are instantiating a new tokenizer from scratch. This is not supported by this script.")
|
||||
|
||||
print("Tokenizer is: ",tokenizer)
|
||||
|
||||
dataset = TextDataset(
|
||||
data,
|
||||
args= args,
|
||||
test_only = self.mode is not 'train',
|
||||
tokenizer= tokenizer,
|
||||
user_idx=user_idx,
|
||||
max_samples_per_user=max_samples_per_user,
|
||||
min_words_per_utt=min_words_per_utt,
|
||||
)
|
||||
self.utt_ids = dataset.user
|
||||
|
||||
try:
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm= task=='mlm',
|
||||
mlm_probability=mlm_probability,)
|
||||
except:
|
||||
|
||||
print('There is an issue with the DataCollator .. Falling back to default_data_collator')
|
||||
data_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
|
||||
|
||||
if self.mode == 'train':
|
||||
train_sampler = RandomSampler(dataset)
|
||||
super(TextDataLoader, self).__init__(
|
||||
dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=train_sampler,
|
||||
collate_fn=data_collator,
|
||||
drop_last=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
elif self.mode == 'val' or self.mode == 'test':
|
||||
eval_sampler = SequentialSampler(dataset)
|
||||
super(TextDataLoader, self).__init__(
|
||||
dataset,
|
||||
sampler=eval_sampler,
|
||||
batch_size= self.batch_size,
|
||||
collate_fn=data_collator,
|
||||
drop_last=False,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Sorry, there is something wrong with the 'mode'-parameter ")
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
def get_user(self):
|
||||
return self.utt_ids
|
||||
|
|
@ -0,0 +1,201 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from utils import print_rank
|
||||
import logging
|
||||
import json
|
||||
import itertools
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Map a text source to the target text
|
||||
"""
|
||||
def __init__(self, data, args, tokenizer, test_only=False, user_idx=None, max_samples_per_user=-1, min_words_per_utt=5):
|
||||
self.utt_list = list()
|
||||
self.test_only= test_only
|
||||
self.padding = args.get('padding', True)
|
||||
self.max_seq_length= args['max_seq_length']
|
||||
self.max_samples_per_user = max_samples_per_user
|
||||
self.min_num_words = min_words_per_utt
|
||||
self.tokenizer = tokenizer
|
||||
self.process_line_by_line=args.get('process_line_by_line', False)
|
||||
self.user = None
|
||||
|
||||
|
||||
if self.max_seq_length is None:
|
||||
self.max_seq_length = self.tokenizer.model_max_length
|
||||
if self.max_seq_length > 512:
|
||||
print_rank(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 512 instead. You can change that default value by passing --max_seq_length xxx.", loglevel=logging.DEBUG
|
||||
)
|
||||
self.max_seq_length = 512
|
||||
else:
|
||||
if self.max_seq_length > self.tokenizer.model_max_length:
|
||||
print_rank(
|
||||
f"The max_seq_length passed ({self.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({self.tokenizer.model_max_length}). Using max_seq_length={self.tokenizer.model_max_length}.", loglevel=logging.DEBUG
|
||||
)
|
||||
self.max_seq_length = min(self.max_seq_length, self.tokenizer.model_max_length)
|
||||
|
||||
self.read_data(data, user_idx)
|
||||
|
||||
if not self.process_line_by_line:
|
||||
self.post_process_list()
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.utt_list)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# Find the index in the available data
|
||||
if self.process_line_by_line:
|
||||
tokenized_text = LineByLineTextDataset(
|
||||
tokenizer=self.tokenizer,
|
||||
input_lines=self.utt_list[idx]['src_text'],
|
||||
line_by_line=True,
|
||||
truncation=True,
|
||||
max_length=self.max_seq_length,
|
||||
padding="max_length")
|
||||
|
||||
self.utt_list[idx]['duration']= len(tokenized_text['input_ids'])
|
||||
return tokenized_text
|
||||
else:
|
||||
return self.utt_list[idx]
|
||||
|
||||
|
||||
def read_data(self, orig_strct, user_idx):
|
||||
""" Reads the data for a specific user (unless it's for val/testing) and returns a
|
||||
list of embeddings and targets."""
|
||||
|
||||
if isinstance(orig_strct, str):
|
||||
print('Loading json-file: ', orig_strct)
|
||||
with open(orig_strct, 'r') as fid:
|
||||
orig_strct = json.load(fid)
|
||||
|
||||
self.user_list = orig_strct['users']
|
||||
self.num_samples= orig_strct['num_samples']
|
||||
self.user_data = orig_strct['user_data']
|
||||
|
||||
if self.test_only:
|
||||
self.user = 'test_only'
|
||||
self.process_x(self.user_data)
|
||||
else:
|
||||
self.user = self.user_list[user_idx]
|
||||
self.process_x(self.user_data[self.user])
|
||||
|
||||
|
||||
def process_x(self, raw_x_batch):
|
||||
|
||||
if self.test_only:
|
||||
for i, user in enumerate(self.user_list):
|
||||
counter=self.process_user(user, raw_x_batch[user])
|
||||
self.num_samples[i] = counter # Update userdata counter "num_samples[user]" after truncation
|
||||
else:
|
||||
counter = self.process_user(self.user, raw_x_batch)
|
||||
self.num_samples[self.user_list.index(self.user)] = counter # Update userdata counter "num_samples[user]" after truncation
|
||||
|
||||
if len(self.utt_list) == 0:
|
||||
self.utt_list = [{'src_text': 'N/A', 'duration': 0, 'loss_weight': 1.0}]
|
||||
|
||||
print_rank('Processing json-structure for User: {} Utterances Processed: {}'.format(self.user, len(self.utt_list)), loglevel=logging.INFO)
|
||||
|
||||
|
||||
def process_user(self, user, user_data):
|
||||
counter=0
|
||||
for line in user_data:
|
||||
for e in line:
|
||||
if len(e.split()) < self.min_num_words:
|
||||
continue
|
||||
if self.max_samples_per_user > -1 and counter >= self.max_samples_per_user:
|
||||
print_rank('Max allowed size per user is reached for user: {}, N: {} utts, Utt_list Len: {}' \
|
||||
.format(user, counter, len(self.utt_list)), loglevel=logging.DEBUG)
|
||||
return counter
|
||||
counter += 1
|
||||
|
||||
utt = {}
|
||||
utt['src_text'] = e
|
||||
utt['duration'] = len(e.split())
|
||||
utt['loss_weight'] = 1.0
|
||||
self.utt_list.append(utt)
|
||||
return counter
|
||||
|
||||
|
||||
def post_process_list(self):
|
||||
|
||||
# Use only the text part of the dataset
|
||||
input_lines=[line['src_text'] for line in self.utt_list]
|
||||
|
||||
# Process all lines of text
|
||||
print_rank('Tokenizing {} Utterances'.format(len(input_lines)), loglevel=logging.DEBUG)
|
||||
self.utt_list= LineByLineTextDataset(self.tokenizer, input_lines) #this one has return_special_tokens_mask as True
|
||||
|
||||
def group_texts(examples):
|
||||
""""Main data processing function that will concatenate all texts
|
||||
from our dataset and generate chunks of max_seq_length."""
|
||||
|
||||
print_rank('Concatenating Frames in Sequences of {} samples'.format(self.max_seq_length), loglevel=logging.DEBUG)
|
||||
|
||||
if self.padding: # Padding last frame
|
||||
|
||||
total_length = sum([len(k) for k in examples['input_ids']])
|
||||
print_rank('Found {} samples Before Concatenation'.format(total_length), loglevel=logging.DEBUG)
|
||||
padN= self.max_seq_length - (total_length % self.max_seq_length)
|
||||
print_rank('Padding last frame with {} samples'.format(padN), loglevel=logging.DEBUG)
|
||||
print_rank('keys {}'.format(examples.keys()), loglevel=logging.DEBUG)
|
||||
examples['input_ids'].append([self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token)]*padN)
|
||||
examples['attention_mask'].append([0]*padN)
|
||||
|
||||
if 'special_tokens_mask' in examples.keys():
|
||||
examples['special_tokens_mask'].append([1]*padN)
|
||||
|
||||
if 'token_type_ids' in examples.keys():
|
||||
examples['token_type_ids'].append([0]*padN)
|
||||
|
||||
|
||||
# Concatenate all input.
|
||||
concatenated_examples = {k: list(itertools.chain.from_iterable(examples[k])) for k in examples.keys()}
|
||||
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
||||
print_rank('Concatenated in {} Samples'.format(total_length), loglevel=logging.DEBUG)
|
||||
total_length = (total_length // self.max_seq_length) * self.max_seq_length
|
||||
print_rank('Concatenated in {} Frames'.format(total_length // self.max_seq_length), loglevel=logging.DEBUG)
|
||||
|
||||
# Split by chunks of max_len
|
||||
self.utt_list=[]
|
||||
for i in range(0, total_length, self.max_seq_length):
|
||||
utt={}
|
||||
for k, t in concatenated_examples.items():
|
||||
utt[k]= t[i : i + self.max_seq_length]
|
||||
self.utt_list.append(utt)
|
||||
print_rank('Utterance Len is: {}'.format(len(utt['input_ids'])),loglevel=logging.DEBUG)
|
||||
|
||||
# Process list of text
|
||||
group_texts(self.utt_list)
|
||||
|
||||
total_length = len(self.utt_list)
|
||||
print_rank('Finished Reshaping in Sequences of {} Frames'.format(total_length), loglevel=logging.INFO)
|
||||
|
||||
# Update userdata after truncation
|
||||
if not self.test_only:
|
||||
self.num_samples[self.user_list.index(self.user)] = total_length
|
||||
|
||||
# Not used anywhere but necessary when the dataset is initiated
|
||||
if total_length == 0:
|
||||
self.utt_list = [{"input_ids": [0, 2], "special_tokens_mask": [1, 1], "attention_mask": [0, 0]}]
|
||||
|
||||
def LineByLineTextDataset(tokenizer, input_lines, truncation=True, max_length=512, padding = False, line_by_line=False):
|
||||
|
||||
if input_lines==['N/A']:
|
||||
batch_encoding = {"input_ids": [[0, 2]], "special_tokens_mask": [[1, 1]], "attention_mask": [[0, 0]]}
|
||||
else:
|
||||
lines = [line for line in input_lines if (len(line) > 0 and not line.isspace())]
|
||||
print_rank ('padding is : ' + str(padding),loglevel=logging.DEBUG)
|
||||
print_rank ('max_length is : ' + str(max_length),loglevel=logging.DEBUG)
|
||||
batch_encoding = tokenizer(lines, truncation=truncation, max_length=max_length, padding = padding, return_special_tokens_mask=True,)
|
||||
if line_by_line:
|
||||
batch_encoding["input_ids"] = batch_encoding["input_ids"][0]
|
||||
batch_encoding["special_tokens_mask"] = batch_encoding["special_tokens_mask"][0]
|
||||
batch_encoding["attention_mask"] = batch_encoding["attention_mask"][0]
|
||||
|
||||
return batch_encoding
|
|
@ -0,0 +1,471 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch as T
|
||||
from utils import print_rank
|
||||
import logging
|
||||
import copy
|
||||
from typing import (Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union)
|
||||
|
||||
from experiments.mlm_bert.utils.trainer_pt_utils import (
|
||||
LabelSmoother,
|
||||
DistributedTensorGatherer,
|
||||
nested_concat,
|
||||
nested_detach,
|
||||
nested_numpify,
|
||||
)
|
||||
|
||||
from experiments.mlm_bert.utils.trainer_utils import (
|
||||
EvalPrediction,
|
||||
ComputeMetrics)
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoConfig,
|
||||
AutoModelForMaskedLM,
|
||||
AutoTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_MASKED_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
class BERT(T.nn.Module):
|
||||
def __init__(self, model_config, **kwargs):
|
||||
super(BERT, self).__init__()
|
||||
"""
|
||||
from transformers import RobertaConfig
|
||||
config = RobertaConfig(
|
||||
vocab_size=52_000,
|
||||
max_position_embeddings=514,
|
||||
num_attention_heads=12,
|
||||
num_hidden_layers=6,
|
||||
type_vocab_size=1,
|
||||
)
|
||||
|
||||
from transformers import RobertaTokenizerFast
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("./EsperBERTo", max_len=512)
|
||||
|
||||
from transformers import RobertaForMaskedLM
|
||||
model = RobertaForMaskedLM(config=config)
|
||||
"""
|
||||
|
||||
# Extracting model_config['BERT']
|
||||
args = model_config['BERT']
|
||||
# Split data to smaller configuration parameters
|
||||
model_args, training_args = args['model'], args['training']
|
||||
|
||||
# Set seed before initializing model.
|
||||
set_seed(training_args['seed'])
|
||||
|
||||
self.gradient_accumulation_steps = model_args.get('gradient_accumulation_steps', 1)
|
||||
self.past_index = model_args.get('past_index', -1)
|
||||
self.prediction_loss_only = model_args.get('prediction_loss_only', True)
|
||||
self.eval_accumulation_steps = model_args.get('eval_accumulation_steps', None)
|
||||
self.label_names = model_args.get('label_names', None)
|
||||
self.batch_size= training_args['batch_size']
|
||||
self.model_name=model_args['model_name']
|
||||
|
||||
if 'model_name_or_path' not in model_args:
|
||||
model_args['model_name_or_path']=self.model_name
|
||||
|
||||
# Label smoothing
|
||||
if training_args['label_smoothing_factor'] != 0:
|
||||
self.label_smoother = LabelSmoother(epsilon=training_args['label_smoothing_factor'])
|
||||
else:
|
||||
self.label_smoother = None
|
||||
self.label_names = ( ["labels"]) if self.label_names is None else self.label_names
|
||||
|
||||
config_kwargs = {
|
||||
"cache_dir": model_args['cache_dir'],
|
||||
"revision": None,
|
||||
"use_auth_token": None,
|
||||
}
|
||||
|
||||
if 'config_name' in model_args:
|
||||
config = AutoConfig.from_pretrained(model_args['config_name'], **config_kwargs)
|
||||
elif 'model_name_or_path' in model_args:
|
||||
config = AutoConfig.from_pretrained(model_args['model_name_or_path'], **config_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new configuration from scratch. This is not supported by this script."
|
||||
)
|
||||
|
||||
|
||||
tokenizer_kwargs = {
|
||||
"cache_dir": model_args['cache_dir'],
|
||||
"use_fast": model_args['use_fast_tokenizer'],
|
||||
"use_auth_token": None,
|
||||
}
|
||||
if 'tokenizer_name' in model_args:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args['tokenizer_name'], **tokenizer_kwargs)
|
||||
elif 'model_name_or_path' in model_args:
|
||||
print('Loading Tokenizer from Pretrained: {}'.format(model_args['model_name_or_path']) )
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args['model_name_or_path'], **tokenizer_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
||||
)
|
||||
self.output_layer_size=len(tokenizer)
|
||||
|
||||
if 'model_name_or_path' in model_args:
|
||||
print('Loading Model from Pretrained: {}'.format(model_args['model_name_or_path']) )
|
||||
self.model = AutoModelForMaskedLM.from_pretrained(
|
||||
model_args['model_name_or_path'],
|
||||
from_tf=False,
|
||||
config=config,
|
||||
cache_dir=model_args['cache_dir'],
|
||||
use_auth_token=None,
|
||||
)
|
||||
if 'adapter' in model_args:
|
||||
if model_args['adapter']:
|
||||
self.model.add_adapter("FLUTE")
|
||||
|
||||
#Activate the adapter
|
||||
self.model.train_adapter("FLUTE")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new model from scratch. This is not supported by this script."
|
||||
)
|
||||
self.model.resize_token_embeddings(self.output_layer_size)
|
||||
total_params = 0
|
||||
trainable_params = 0
|
||||
|
||||
for p in self.model.parameters():
|
||||
total_params += p.numel()
|
||||
if p.requires_grad:
|
||||
trainable_params += p.numel()
|
||||
|
||||
print_rank(f"Total parameters count: {total_params}", loglevel=logging.DEBUG) # ~109M
|
||||
print_rank(f"Trainable parameters count: {trainable_params}", loglevel=logging.DEBUG) # ~1M
|
||||
print_rank(f"Original Bert parameters count: {total_params-trainable_params}", loglevel=logging.DEBUG) # ~1M
|
||||
|
||||
|
||||
def copy_state_dict(self, state_dict):
|
||||
self.model.state_dict=state_dict.clone()
|
||||
|
||||
def get_model(self):
|
||||
return self.model
|
||||
|
||||
|
||||
def _prepare_inputs(self, inputs):
|
||||
"""
|
||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||
handling potential state.
|
||||
"""
|
||||
for k, v in inputs.items():
|
||||
if isinstance(v, T.Tensor):
|
||||
inputs[k] = v.cuda() if T.cuda.is_available() else v
|
||||
if self.past_index >= 0 and self._past is not None:
|
||||
inputs["mems"] = self._past
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
def forward(self, inputs):
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
return self.model(**inputs)
|
||||
|
||||
|
||||
def loss(self, inputs):
|
||||
"""
|
||||
Perform a training step on a batch of inputs.
|
||||
Subclass and override to inject custom behavior.
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to train.
|
||||
inputs (:obj:`Dict[str, Union[T.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
Return:
|
||||
:obj:`T.Tensor`: The tensor with training loss on this batch.
|
||||
"""
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
loss = self.compute_loss(inputs)
|
||||
loss = loss / self.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
def compute_loss(self, inputs_orig, return_outputs=False):
|
||||
"""
|
||||
How the loss is computed by Trainer. By default, all models return the loss in the first element.
|
||||
Subclass and override for custom behavior.
|
||||
|
||||
inputs (:obj:`Dict[str, Union[T.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
"""
|
||||
# Copy a local copy of the data
|
||||
inputs=copy.deepcopy(inputs_orig)
|
||||
|
||||
if self.label_smoother is not None and "labels" in inputs:
|
||||
labels = inputs["labels"].detach().cpu()
|
||||
else:
|
||||
labels = None
|
||||
|
||||
# The following fields need to be removed for Roberta
|
||||
if 'roberta' in self.model_name:
|
||||
#print("here")
|
||||
if 'attention_mask' in inputs:
|
||||
inputs.pop('attention_mask')
|
||||
if 'special_tokens_mask' in inputs:
|
||||
inputs.pop('special_tokens_mask')
|
||||
|
||||
|
||||
# Forward pass for the transformer
|
||||
outputs = self.model(**inputs)
|
||||
|
||||
if self.past_index >= 0:
|
||||
self._past = outputs[self.past_index]
|
||||
|
||||
if labels is not None:
|
||||
loss = self.label_smoother(outputs, labels)
|
||||
else:
|
||||
# We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
|
||||
|
||||
def inference(
|
||||
self, inputs, ignore_keys: Optional[List[str]] = [], metric_key_prefix: str = "eval"
|
||||
) -> List[float]:
|
||||
"""
|
||||
Run prediction and returns predictions and potential metrics.
|
||||
Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method
|
||||
will also return metrics, like in :obj:`evaluate()`.
|
||||
Args:
|
||||
inputs (:obj:`Dict[str, Union[T.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
metric_key_prefix (:obj:`str`, `optional`, defaults to :obj:`"eval"`):
|
||||
An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named
|
||||
"eval_bleu" if the prefix is "eval" (default)
|
||||
.. note::
|
||||
If your predictions or labels have different sequence length (for instance because you're doing dynamic
|
||||
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
||||
concatenation into one array. The padding index is -100.
|
||||
Returns: `NamedTuple` A namedtuple with the following keys:
|
||||
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
||||
- label_ids (:obj:`np.ndarray`, `optional`): The labels (if the dataset contained some).
|
||||
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
|
||||
contained labels).
|
||||
"""
|
||||
|
||||
|
||||
output, batch_size = self.prediction_loop(
|
||||
inputs,
|
||||
description="Evaluation",
|
||||
ignore_keys=ignore_keys,
|
||||
metric_key_prefix=metric_key_prefix)
|
||||
return (output['eval_loss'], output['eval_acc'], batch_size[0])
|
||||
|
||||
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
inputs,
|
||||
description: str,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> Union[Dict, List[int]]:
|
||||
"""
|
||||
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
|
||||
Works both with or without labels.
|
||||
"""
|
||||
|
||||
out_label_ids=None
|
||||
if 'labels' in inputs:
|
||||
out_label_ids = inputs['labels'].detach().cpu()
|
||||
|
||||
if 'attention_mask' in inputs:
|
||||
attention_mask= inputs['attention_mask'].detach().cpu()
|
||||
|
||||
losses_host = None
|
||||
preds_host = None
|
||||
labels_host = None
|
||||
|
||||
world_size = 1
|
||||
num_hosts = 1
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_hosts, make_multiple_of=self.batch_size)
|
||||
if not self.prediction_loss_only:
|
||||
preds_gatherer = DistributedTensorGatherer(world_size, num_hosts)
|
||||
labels_gatherer = DistributedTensorGatherer(world_size, num_hosts)
|
||||
|
||||
self.model.eval()
|
||||
if self.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
loss, logits, _ = self.prediction_step(inputs, ignore_keys=ignore_keys, has_labels=True)
|
||||
if loss is not None:
|
||||
losses = loss.repeat(self.batch_size).cpu()
|
||||
losses_host = losses if losses_host is None else T.cat((losses_host, losses), dim=0)
|
||||
if logits is not None:
|
||||
preds_host = logits.detach().cpu() if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
if out_label_ids is not None:
|
||||
labels_host = out_label_ids if labels_host is None else nested_concat(labels_host, out_label_ids, padding_index=-100)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.eval_accumulation_steps is not None :
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
if not self.prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
if num_hosts>1:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"), want_masked=True)
|
||||
if not self.prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
eval_loss = eval_losses_gatherer.finalize()
|
||||
preds = preds_gatherer.finalize() if not self.prediction_loss_only else None
|
||||
label_ids = labels_gatherer.finalize() if not self.prediction_loss_only else None
|
||||
else:
|
||||
eval_loss= losses_host
|
||||
preds = preds_host
|
||||
label_ids= labels_host
|
||||
|
||||
if preds is not None and label_ids is not None:
|
||||
metrics = ComputeMetrics.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids), attention_mask)
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
if eval_loss is not None:
|
||||
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
|
||||
|
||||
# Prefix all keys with metric_key_prefix + '_'
|
||||
for key in list(metrics.keys()):
|
||||
if not key.startswith(f"{metric_key_prefix}_"):
|
||||
metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key).item()
|
||||
return metrics, preds.size()
|
||||
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
"""
|
||||
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
|
||||
concatenating them to `gathered`
|
||||
"""
|
||||
if tensors is None:
|
||||
return
|
||||
return nested_numpify(tensors)
|
||||
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
inputs,
|
||||
ignore_keys: Optional[List[str]] = None, has_labels: bool = None
|
||||
) -> Tuple[Optional[float], Optional[T.Tensor], Optional[T.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
Subclass and override to inject custom behavior.
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[T.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
|
||||
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
|
||||
prediction_loss_only (:obj:`bool`):
|
||||
Whether or not to return the loss only.
|
||||
ignore_keys (:obj:`Lst[str]`, `optional`):
|
||||
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
|
||||
gathering predictions.
|
||||
Return:
|
||||
Tuple[Optional[float], Optional[T.Tensor], Optional[T.Tensor]]: A tuple with the loss, logits and
|
||||
labels (each being optional).
|
||||
"""
|
||||
|
||||
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
# labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
|
||||
if has_labels:
|
||||
#labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
|
||||
labels = inputs["labels"].detach().cpu()
|
||||
if len(labels) == 1:
|
||||
labels = labels[0]
|
||||
else:
|
||||
labels = None
|
||||
|
||||
with T.no_grad():
|
||||
if has_labels:
|
||||
loss, outputs = self.compute_loss(inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
if isinstance(outputs, dict):
|
||||
logits = outputs["logits"]
|
||||
else:
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
outputs = self.model(**inputs)
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
|
||||
else:
|
||||
logits = outputs
|
||||
if self.past_index >= 0:
|
||||
self._past = outputs[self.past_index - 1]
|
||||
|
||||
if self.prediction_loss_only:
|
||||
return (loss, None, None)
|
||||
|
||||
logits = nested_detach(logits)
|
||||
if len(logits) == 1:
|
||||
logits = logits[0]
|
||||
|
||||
return (loss, logits, labels)
|
||||
|
||||
|
||||
def floating_point_ops(self, inputs):
|
||||
"""
|
||||
For models that inherit from :class:`~transformers.PreTrainedModel`, uses that method to compute the number of
|
||||
floating point operations for every backward + forward pass. If using another model, either implement such a
|
||||
method in the model or subclass and override this method.
|
||||
Args:
|
||||
inputs (:obj:`Dict[str, Union[T.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
if hasattr(self.model, "floating_point_ops"):
|
||||
return self.model.floating_point_ops(inputs)
|
||||
else:
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
def set_eval(self):
|
||||
"""
|
||||
Bring the model into evaluation mode
|
||||
"""
|
||||
self.model.eval()
|
||||
|
||||
|
||||
def set_train(self):
|
||||
"""
|
||||
Bring the model into train mode
|
||||
"""
|
||||
self.model.train()
|
|
@ -0,0 +1,493 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Torch utilities for the Trainer class.
|
||||
"""
|
||||
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterator, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from packaging import version
|
||||
from torch.utils.data.dataset import Dataset
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.utils.data.sampler import RandomSampler, Sampler
|
||||
|
||||
|
||||
|
||||
# this is used to supress an undesired warning emitted by pytorch versions 1.4.2-1.7.0
|
||||
try:
|
||||
from torch.optim.lr_scheduler import SAVE_STATE_WARNING
|
||||
except ImportError:
|
||||
SAVE_STATE_WARNING = ""
|
||||
|
||||
|
||||
|
||||
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
||||
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
||||
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
||||
return torch.cat((tensor1, tensor2), dim=0)
|
||||
|
||||
# Let's figure out the new shape
|
||||
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
|
||||
|
||||
# Now let's fill the result tensor
|
||||
result = tensor1.new_full(new_shape, padding_index)
|
||||
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
|
||||
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
|
||||
return result
|
||||
|
||||
|
||||
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
||||
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
||||
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
||||
return np.concatenate((array1, array2), dim=0)
|
||||
|
||||
# Let's figure out the new shape
|
||||
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
|
||||
|
||||
# Now let's fill the result tensor
|
||||
result = np.full_like(array1, padding_index, shape=new_shape)
|
||||
result[: array1.shape[0], : array1.shape[1]] = array1
|
||||
result[array1.shape[0] :, : array2.shape[1]] = array2
|
||||
return result
|
||||
|
||||
|
||||
def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||
"""
|
||||
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
|
||||
nested list/tuples of tensors.
|
||||
"""
|
||||
assert type(tensors) == type(
|
||||
new_tensors
|
||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
||||
elif isinstance(tensors, torch.Tensor):
|
||||
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||
elif isinstance(tensors, np.ndarray):
|
||||
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
||||
|
||||
|
||||
def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
return tensors.cpu().numpy()
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
"Detach `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_detach(t) for t in tensors)
|
||||
return tensors.detach()
|
||||
|
||||
|
||||
|
||||
|
||||
def reissue_pt_warnings(caught_warnings):
|
||||
# Reissue warnings that are not the SAVE_STATE_WARNING
|
||||
if len(caught_warnings) > 1:
|
||||
for w in caught_warnings:
|
||||
if w.category != UserWarning or w.message != SAVE_STATE_WARNING:
|
||||
warnings.warn(w.message, w.category)
|
||||
|
||||
|
||||
|
||||
|
||||
def nested_new_like(arrays, num_samples, padding_index=-100):
|
||||
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
|
||||
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
|
||||
|
||||
|
||||
def nested_expand_like(arrays, new_seq_length, padding_index=-100):
|
||||
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays)
|
||||
|
||||
result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
|
||||
result[:, : arrays.shape[1]] = arrays
|
||||
return result
|
||||
|
||||
|
||||
def nested_truncate(tensors, limit):
|
||||
"Truncate `tensors` at `limit` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_truncate(t, limit) for t in tensors)
|
||||
return tensors[:limit]
|
||||
|
||||
|
||||
def _get_first_shape(arrays):
|
||||
"""Return the shape of the first array found in the nested struct `arrays`."""
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
return _get_first_shape(arrays[0])
|
||||
return arrays.shape
|
||||
|
||||
|
||||
class DistributedTensorGatherer:
|
||||
"""
|
||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
||||
If our dataset has 16 samples with a batch size of 2 on 3 processes and we gather then transfer on CPU at every
|
||||
step, our sampler will generate the following indices:
|
||||
:obj:`[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1]`
|
||||
to get something of size a multiple of 3 (so that each process gets the same dataset length). Then process 0, 1 and
|
||||
2 will be responsible of making predictions for the following samples:
|
||||
- P0: :obj:`[0, 1, 2, 3, 4, 5]`
|
||||
- P1: :obj:`[6, 7, 8, 9, 10, 11]`
|
||||
- P2: :obj:`[12, 13, 14, 15, 0, 1]`
|
||||
The first batch treated on each process will be
|
||||
- P0: :obj:`[0, 1]`
|
||||
- P1: :obj:`[6, 7]`
|
||||
- P2: :obj:`[12, 13]`
|
||||
So if we gather at the end of the first batch, we will get a tensor (nested list/tuple of tensor) corresponding to
|
||||
the following indices:
|
||||
:obj:`[0, 1, 6, 7, 12, 13]`
|
||||
If we directly concatenate our results without taking any precautions, the user will then get the predictions for
|
||||
the indices in this order at the end of the prediction loop:
|
||||
:obj:`[0, 1, 6, 7, 12, 13, 2, 3, 8, 9, 14, 15, 4, 5, 10, 11, 0, 1]`
|
||||
For some reason, that's not going to roll their boat. This class is there to solve that problem.
|
||||
Args:
|
||||
world_size (:obj:`int`):
|
||||
The number of processes used in the distributed training.
|
||||
num_samples (:obj:`int`):
|
||||
The number of samples in our dataset.
|
||||
make_multiple_of (:obj:`int`, `optional`):
|
||||
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
|
||||
(by adding samples).
|
||||
padding_index (:obj:`int`, `optional`, defaults to -100):
|
||||
The padding index to use if the arrays don't all have the same sequence length.
|
||||
"""
|
||||
|
||||
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
|
||||
self.world_size = world_size
|
||||
self.num_samples = num_samples
|
||||
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
||||
self.total_samples = int(np.ceil(num_samples / total_size)) * total_size
|
||||
self.process_length = self.total_samples // world_size
|
||||
self._storage = None
|
||||
self._offsets = None
|
||||
self.padding_index = padding_index
|
||||
|
||||
def add_arrays(self, arrays):
|
||||
"""
|
||||
Add :obj:`arrays` to the internal storage, Will initialize the storage to the full size at the first arrays
|
||||
passed so that if we're bound to get an OOM, it happens at the beginning.
|
||||
"""
|
||||
if arrays is None:
|
||||
return
|
||||
if self._storage is None:
|
||||
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
|
||||
self._offsets = list(range(0, self.total_samples, self.process_length))
|
||||
else:
|
||||
storage_shape = _get_first_shape(self._storage)
|
||||
arrays_shape = _get_first_shape(arrays)
|
||||
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
|
||||
# If we get new arrays that are too big too fit, we expand the shape fo the storage
|
||||
self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index)
|
||||
slice_len = self._nested_set_tensors(self._storage, arrays)
|
||||
for i in range(self.world_size):
|
||||
self._offsets[i] += slice_len
|
||||
|
||||
def _nested_set_tensors(self, storage, arrays):
|
||||
if isinstance(arrays, (list, tuple)):
|
||||
for x, y in zip(storage, arrays):
|
||||
slice_len = self._nested_set_tensors(x, y)
|
||||
return slice_len
|
||||
assert (
|
||||
arrays.shape[0] % self.world_size == 0
|
||||
), f"Arrays passed should all have a first dimension multiple of {self.world_size}, found {arrays.shape[0]}."
|
||||
|
||||
slice_len = arrays.shape[0] // self.world_size
|
||||
for i in range(self.world_size):
|
||||
if len(arrays.shape) == 1:
|
||||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
||||
else:
|
||||
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
|
||||
i * slice_len : (i + 1) * slice_len
|
||||
]
|
||||
return slice_len
|
||||
|
||||
def finalize(self):
|
||||
"""
|
||||
Return the properly gathered arrays and truncate to the number of samples (since the sampler added some extras
|
||||
to get each process a dataset of the same length).
|
||||
"""
|
||||
if self._storage is None:
|
||||
return
|
||||
if self._offsets[0] != self.process_length:
|
||||
logger.warn("Not all data has been set. Are you sure you passed all values?")
|
||||
return nested_truncate(self._storage, self.num_samples)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LabelSmoother:
|
||||
"""
|
||||
Adds label-smoothing on a pre-computed output from a Transformers model.
|
||||
Args:
|
||||
epsilon (:obj:`float`, `optional`, defaults to 0.1):
|
||||
The label smoothing factor.
|
||||
ignore_index (:obj:`int`, `optional`, defaults to -100):
|
||||
The index in the labels to ignore when computing the loss.
|
||||
"""
|
||||
|
||||
epsilon: float = 0.1
|
||||
ignore_index: int = -100
|
||||
|
||||
def __call__(self, model_output, labels):
|
||||
logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||
log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
|
||||
if labels.dim() == log_probs.dim() - 1:
|
||||
labels = labels.unsqueeze(-1)
|
||||
|
||||
padding_mask = labels.eq(self.ignore_index)
|
||||
# In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
|
||||
# will ignore them in any case.
|
||||
labels.clamp_min_(0)
|
||||
nll_loss = log_probs.gather(dim=-1, index=labels)
|
||||
smoothed_loss = log_probs.sum(dim=-1, keepdim=True)
|
||||
|
||||
nll_loss.masked_fill_(padding_mask, 0.0)
|
||||
smoothed_loss.masked_fill_(padding_mask, 0.0)
|
||||
|
||||
# Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
|
||||
num_active_elements = padding_mask.numel() - padding_mask.long().sum()
|
||||
nll_loss = nll_loss.sum() / num_active_elements
|
||||
smoothed_loss = smoothed_loss.sum() / (num_active_elements * log_probs.shape[-1])
|
||||
return (1 - self.epsilon) * nll_loss + self.epsilon * smoothed_loss
|
||||
|
||||
|
||||
def get_length_grouped_indices(lengths, batch_size, mega_batch_mult=None, generator=None):
|
||||
"""
|
||||
Return a list of indices so that each slice of :obj:`batch_size` consecutive indices correspond to elements of
|
||||
similar lengths. To do this, the indices are:
|
||||
- randomly permuted
|
||||
- grouped in mega-batches of size :obj:`mega_batch_mult * batch_size`
|
||||
- sorted by length in each mega-batch
|
||||
The result is the concatenation of all mega-batches, with the batch of :obj:`batch_size` containing the element of
|
||||
maximum length placed first, so that an OOM happens sooner rather than later.
|
||||
"""
|
||||
# Default for mega_batch_mult: 50 or the number to get 4 megabatches, whichever is smaller.
|
||||
if mega_batch_mult is None:
|
||||
mega_batch_mult = min(len(lengths) // (batch_size * 4), 50)
|
||||
# Just in case, for tiny datasets
|
||||
if mega_batch_mult == 0:
|
||||
mega_batch_mult = 1
|
||||
|
||||
# We need to use torch for the random part as a distributed sampler will set the random seed for torch.
|
||||
indices = torch.randperm(len(lengths), generator=generator)
|
||||
megabatch_size = mega_batch_mult * batch_size
|
||||
megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
|
||||
megabatches = [list(sorted(megabatch, key=lambda i: lengths[i], reverse=True)) for megabatch in megabatches]
|
||||
|
||||
# The rest is to get the biggest batch first.
|
||||
# Since each megabatch is sorted by descending length, the longest element is the first
|
||||
megabatch_maximums = [lengths[megabatch[0]] for megabatch in megabatches]
|
||||
max_idx = torch.argmax(torch.tensor(megabatch_maximums)).item()
|
||||
# Switch to put the longest element in first position
|
||||
megabatches[0][0], megabatches[max_idx][0] = megabatches[max_idx][0], megabatches[0][0]
|
||||
|
||||
return sum(megabatches, [])
|
||||
|
||||
|
||||
class LengthGroupedSampler(Sampler):
|
||||
r"""
|
||||
Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
|
||||
keeping a bit of randomness.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __len__(self):
|
||||
return len(self.lengths)
|
||||
|
||||
def __iter__(self):
|
||||
indices = get_length_grouped_indices(self.lengths, self.batch_size)
|
||||
return iter(indices)
|
||||
|
||||
|
||||
class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
r"""
|
||||
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
|
||||
length while keeping a bit of randomness.
|
||||
"""
|
||||
# Copied and adapted from PyTorch DistributedSampler.
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
lengths: Optional[List[int]] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0:
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil((len(self.dataset) - self.num_replicas) / self.num_replicas)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.seed = seed
|
||||
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
# Deterministically shuffle based on epoch and seed
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.seed + self.epoch)
|
||||
indices = get_length_grouped_indices(self.lengths, self.batch_size, generator=g)
|
||||
|
||||
if not self.drop_last:
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[: (self.total_size - len(indices))]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank : self.total_size : self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
return iter(indices)
|
||||
|
||||
|
||||
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
|
||||
# helper methods here
|
||||
|
||||
|
||||
def _get_learning_rate(self):
|
||||
if self.deepspeed:
|
||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
||||
try:
|
||||
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||
except AssertionError as e:
|
||||
if "need to call step" in str(e):
|
||||
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
||||
last_lr = 0
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
last_lr = (
|
||||
# backward compatibility for pytorch schedulers
|
||||
self.lr_scheduler.get_last_lr()[0]
|
||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||
else self.lr_scheduler.get_lr()[0]
|
||||
)
|
||||
return last_lr
|
||||
|
||||
|
||||
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
Reformat Trainer metrics values to a human-readable format
|
||||
Args:
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
Returns:
|
||||
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
||||
"""
|
||||
|
||||
metrics_copy = metrics.copy()
|
||||
for k, v in metrics_copy.items():
|
||||
if "_mem_" in k:
|
||||
metrics_copy[k] = f"{ v >> 20 }MB"
|
||||
elif k == "total_flos":
|
||||
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
||||
elif type(metrics_copy[k]) == float:
|
||||
metrics_copy[k] = round(v, 4)
|
||||
|
||||
return metrics_copy
|
||||
|
||||
|
||||
def log_metrics(self, split, metrics):
|
||||
"""
|
||||
Log metrics in a specially formatted way
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||
"""
|
||||
|
||||
logger.info(f"***** {split} metrics *****")
|
||||
metrics_formatted = self.metrics_format(metrics)
|
||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||
for key in sorted(metrics_formatted.keys()):
|
||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||
|
||||
|
||||
def save_metrics(self, split, metrics):
|
||||
"""
|
||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||
Args:
|
||||
split (:obj:`str`):
|
||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||
metrics (:obj:`Dict[str, float]`):
|
||||
The metrics returned from train/evaluate/predict
|
||||
"""
|
||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||
with open(path, "w") as f:
|
||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
|
@ -0,0 +1,86 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# coding=utf-8
|
||||
# Copyright 2020-present the HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Utilities for the Trainer and TFTrainer class. Should be independent from PyTorch and TensorFlow.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
import logging
|
||||
|
||||
from utils import print_rank
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
||||
installed).
|
||||
Args:
|
||||
seed (:obj:`int`): The seed to set.
|
||||
"""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
|
||||
|
||||
class EvalPrediction(NamedTuple):
|
||||
"""
|
||||
Evaluation output (always contains labels), to be used to compute metrics.
|
||||
Parameters:
|
||||
predictions (:obj:`np.ndarray`): Predictions of the model.
|
||||
label_ids (:obj:`np.ndarray`): Targets to be matched.
|
||||
"""
|
||||
|
||||
predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
||||
label_ids: np.ndarray
|
||||
|
||||
|
||||
class PredictionOutput(NamedTuple):
|
||||
predictions: Union[np.ndarray, Tuple[np.ndarray]]
|
||||
label_ids: Optional[np.ndarray]
|
||||
metrics: Optional[Dict[str, float]]
|
||||
|
||||
|
||||
class ComputeMetrics:
|
||||
def __init__(self, p: EvalPrediction, mask=None):
|
||||
self.EvalPrediction = EvalPrediction
|
||||
self.compute_metrics( self.EvalPrediction)
|
||||
|
||||
@staticmethod
|
||||
def compute_metrics(p: EvalPrediction, mask=None):
|
||||
print_rank('Prediction Block Size: {}'.format(p.predictions.size()), loglevel=logging.DEBUG)
|
||||
if len(list(p.predictions.size()))<3:
|
||||
if len(list(p.predictions.size()))<2:
|
||||
print_rank('There is something REALLY wrong with prediction tensor:'.format(p.predictions.size()), loglevel=logging.INFO)
|
||||
return {'acc': torch.tensor(0.0)}
|
||||
print_rank('There is something wrong with prediction tensor:'.format(p.predictions.size()), loglevel=logging.INFO)
|
||||
preds = np.argmax(p.predictions, axis=1)
|
||||
else:
|
||||
preds = np.argmax(p.predictions, axis=2)
|
||||
|
||||
if mask is None:
|
||||
return {'acc': (preds == p.label_ids).float().mean()}
|
||||
else:
|
||||
#valid = preds >1 # reject oov predictions even if they're correct.
|
||||
valid = mask==1
|
||||
return {'acc': (preds.eq(p.label_ids.cpu()) * valid.cpu()).float().mean()}
|
|
@ -0,0 +1,41 @@
|
|||
# Simple example of a NLG task on Reddit Dataset
|
||||
|
||||
Instructions on how to run the experiment, given below.
|
||||
|
||||
## Preparing the data
|
||||
|
||||
Right now FLUTE expects data to be provided either in JSON or HDF5 formats. It
|
||||
should be made data-agnostic in the near future, but at this moment we need to do some
|
||||
preprocessing before handling the data on the model. For this experiment, we can run the
|
||||
script located in `testing/create_data.py` as follows:
|
||||
|
||||
```code
|
||||
python create_data.py -e nlg
|
||||
```
|
||||
to download mock data already preprocessed. A new folder `mockup` will be generated
|
||||
inside `testing` with all data needed for a local run.
|
||||
|
||||
A couple of scripts are provided in `utils/preprocessing` for preprocessing .tsv files
|
||||
in case you want to use your own data.
|
||||
|
||||
## Creating a config file
|
||||
|
||||
All the parameters of the experiment are passed in a YAML file. An basic example is
|
||||
provided in `configs/hello_world_nlg_gru_json.yaml` with the suggested
|
||||
parameters for local runs.
|
||||
|
||||
The example provided above is for running json files. If you want to try with HDF5 files
|
||||
make sure to use the script `utils/preprocessing/from_json_to_hdf5.py` to convert the mock
|
||||
data to HDF5 format.
|
||||
|
||||
## Running the experiment
|
||||
|
||||
Finally, to launch the experiment locally , it suffices to launch the `e2e_trainer.py`
|
||||
script using MPI, you can use as example the following line:
|
||||
|
||||
```code
|
||||
mpiexec -n 3 python e2e_trainer.py -dataPath .\testing\mockup\ -outputPath scratch -config .\testing\configs\hello_world_local.yaml -task nlg_gru
|
||||
```
|
||||
|
||||
For submitting jobs in Azure ML, we have included the instructions in the `Experiments`
|
||||
section of the main `README.md`.
|
|
@ -0,0 +1,93 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import random
|
||||
import torch
|
||||
import numpy as np
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from experiments.nlg_gru.dataloaders.text_dataset import TextDataset
|
||||
from utils.data_utils import BatchSampler, DynamicBatchSampler
|
||||
|
||||
class TextDataLoader(DataLoader):
|
||||
"""
|
||||
PyTorch dataloader for loading text data from
|
||||
text_dataset.
|
||||
"""
|
||||
def __init__(self, mode, num_workers=0, **kwargs):
|
||||
|
||||
args = kwargs['args']
|
||||
self.batch_size = args['batch_size']
|
||||
batch_sampler = None
|
||||
|
||||
dataset = TextDataset(
|
||||
data = kwargs['data'],
|
||||
test_only = not mode=="train",
|
||||
vocab_dict = args['vocab_dict'],
|
||||
user_idx = kwargs['user_idx'],
|
||||
max_num_words= args['max_num_words'],
|
||||
preencoded = args.get('preencoded', False))
|
||||
|
||||
if mode == 'train':
|
||||
|
||||
sampler = DistributedSampler(dataset,num_replicas=1,rank=0)
|
||||
sampler.set_epoch(random.randint(0, 10**10))
|
||||
batch_sampler = DynamicBatchSampler(sampler,
|
||||
frames_threshold = args['max_num_words'],
|
||||
max_batch_size = self.batch_size,
|
||||
unsorted_batch = args['unsorted_batch'],
|
||||
fps=1)
|
||||
|
||||
elif mode == 'val' or mode == 'test':
|
||||
sampler = BatchSampler(dataset, batch_size=self.batch_size, randomize=False, drop_last=False)
|
||||
super().__init__(dataset,
|
||||
batch_sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
pin_memory=args["pin_memory"])
|
||||
return
|
||||
|
||||
if batch_sampler is None:
|
||||
super().__init__(dataset,
|
||||
batch_size=self.batch_size,
|
||||
sampler=sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
drop_last=True)
|
||||
else:
|
||||
super().__init__(dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
pin_memory=args["pin_memory"])
|
||||
|
||||
|
||||
def create_loader(self):
|
||||
return self
|
||||
|
||||
|
||||
def collate_fn(self, batch):
|
||||
def pad_and_concat_feats(labels):
|
||||
batch_size = len(labels)
|
||||
max_len = max(len(l[0]) for l in labels)
|
||||
cat_labels = np.full((batch_size, max_len), -1)
|
||||
|
||||
for e, l in enumerate(labels):
|
||||
cat_labels[e,:len(l[0])] = np.squeeze(l)
|
||||
return cat_labels
|
||||
|
||||
|
||||
src_seq, utt_ids = zip(*batch)
|
||||
x_len = [len(s[0]) for s in src_seq]
|
||||
|
||||
src_seq = pad_and_concat_feats(src_seq)
|
||||
packed = {
|
||||
'x': torch.from_numpy(src_seq).long(),
|
||||
'x_len': x_len,
|
||||
'utt_ids' : utt_ids,
|
||||
'total_frames' : sum(x_len),
|
||||
'total_frames_with_padding' : np.prod(src_seq.shape),
|
||||
'loss_weight' : None
|
||||
}
|
||||
return packed
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from utils import print_rank
|
||||
from core.globals import file_type
|
||||
from experiments.nlg_gru.utils.utility import *
|
||||
import numpy as np
|
||||
import h5py
|
||||
import logging
|
||||
import json
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Map a text source to the target text
|
||||
"""
|
||||
|
||||
def __init__(self, data, min_num_words=2, max_num_words=25, test_only=False, user_idx=None, vocab_dict=None, preencoded=False):
|
||||
|
||||
self.utt_list = list()
|
||||
self.test_only = test_only
|
||||
self.max_num_words = max_num_words
|
||||
self.min_num_words = min_num_words
|
||||
self.preencoded = preencoded
|
||||
|
||||
# Load the vocab
|
||||
self.vocab = load_vocab(vocab_dict)
|
||||
self.vocab_size = len(self.vocab)
|
||||
|
||||
# reading the jsonl for a specific user_idx
|
||||
self.read_data(data, user_idx)
|
||||
|
||||
def __len__(self):
|
||||
"""Return the length of the elements in the list."""
|
||||
return len(self.utt_list)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Find the index in the available data"""
|
||||
|
||||
if self.preencoded:
|
||||
batch = np.array([self.utt_list[idx]['src_text']], dtype=np.int32)
|
||||
else:
|
||||
# case_backoff_batch tries to find the best capitalisation that will allow the word to be in vocabulary
|
||||
batch = case_backoff_batch([self.utt_list[idx]['src_text']], self.vocab.term_to_idx)
|
||||
batch = to_indices(self.vocab, batch)
|
||||
|
||||
return batch, self.user
|
||||
|
||||
# Reads JSON or HDF5 files
|
||||
def read_data(self, orig_strct, user_idx):
|
||||
|
||||
if isinstance(orig_strct, str):
|
||||
if file_type == "json":
|
||||
print('Loading json-file: ', orig_strct)
|
||||
with open(orig_strct, 'r') as fid:
|
||||
orig_strct = json.load(fid)
|
||||
|
||||
elif file_type == "hdf5":
|
||||
print('Loading hdf5-file: ', orig_strct)
|
||||
orig_strct = h5py.File(orig_strct, 'r')
|
||||
|
||||
self.user_list = orig_strct['users']
|
||||
self.num_samples = orig_strct['num_samples']
|
||||
self.user_data = orig_strct['user_data']
|
||||
|
||||
if self.test_only:
|
||||
self.user = 'test_only'
|
||||
self.process_x(self.user_data)
|
||||
else:
|
||||
self.user = self.user_list[user_idx]
|
||||
self.process_x(self.user_data[self.user])
|
||||
|
||||
|
||||
def process_x(self, raw_x_batch):
|
||||
print_rank('Processing data-structure: {} Utterances expected'.format(sum(self.num_samples)), loglevel=logging.DEBUG)
|
||||
if self.test_only:
|
||||
for user in self.user_list:
|
||||
for e in raw_x_batch[user]['x']:
|
||||
utt={}
|
||||
utt['src_text'] = e if type(e) is list else e.split()
|
||||
utt['duration'] = len(e)
|
||||
utt["loss_weight"] = 1.0
|
||||
self.utt_list.append(utt)
|
||||
|
||||
else:
|
||||
for e in raw_x_batch['x']:
|
||||
utt={}
|
||||
utt['src_text'] = e if type(e) is list else e.split()
|
||||
utt['duration'] = len(utt["src_text"])
|
||||
if utt['duration']<= self.min_num_words:
|
||||
continue
|
||||
|
||||
if utt['duration'] > self.max_num_words:
|
||||
utt['src_text'] = utt['src_text'][:self.max_num_words]
|
||||
utt['duration'] = self.max_num_words
|
||||
utt["loss_weight"] = 1.0
|
||||
self.utt_list.append(utt)
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch as T
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional, NamedTuple
|
||||
from utils import softmax
|
||||
|
||||
class GRU2(T.nn.Module):
|
||||
def __init__(self, input_size, hidden_size, input_bias, hidden_bias):
|
||||
super(GRU2, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.w_ih = T.nn.Linear(input_size, 3 * hidden_size, input_bias)
|
||||
self.w_hh = T.nn.Linear(hidden_size, 3 * hidden_size, hidden_bias)
|
||||
|
||||
def _forward_cell(self, input : Tensor, hidden : Tensor) -> Tensor:
|
||||
g_i = self.w_ih(input)
|
||||
g_h = self.w_hh(hidden)
|
||||
i_r, i_i, i_n = g_i.chunk(3, 1)
|
||||
h_r, h_i, h_n = g_h.chunk(3, 1)
|
||||
reset_gate = T.sigmoid(i_r + h_r)
|
||||
input_gate = T.sigmoid(i_i + h_i)
|
||||
new_gate = T.tanh(i_n + reset_gate * h_n)
|
||||
hy = new_gate + input_gate * (hidden - new_gate)
|
||||
return hy
|
||||
|
||||
def forward(self, input : Tensor) -> Tuple[Tensor, Tensor]:
|
||||
hiddens : List[Tensor] = [T.zeros((input.shape[0], self.hidden_size)).cuda() if T.cuda.is_available() \
|
||||
else T.zeros((input.shape[0], self.hidden_size))]
|
||||
for step in range(input.shape[1]):
|
||||
hidden = self._forward_cell(input[:, step], hiddens[-1])
|
||||
hiddens.append(hidden)
|
||||
|
||||
return T.stack(hiddens, dim=1), hiddens[-1]
|
||||
|
||||
|
||||
class Embedding(T.nn.Module):
|
||||
def __init__(self, vocab_size, embedding_size):
|
||||
super(Embedding, self).__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.embedding_size = embedding_size
|
||||
self.table = T.nn.Parameter(T.zeros((vocab_size, embedding_size)))
|
||||
self.unembedding_bias = T.nn.Parameter(T.zeros(vocab_size))
|
||||
delta = (3 / self.table.shape[1]) ** 0.5
|
||||
T.nn.init.uniform_(self.table, -delta, delta)
|
||||
|
||||
def forward(self, input : Tensor, embed : bool) -> Tensor:
|
||||
if embed:
|
||||
output = T.nn.functional.embedding(input, self.table)
|
||||
else:
|
||||
output = input @ self.table.t() + self.unembedding_bias
|
||||
return output
|
||||
|
||||
|
||||
class GRU(T.nn.Module): #DLM_2_0
|
||||
def __init__(self, model_config, OOV_correct=False, dropout=0.0, topK_results=1, wantLogits=False, **kwargs):
|
||||
super(GRU, self).__init__()
|
||||
self.vocab_size = model_config['vocab_size']
|
||||
self.embedding_size = model_config['embed_dim']
|
||||
self.hidden_size = model_config['hidden_dim']
|
||||
self.embedding = Embedding(self.vocab_size, self.embedding_size)
|
||||
self.rnn = GRU2(self.embedding_size, self.hidden_size, True, True)
|
||||
self.squeeze = T.nn.Linear(self.hidden_size, self.embedding_size, bias=False)
|
||||
self.OOV_correct = OOV_correct
|
||||
self.topK_results = topK_results
|
||||
self.dropout=dropout
|
||||
self.wantLogits=wantLogits
|
||||
if self.dropout>0.0:
|
||||
self.drop_layer = T.nn.Dropout(p=self.dropout)
|
||||
|
||||
def forward(self, input : T.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
input = input['x'] if isinstance(input, dict) else input
|
||||
input = input.cuda() if T.cuda.is_available() else input
|
||||
embedding = self.embedding(input, True)
|
||||
hiddens, state = self.rnn(embedding)
|
||||
if self.dropout>0.0:
|
||||
hiddens= self.drop_layer(hiddens)
|
||||
output = self.embedding(self.squeeze(hiddens), False)
|
||||
return output, state
|
||||
|
||||
|
||||
def loss(self, input : T.Tensor) -> T.Tensor:
|
||||
input = input['x'] if isinstance(input, dict) else input
|
||||
input = input.cuda() if T.cuda.is_available() else input
|
||||
non_pad_mask = input >= 0
|
||||
input = input * non_pad_mask.long()
|
||||
non_pad_mask = non_pad_mask.view(-1)
|
||||
|
||||
# Run the forward pass
|
||||
output, _ = self.forward(input[:, :-1])
|
||||
|
||||
# Estimate the targets
|
||||
targets = input.view(-1)[non_pad_mask]
|
||||
preds = output.view(-1, self.vocab_size)[non_pad_mask]
|
||||
|
||||
# Estimate the loss
|
||||
return T.nn.functional.cross_entropy(preds, targets)
|
||||
|
||||
|
||||
def inference(self, input):
|
||||
input = input['x'] if isinstance(input, dict) else input
|
||||
input = input.cuda() if T.cuda.is_available() else input
|
||||
non_pad_mask = input >= 0
|
||||
input = input * non_pad_mask.long()
|
||||
non_pad_mask = non_pad_mask.view(-1)
|
||||
output, _ = self.forward(input[:, :-1])
|
||||
|
||||
# Apply mask to input/output
|
||||
targets = input.view(-1)[non_pad_mask]
|
||||
preds = output.view(-1, self.vocab_size)[non_pad_mask]
|
||||
|
||||
# accuracy
|
||||
probs_topK, preds_topK = T.topk(preds, self.topK_results, sorted=True, dim=1)
|
||||
probs, preds = probs_topK[:,0], preds_topK[:,0]
|
||||
if self.OOV_correct:
|
||||
acc = preds.eq(targets).float().mean()
|
||||
else:
|
||||
valid = preds != 0 # reject oov predictions even if they're correct.
|
||||
acc = (preds.eq(targets) * valid).float().mean()
|
||||
|
||||
if self.wantLogits:
|
||||
if 1:
|
||||
output= {'probabilities': softmax(probs_topK.cpu().detach().numpy(), axis=1),
|
||||
'predictions': preds_topK.cpu().detach().numpy(),
|
||||
'labels': targets.cpu().detach().numpy()}
|
||||
else:
|
||||
output = {'probabilities': probs_topK.cpu().detach().numpy(),
|
||||
'predictions': preds_topK.cpu().detach().numpy(),
|
||||
'labels': targets.cpu().detach().numpy()}
|
||||
return output, acc.item(), input.shape[0]
|
||||
|
||||
def set_eval(self):
|
||||
"""
|
||||
Bring the model into evaluation mode
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
|
||||
def set_train(self):
|
||||
"""
|
||||
Bring the model into train mode
|
||||
"""
|
||||
self.train()
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
|
||||
TR_UPPER = {ord('i'): 'İ'}
|
||||
TR_LOWER = {ord('I'): 'ı'}
|
||||
|
||||
Vocab = namedtuple('Vocab', ['idx_to_term', 'term_to_idx'])
|
||||
|
||||
|
||||
def load_vocab(url):
|
||||
"""Load a vocabulary file.
|
||||
|
||||
url -- string -- url to the txt file
|
||||
|
||||
returns -- Vocab(idx_to_term=list, term_to_idx=dict)
|
||||
"""
|
||||
term_to_idx = {}
|
||||
idx_to_term = []
|
||||
with open(url, 'r', encoding='utf-8') as f:
|
||||
for i, line in enumerate(f):
|
||||
word = line.strip()
|
||||
idx_to_term.append(word)
|
||||
term_to_idx[word] = i
|
||||
return Vocab(idx_to_term, term_to_idx)
|
||||
|
||||
|
||||
def to_indices(vocab, batch, ndim=2, oov_idx=0, pad_idx=-1):
|
||||
"""Convert a nested list of strings to a np.array of integers.
|
||||
|
||||
vocab -- Vocab -- the vocabulary of the model
|
||||
|
||||
batch -- [..[str]..] -- multidimensional batch
|
||||
|
||||
ndim -- int -- number of dimensions in batch
|
||||
|
||||
oov_idx -- int or None -- if specified, replace missing terms by
|
||||
the given index, otherwise raise an error
|
||||
|
||||
pad_idx -- int or None -- if specified, pad short last-dimension
|
||||
as specified, otherwise raise an error
|
||||
|
||||
raises -- ValueError -- if pad is required but pad_idx not specified
|
||||
-- KeyError -- if oov is required but oov_idx not specified
|
||||
|
||||
returns -- np.array(int) -- term indices
|
||||
"""
|
||||
#print_rank(f'to_indices: batch len: {len(batch)} ndim: {ndim}')
|
||||
if ndim == 1:
|
||||
return np.array(
|
||||
[(vocab.term_to_idx[term] if oov_idx is None else
|
||||
vocab.term_to_idx.get(term, oov_idx))
|
||||
for term in batch], dtype=np.int32)
|
||||
|
||||
if ndim == 2:
|
||||
# note: in most circumstances there is only one example in the batch
|
||||
# as a result, padding is never applied. We rely on collate_fn to properly
|
||||
# apply padding.
|
||||
length = max(len(row) for row in batch)
|
||||
if pad_idx is None and min(len(row) for row in batch) != length:
|
||||
raise ValueError('Padding required, but no pad_idx provided')
|
||||
pad = length * [pad_idx]
|
||||
|
||||
result = np.array(
|
||||
[[(vocab.term_to_idx[term] if oov_idx is None else
|
||||
vocab.term_to_idx.get(term, oov_idx))
|
||||
for term in row] + pad[len(row):]
|
||||
for row in batch], dtype=np.int32)
|
||||
#print_rank(f'to_indices result: {result.shape}')
|
||||
return result
|
||||
|
||||
# Flatten to a 2D batch, then recurse & reshape up (this ensures
|
||||
# padding is handled correctly)
|
||||
shape = [len(batch)]
|
||||
for _ in range(2, ndim):
|
||||
shape.append(len(batch[0]))
|
||||
batch = [item for sub_batch in batch for item in sub_batch]
|
||||
shape.append(-1)
|
||||
return to_indices(vocab, batch, ndim=2, oov_idx=oov_idx, pad_idx=pad_idx).reshape(*shape)
|
||||
|
||||
def case_backoff_batch(batch, vocab):
|
||||
"""Perform capitalization backoff on words both to lower & initial-upper case variants.
|
||||
|
||||
batch -- list(list(string)) -- batch of sentences of words, to back off
|
||||
|
||||
vocab -- set(string) -- vocabulary to consider
|
||||
|
||||
returns -- list(list(string)) -- backed-off batch
|
||||
"""
|
||||
|
||||
def _variants(word):
|
||||
yield word
|
||||
yield word.translate(TR_LOWER).lower()
|
||||
yield word.lower()
|
||||
if len(word) > 1:
|
||||
yield word[0].translate(TR_UPPER).capitalize() + word[1:]
|
||||
yield word.capitalize()
|
||||
|
||||
return [[next((variant for variant in _variants(word) if variant in vocab),
|
||||
word) # will become OOV
|
||||
for word in sentence]
|
||||
for sentence in batch]
|
||||
|
||||
|
||||
def encode_data(data_dict, vocab):
|
||||
'''Encode data that is in the format expected by FLUTE
|
||||
|
||||
Parameters
|
||||
----------
|
||||
data_dict: dict
|
||||
Dictionary where keys consist of usernames and values give
|
||||
the data for that user, specified by another dictionary with
|
||||
keys :code:`x` (features) and, optionally, :code:`y` (labels).
|
||||
vocab:
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Dictionary in the same format as the input one, but now the
|
||||
data in the :code:`x` field is given by tokens (i.e., integers),
|
||||
instead of strings.
|
||||
'''
|
||||
new_dict = {}
|
||||
for key, value in tqdm(data_dict.items()):
|
||||
user_data = [s.split() for s in value['x']]
|
||||
processed_data = case_backoff_batch(user_data, vocab.term_to_idx)
|
||||
encoded_data = [[vocab.term_to_idx.get(term, 0) for term in row] for row in processed_data]
|
||||
new_dict[key] = {'x': encoded_data}
|
||||
|
||||
return new_dict
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = ArgumentParser(description='Encodes data')
|
||||
parser.add_argument('data_path', type=str, help='Path to data')
|
||||
parser.add_argument('vocab_path', type=str, help='Path to vocabulary')
|
||||
args = parser.parse_args()
|
||||
|
||||
if not os.path.isfile(args.data_path):
|
||||
raise ValueError('data file does not exist')
|
||||
if not os.path.isfile(args.vocab_path):
|
||||
raise ValueError('vocabulary file does not exist')
|
||||
if args.data_path[-5:] != '.json':
|
||||
raise ValueError('argument must be a valid json file')
|
||||
|
||||
# Load vocabulary
|
||||
print('Loading vocabulary...')
|
||||
vocab = load_vocab(args.vocab_path)
|
||||
|
||||
# Load and encode data
|
||||
print('Loading data... ', end='', flush=True)
|
||||
start_time = time.time()
|
||||
with open(args.data_path, 'r') as input_file:
|
||||
all_data = json.load(input_file)
|
||||
print(f'Finished in {time.time() - start_time:.2f}s')
|
||||
|
||||
print('Converting data...')
|
||||
converted_user_data = encode_data(all_data['user_data'], vocab)
|
||||
|
||||
# For debug purposes
|
||||
for k, v in converted_user_data.items():
|
||||
print(f'USER: {k}\nDATA: {v}')
|
||||
break
|
||||
|
||||
# Save encoded data to disk
|
||||
print('Saving encoded data to disk...')
|
||||
all_data['user_data'] = converted_user_data
|
||||
with open(f'{args.data_path[:-5]}-encoded.json', 'w') as output_file:
|
||||
json.dump(all_data, output_file)
|
|
@ -0,0 +1,345 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
from utils import ( make_lr_scheduler,
|
||||
print_rank,
|
||||
torch_save,
|
||||
try_except_save,
|
||||
make_optimizer)
|
||||
|
||||
class SequenceWise(nn.Module):
|
||||
def __init__(self, module):
|
||||
"""
|
||||
Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
|
||||
Allows handling of variable sequence lengths and minibatch sizes.
|
||||
:param module: Module to apply input to.
|
||||
"""
|
||||
super(SequenceWise, self).__init__()
|
||||
self.module = module
|
||||
|
||||
def forward(self, x):
|
||||
t, n = x.size(0), x.size(1)
|
||||
x = x.view(t * n, -1)
|
||||
x = x.contiguous()
|
||||
x = self.module(x)
|
||||
x = x.view(t, n, -1)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
tmpstr = self.__class__.__name__ + ' (\n'
|
||||
tmpstr += self.module.__repr__()
|
||||
tmpstr += ')'
|
||||
return tmpstr
|
||||
|
||||
|
||||
class BatchRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True,dropout=0.0,multi=1):
|
||||
super(BatchRNN, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.batch_norm_activate = batch_norm
|
||||
self.bidirectional = bidirectional
|
||||
self.multi = multi
|
||||
self.dropout = dropout
|
||||
|
||||
if self.batch_norm_activate:
|
||||
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size))
|
||||
self.rnn = rnn_type(input_size = input_size,
|
||||
hidden_size = hidden_size,
|
||||
bidirectional= bidirectional,
|
||||
bias = True,
|
||||
batch_first = True,
|
||||
dropout = self.dropout)
|
||||
self.num_directions = 2 if bidirectional else 1
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if x.dim()==2:
|
||||
x=x.unsqueeze(1)
|
||||
|
||||
if self.batch_norm_activate:
|
||||
x = x.contiguous()
|
||||
x = self.batch_norm(x)
|
||||
x, _ = self.rnn(x)
|
||||
|
||||
if self.bidirectional and self.multi<2:
|
||||
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1)
|
||||
return x
|
||||
|
||||
|
||||
class NeuralNetwork(nn.Module):
|
||||
def __init__(self, params, wantLSTM=False, batch_norm=False):
|
||||
super(NeuralNetwork, self).__init__()
|
||||
|
||||
"""
|
||||
The following parameters need revisiting
|
||||
self.number_of_actions = 2
|
||||
self.gamma = 0.99
|
||||
self.final_epsilon = 0.0001
|
||||
self.initial_epsilon = 0.1
|
||||
self.number_of_iterations = 2000000
|
||||
self.replay_memory_size = 10000
|
||||
self.minibatch_size = 32
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-6)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
"""
|
||||
self.wantLSTM = wantLSTM
|
||||
self.batch_norm= batch_norm
|
||||
params = [int(x) for x in params.split(',')]
|
||||
layers = []
|
||||
|
||||
self.softmax = nn.Softmax(dim = 1)
|
||||
if self.wantLSTM:
|
||||
# Recurrent Component of the architecture
|
||||
rnns = []
|
||||
for i in range(1, len(params) - 2):
|
||||
multi = 1 if i==1 else 1
|
||||
rnn = BatchRNN(input_size = params[i-1]*multi,
|
||||
hidden_size = params[i],
|
||||
rnn_type = nn.LSTM,
|
||||
bidirectional= True,
|
||||
batch_norm = batch_norm,
|
||||
multi = 1,
|
||||
dropout = 0.0)
|
||||
rnns.append(('%d' %(i-1), rnn))
|
||||
self.rnn = nn.Sequential(OrderedDict(rnns))
|
||||
|
||||
layers.append(nn.Linear(params[-3], params[-2], bias=True))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Linear(params[-2], params[-1], bias=True))
|
||||
mlp = nn.Sequential(*layers)
|
||||
self.mlp = nn.Sequential(SequenceWise(mlp),)
|
||||
|
||||
else:
|
||||
if self.batch_norm:
|
||||
self.batch_norm = nn.BatchNorm1d(params[0])
|
||||
|
||||
for i in range(1, len(params)-1):
|
||||
layers.append(nn.Linear(params[i-1], params[i], bias=True))
|
||||
layers.append(nn.ReLU(inplace=True))
|
||||
layers.append(nn.Linear(params[-2], params[-1], bias=True))
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
if self.wantLSTM:
|
||||
x = self.rnn(x)
|
||||
|
||||
if self.batch_norm:
|
||||
x = self.batch_norm(x)
|
||||
out = self.mlp(x)
|
||||
out = out.squeeze()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
|
||||
|
||||
class RL:
|
||||
def __init__(self, config=None):
|
||||
|
||||
# Finalized config-file
|
||||
self.config= config
|
||||
|
||||
self.out_size = config["num_clients_per_iteration"]
|
||||
self.wantLSTM = config['RL']['wantLSTM'] if 'wantLSTM' in config['RL'] else False
|
||||
self.replay_memory= []
|
||||
self.state_memory = []
|
||||
self.epsilon= config['RL']['initial_epsilon']
|
||||
self.step =0
|
||||
self.runningLoss =0
|
||||
|
||||
model_descriptor = config['RL']['model_descriptor_RL'] if 'model_descriptor_RL' in config['RL'] else 'Default'
|
||||
self.model_name = os.path.join(config['RL']['RL_path'], 'rl_{}.{}.model'.format(self.out_size, model_descriptor))
|
||||
self.stats_name = os.path.join(config['RL']['RL_path'], 'rl_{}.{}.stats'.format(self.out_size, model_descriptor))
|
||||
|
||||
# Initialize RL model
|
||||
self.make_model()
|
||||
self.load_saved_status()
|
||||
|
||||
# Set the RL weights
|
||||
self.rl_weights=None
|
||||
self.rl_losses=None
|
||||
|
||||
self.criterion = nn.MSELoss()
|
||||
|
||||
def set_losses(self, losses):
|
||||
self.rl_losses=losses
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.rl_weights = weights
|
||||
|
||||
def forward(self, state=None):
|
||||
# epsilon greedy exploration
|
||||
|
||||
if self.wantLSTM:
|
||||
N = len(state)
|
||||
state.resize(1, N)
|
||||
if len(self.state_memory)==0:
|
||||
self.state_memory = np.zeros((self.config['RL']['minibatch_size'], N))
|
||||
self.state_memory = np.concatenate((self.state_memory[1:], state), axis=0)
|
||||
state = self.state_memory
|
||||
|
||||
if random.random() <= self.epsilon:
|
||||
print_rank("Performed random action!")
|
||||
action= torch.rand(self.out_size).cuda() if torch.cuda.is_available() else torch.rand(self.out_size)
|
||||
else:
|
||||
state = torch.from_numpy(state).cuda() if torch.cuda.is_available() else torch.from_numpy(state)
|
||||
print_rank(f'RL_state: {state.shape}')
|
||||
action= self.model(state.float())
|
||||
return action
|
||||
|
||||
|
||||
|
||||
def train(self, batch=None):
|
||||
# save transition to replay memory
|
||||
self.replay_memory.append(batch)
|
||||
|
||||
# if replay memory is full, remove the oldest transition
|
||||
if len(self.replay_memory) > self.config['RL']['max_replay_memory_size']:
|
||||
self.replay_memory.pop(0)
|
||||
|
||||
# epsilon annealing
|
||||
self.epsilon *= self.config['RL']['epsilon_gamma'] if self.epsilon*self.config['RL']['epsilon_gamma']>self.config['RL']['final_epsilon'] else 1.0
|
||||
|
||||
# sample random minibatch
|
||||
if self.wantLSTM:
|
||||
if len(self.replay_memory)>= self.config['RL']['minibatch_size']:
|
||||
minibatch = self.replay_memory[-self.config['RL']['minibatch_size']:]
|
||||
else:
|
||||
minibatch = self.replay_memory
|
||||
else:
|
||||
minibatch = random.sample(self.replay_memory, min(len(self.replay_memory), self.config['RL']['minibatch_size']))
|
||||
|
||||
# unpack minibatch
|
||||
state_batch = torch.tensor(tuple(d[0] for d in minibatch)).float()
|
||||
action_batch = torch.tensor(tuple(d[1] for d in minibatch)).float()
|
||||
reward_batch = torch.tensor(tuple(d[2] for d in minibatch)).float()
|
||||
|
||||
if torch.cuda.is_available(): # put on GPU if CUDA is available
|
||||
state_batch = state_batch.cuda()
|
||||
action_batch = action_batch.cuda()
|
||||
reward_batch = reward_batch.cuda()
|
||||
|
||||
|
||||
# set y_j to r_j for terminal state, otherwise to r_j + gamma*max(Q)
|
||||
y_batch = reward_batch
|
||||
|
||||
# extract Q-value
|
||||
print_rank(f'RL state_batch: {state_batch.shape}', loglevel=logging.DEBUG)
|
||||
state_output = self.model(state_batch)
|
||||
print_rank(f'RL train shapes: {state_batch.shape} {action_batch.shape} {state_output.shape}', loglevel=logging.DEBUG)
|
||||
q_value = torch.sum(state_output * action_batch, dim=1)
|
||||
|
||||
# reset gradient
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# returns a new Tensor, detached from the current graph, the result will never require gradient
|
||||
y_batch = y_batch.detach()
|
||||
|
||||
# calculate loss
|
||||
loss = self.criterion(q_value, y_batch)
|
||||
|
||||
# do backward pass
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
|
||||
# Tracking a running average of loss
|
||||
if self.runningLoss==0:
|
||||
self.runningLoss = loss.item()
|
||||
else:
|
||||
self.runningLoss = 0.95 * self.runningLoss + 0.05 * loss.item()
|
||||
print_rank('Running Loss for RL training process: {}'.format(self.runningLoss))
|
||||
|
||||
# Decay learning rate
|
||||
self.lr_scheduler.step()
|
||||
|
||||
|
||||
def make_model(self):
|
||||
# make model
|
||||
self.model = NeuralNetwork(self.config['RL']['network_params'], \
|
||||
self.config['RL']['wantLSTM'] if 'wantLSTM' in self.config['RL'] else False, \
|
||||
self.config['RL']['batchNorm'] if 'batchNorm' in self.config['RL'] else False)
|
||||
print(self.model)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.model = self.model.cuda()
|
||||
|
||||
# make optimizer
|
||||
self.optimizer = make_optimizer(self.config['RL']["optimizer_config"], self.model)
|
||||
|
||||
# make lr_scheduler
|
||||
self.lr_scheduler = make_lr_scheduler(
|
||||
self.config['RL']['annealing_config'],
|
||||
self.optimizer,
|
||||
num_batches=1)
|
||||
|
||||
|
||||
def load_saved_status(self):
|
||||
if os.path.exists(self.model_name):
|
||||
print_rank("Resuming from checkpoint model {}".format(self.model_name))
|
||||
self.load()
|
||||
|
||||
if os.path.exists(self.stats_name):
|
||||
with open(self.stats_name, 'r') as logfp: # loading the iteration no., val_loss and lr_weight
|
||||
elems = json.load(logfp)
|
||||
self.cur_iter_no= elems["i"]
|
||||
self.val_loss = elems["val_loss"]
|
||||
self.val_cer = elems["val_cer"]
|
||||
self.runningLoss= elems["weight"]
|
||||
|
||||
|
||||
|
||||
def load(self):
|
||||
print_rank("Loading checkpoint: {}".format(self.model_name))
|
||||
checkpoint = torch.load(self.model_name)
|
||||
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if self.optimizer is not None:
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
anl_st_dict = checkpoint.get('lr_scheduler_state_dict')
|
||||
if anl_st_dict and self.lr_scheduler is not None:
|
||||
self.lr_scheduler.load_state_dict(anl_st_dict)
|
||||
|
||||
|
||||
def save(self, i):
|
||||
"""
|
||||
Save a model as well as training information
|
||||
"""
|
||||
|
||||
save_state = {
|
||||
'model_state_dict' : self.model.state_dict(),
|
||||
'optimizer_state_dict' : self.optimizer.state_dict() if self.optimizer is not None else None,
|
||||
'lr_scheduler_state_dict' : self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None
|
||||
}
|
||||
|
||||
outputdir = os.path.dirname(self.model_name)
|
||||
if os.path.exists(outputdir) is False:
|
||||
os.makedirs(outputdir, exist_ok=True)
|
||||
|
||||
print_rank("Saving model to: {}".format(self.model_name))
|
||||
try_except_save(torch_save, state_or_model=save_state,
|
||||
save_path=self.model_name)
|
||||
|
||||
# logging the latest best values
|
||||
print_rank(f'Saving stats to {self.stats_name}')
|
||||
with open(self.stats_name, 'w') as logfp:
|
||||
json.dump({"i":i+1,
|
||||
"val_loss":float(self.rl_losses[0]),
|
||||
"val_cer":float(self.rl_losses[1]),
|
||||
"weight":float(self.runningLoss)},
|
||||
logfp)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from extensions.RL.RL import *
|
||||
from extensions.quantization.quant import *
|
|
@ -0,0 +1,209 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import numpy as np
|
||||
import torch as T
|
||||
import math
|
||||
import json
|
||||
from utils import print_rank
|
||||
from azureml.core import Run
|
||||
from scipy.special import betainc, betaln
|
||||
|
||||
run = Run.get_context()
|
||||
|
||||
def compute_LDP_noise_std(eps, max_sensitivity, delta):
|
||||
return np.sqrt(2 * np.log(1.25 / delta)) * max_sensitivity / eps
|
||||
|
||||
|
||||
def _beta2betainc_ratio(a, x):
|
||||
return 1 / betainc(a, a, x)
|
||||
|
||||
|
||||
def _log_m1(d, alpha, gamma):
|
||||
return alpha * np.log(1 - gamma**2) - (d - 2) * np.log(2) - np.log(d - 1)
|
||||
|
||||
|
||||
def _log_m2(p, tau, alpha):
|
||||
return np.log(p / (_beta2betainc_ratio(alpha, tau) - 1) - (1 - p)) + np.log(_beta2betainc_ratio(alpha, tau)) - betaln(alpha, alpha)
|
||||
|
||||
|
||||
def _efficient_m(d, gamma, p):
|
||||
alpha = (d - 1) / 2
|
||||
tau = (1 + gamma) / 2
|
||||
return np.exp(_log_m1(d, alpha, gamma) + _log_m2(p, tau, alpha))
|
||||
|
||||
|
||||
def privacy_parameters(eps0, eps, d):
|
||||
exp_eps0 = np.exp(eps0)
|
||||
exp_eps = np.exp(eps)
|
||||
if exp_eps0 == np.inf:
|
||||
p0 = 1
|
||||
else:
|
||||
p0 = exp_eps0 / (1 + exp_eps0)
|
||||
if exp_eps == np.inf:
|
||||
gamma = np.sqrt(np.pi / (2 * (d - 1)))
|
||||
else:
|
||||
gamma = ((exp_eps - 1) / (exp_eps + 1)) * np.sqrt(np.pi / (2 * (d - 1)))
|
||||
return p0, gamma
|
||||
|
||||
|
||||
def private_unit2(grad, gamma, prob):
|
||||
np.testing.assert_almost_equal(grad.norm().cpu().item(), 1, decimal=5)
|
||||
assert prob >= 0.5
|
||||
assert (0 <= gamma <= 1)
|
||||
p = T.rand(())
|
||||
while True:
|
||||
# create a uniform distriubtion over d-sphere
|
||||
V = T.normal(0, 1, grad.shape, device=grad.device)
|
||||
V = V / V.norm()
|
||||
dot_prod = T.dot(V, grad)
|
||||
if (dot_prod >= gamma and p < prob) or (dot_prod < gamma and p >= prob):
|
||||
break
|
||||
d = grad.shape[0]
|
||||
m = _efficient_m(d, gamma, prob)
|
||||
return V / m
|
||||
|
||||
|
||||
def add_gaussian_noise(grad, eps, max_grad, delta):
|
||||
sigma = compute_LDP_noise_std(eps, max_grad, delta)
|
||||
#sigma = np.sqrt(2 * np.log(1.25 / delta)) * max_grad / eps
|
||||
noisy_grad = sigma * T.randn(grad.shape, device=grad.device) + grad
|
||||
return noisy_grad, sigma
|
||||
|
||||
|
||||
def add_private_unit2_noise(eps, grad):
|
||||
eps0 = 0.01 * eps
|
||||
eps1 = 0.99 * eps
|
||||
samp_prob, gamma = privacy_parameters(eps0, eps1, grad.shape[0])
|
||||
return private_unit2(grad, gamma, samp_prob)
|
||||
|
||||
|
||||
def scalar_DP(r, eps, k, r_max):
|
||||
r = np.minimum(r, r_max)
|
||||
val = k * r / r_max
|
||||
f_val = math.floor(val)
|
||||
c_val = math.ceil(val)
|
||||
J = f_val if T.rand(()) < (c_val - val) else c_val
|
||||
exp_eps = np.exp(eps)
|
||||
rand_prob = exp_eps / (exp_eps + k)
|
||||
if T.rand(()) >= rand_prob:
|
||||
while True:
|
||||
J_ = T.randint(0, k + 1, ()).item()
|
||||
if J != J_:
|
||||
J = J_
|
||||
break
|
||||
a = ((exp_eps + k) / (exp_eps - 1)) * (r_max / k)
|
||||
b = (k * (k + 1)) / (2 * (exp_eps + k))
|
||||
return a * (J - b)
|
||||
|
||||
|
||||
def laplace_noise(max_sens, eps, vocab_size):
|
||||
return np.random.laplace(0.0, max_sens/eps, vocab_size)
|
||||
|
||||
|
||||
def unroll_network(named_params, select_grad=False):
|
||||
# Unroll the network as 1D vector and save original values indices
|
||||
params_ids, flat_params = {}, []
|
||||
cur_idx = 0
|
||||
for n, p in named_params:
|
||||
dat = p.grad if select_grad else p.data
|
||||
flat_params.append(dat.view(-1))
|
||||
next_idx = cur_idx + flat_params[-1].shape[0]
|
||||
params_ids[n] = (cur_idx, next_idx)
|
||||
cur_idx = next_idx
|
||||
return T.cat(flat_params), params_ids
|
||||
|
||||
|
||||
def update_network(named_params, params_ids, flat_params, apply_to_grad=False):
|
||||
# Roll back the network parameters to layers
|
||||
for n, p in named_params:
|
||||
s_id, e_id = params_ids[n]
|
||||
if apply_to_grad:
|
||||
p.grad.copy_(flat_params[s_id : e_id].view(*p.grad.shape))
|
||||
else:
|
||||
p.data.copy_(flat_params[s_id : e_id].view(*p.data.shape))
|
||||
|
||||
|
||||
def apply_global_dp(config, model, num_clients_curr_iter, select_grad=True, metric_logger=None):
|
||||
# Add global DP noise here
|
||||
dp_config = config.get('dp_config', None)
|
||||
if dp_config is not None and dp_config.get('enable_global_dp', False):
|
||||
# enable_local_dp must be enabled - client-side gradient clipping must be enabled.
|
||||
assert (dp_config['enable_local_dp'])
|
||||
# Unroll the network grads as 1D vectors
|
||||
flat_grad, params_ids = unroll_network(model.named_parameters(), select_grad=select_grad)
|
||||
|
||||
sigma = dp_config['global_sigma']
|
||||
max_grad = dp_config['max_grad']
|
||||
noise_scale = sigma * max_grad / num_clients_curr_iter
|
||||
noise = T.normal(0, 1, flat_grad.shape, device=flat_grad.device) * noise_scale
|
||||
flat_noisy_grad = flat_grad + noise
|
||||
print_rank('Error from noise {} is {}. grad norm: {} noisy_grad norm: {}'.format(noise_scale, (
|
||||
flat_grad - flat_noisy_grad).norm(), flat_grad.norm(), flat_noisy_grad.norm()))
|
||||
|
||||
# Return back to the network gradients
|
||||
update_network(model.named_parameters(), params_ids, flat_noisy_grad,
|
||||
apply_to_grad=select_grad)
|
||||
|
||||
if metric_logger is None:
|
||||
metric_logger = Run.get_context().log
|
||||
metric_logger('Gradient Norm', flat_grad.norm().cpu().item())
|
||||
|
||||
|
||||
def update_privacy_accountant(config, num_clients, curr_iter, num_clients_curr_iter):
|
||||
# Privacy accounting starts here
|
||||
# We will dump all the needed parameters to the log so as not to slow down training.
|
||||
dp_config = config.get('dp_config', None)
|
||||
if dp_config is not None and dp_config.get('enable_global_dp', False) or dp_config.get('enable_local_dp',
|
||||
False):
|
||||
from math import sqrt, exp, log
|
||||
import extensions.privacy.analysis as privacy_analysis
|
||||
|
||||
K = 1 # from DP perspective each user is contributing one gradient
|
||||
B = num_clients_curr_iter # batch size
|
||||
n = num_clients
|
||||
T = curr_iter + 1
|
||||
_delta = dp_config.get('delta', min(1e-7, 1. / (n * log(n)))) # TODO should be precomputed in config
|
||||
if dp_config.get('global_sigma', None) is None:
|
||||
max_sensitivity = np.sqrt(dp_config['max_grad'] ** 2 + dp_config['max_weight'] ** 2)
|
||||
noise_scale = compute_LDP_noise_std(dp_config['eps'], max_sensitivity, _delta)
|
||||
global_sigma = noise_scale * np.sqrt(B) / max_sensitivity
|
||||
else:
|
||||
global_sigma = dp_config['global_sigma']
|
||||
noise_scale = global_sigma * dp_config['max_grad'] / B
|
||||
|
||||
try:
|
||||
mu = K * B / n * sqrt(T * exp((1. / global_sigma) ** 2 - 1))
|
||||
except OverflowError:
|
||||
print_rank(f"Error computing mu {global_sigma} {K} {B} {n} {T}")
|
||||
mu = -1
|
||||
|
||||
orders = ([1.25, 1.5, 1.75, 2., 2.25, 2.5, 3., 3.5, 4., 4.5] + list(range(5, 64)) + [128, 256, 512])
|
||||
q = B / n
|
||||
_sigma = global_sigma # was: noise_scale but we should apply the noise multiplier.
|
||||
rdp = privacy_analysis.compute_rdp(q, _sigma, T, orders)
|
||||
|
||||
rdp_epsilon, opt_order = privacy_analysis.get_privacy_spent(orders, rdp, _delta)
|
||||
|
||||
props = {
|
||||
'dp_global_K': K, # gradients per user
|
||||
'dp_global_B': B, # users per batch
|
||||
'dp_global_n': n, # total users
|
||||
'dp_global_T': T, # how many iterations
|
||||
'dp_sigma': _sigma, # noise_multiplier. Should be combined global+local sigma.
|
||||
'dp_global_mu': mu,
|
||||
# 'dp_epsilon_fdp': fdp_epsilon,
|
||||
'dp_epsilon_rdp': rdp_epsilon,
|
||||
# 'dp_epsilon_exact': exact_eps,
|
||||
'dp_opt_order': opt_order,
|
||||
'dp_delta': _delta,
|
||||
'dp_noise_scale': noise_scale # Note: not needed for accounting.
|
||||
}
|
||||
|
||||
print_rank(f'DP accounting: {json.dumps(props)}')
|
||||
for k in props:
|
||||
run.log(k, props[k])
|
||||
|
||||
return rdp_epsilon
|
||||
else:
|
||||
return None
|
|
@ -0,0 +1,306 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
||||
|
||||
"""
|
||||
*Based on Google's TF Privacy:* https://github.com/tensorflow/privacy/blob/master/tensorflow_privacy/privacy/analysis/rdp_accountant.py.
|
||||
*Here, we update this code to Python 3, and optimize dependencies.*
|
||||
|
||||
Functionality for computing Renyi Differential Privacy (RDP) of an additive
|
||||
Sampled Gaussian Mechanism (SGM).
|
||||
|
||||
Example:
|
||||
Suppose that we have run an SGM applied to a function with L2-sensitivity of 1.
|
||||
|
||||
Its parameters are given as a list of tuples
|
||||
``[(q_1, sigma_1, steps_1), ..., (q_k, sigma_k, steps_k)],``
|
||||
and we wish to compute epsilon for a given target delta.
|
||||
|
||||
The example code would be:
|
||||
|
||||
>>> max_order = 32
|
||||
>>> orders = range(2, max_order + 1)
|
||||
>>> rdp = np.zeros_like(orders, dtype=float)
|
||||
>>> for q, sigma, steps in parameters:
|
||||
>>> rdp += privacy_analysis.compute_rdp(q, sigma, steps, orders)
|
||||
>>> epsilon, opt_order = privacy_analysis.get_privacy_spent(orders, rdp, delta)
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from scipy import special
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
########################
|
||||
# LOG-SPACE ARITHMETIC #
|
||||
########################
|
||||
|
||||
|
||||
def _log_add(logx: float, logy: float) -> float:
|
||||
r"""Adds two numbers in the log space.
|
||||
|
||||
Args:
|
||||
logx: First term in log space.
|
||||
logy: Second term in log space.
|
||||
|
||||
Returns:
|
||||
Sum of numbers in log space.
|
||||
"""
|
||||
a, b = min(logx, logy), max(logx, logy)
|
||||
if a == -np.inf: # adding 0
|
||||
return b
|
||||
# Use exp(a) + exp(b) = (exp(a - b) + 1) * exp(b)
|
||||
return math.log1p(math.exp(a - b)) + b # log1p(x) = log(x + 1)
|
||||
|
||||
|
||||
def _log_sub(logx: float, logy: float) -> float:
|
||||
r"""Subtracts two numbers in the log space.
|
||||
|
||||
Args:
|
||||
logx: First term in log space. Expected to be greater than the second term.
|
||||
logy: First term in log space. Expected to be less than the first term.
|
||||
|
||||
Returns:
|
||||
Difference of numbers in log space.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
If the result is negative.
|
||||
"""
|
||||
if logx < logy:
|
||||
raise ValueError("The result of subtraction must be non-negative.")
|
||||
if logy == -np.inf: # subtracting 0
|
||||
return logx
|
||||
if logx == logy:
|
||||
return -np.inf # 0 is represented as -np.inf in the log space.
|
||||
|
||||
try:
|
||||
# Use exp(x) - exp(y) = (exp(x - y) - 1) * exp(y).
|
||||
return math.log(math.expm1(logx - logy)) + logy # expm1(x) = exp(x) - 1
|
||||
except OverflowError:
|
||||
return logx
|
||||
|
||||
|
||||
def _compute_log_a_for_int_alpha(q: float, sigma: float, alpha: int) -> float:
|
||||
r"""Computes :math:`log(A_\alpha)` for integer ``alpha``.
|
||||
|
||||
Notes:
|
||||
Note that
|
||||
:math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
|
||||
and that 0 < ``q`` < 1.
|
||||
|
||||
Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details.
|
||||
|
||||
Args:
|
||||
q: Sampling rate of SGM.
|
||||
sigma: The standard deviation of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
:math:`log(A_\alpha)` as defined in Section 3.3 of
|
||||
https://arxiv.org/pdf/1908.10530.pdf.
|
||||
"""
|
||||
|
||||
# Initialize with 0 in the log space.
|
||||
log_a = -np.inf
|
||||
|
||||
for i in range(alpha + 1):
|
||||
log_coef_i = (
|
||||
math.log(special.binom(alpha, i))
|
||||
+ i * math.log(q)
|
||||
+ (alpha - i) * math.log(1 - q)
|
||||
)
|
||||
|
||||
s = log_coef_i + (i * i - i) / (2 * (sigma ** 2))
|
||||
log_a = _log_add(log_a, s)
|
||||
|
||||
return float(log_a)
|
||||
|
||||
|
||||
def _compute_log_a_for_frac_alpha(q: float, sigma: float, alpha: float) -> float:
|
||||
r"""Computes :math:`log(A_\alpha)` for fractional ``alpha``.
|
||||
|
||||
Notes:
|
||||
Note that
|
||||
:math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
|
||||
and that 0 < ``q`` < 1.
|
||||
|
||||
Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf for details.
|
||||
|
||||
Args:
|
||||
q: Sampling rate of SGM.
|
||||
sigma: The standard deviation of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
:math:`log(A_\alpha)` as defined in Section 3.3 of
|
||||
https://arxiv.org/pdf/1908.10530.pdf.
|
||||
"""
|
||||
# The two parts of A_alpha, integrals over (-inf,z0] and [z0, +inf), are
|
||||
# initialized to 0 in the log space:
|
||||
log_a0, log_a1 = -np.inf, -np.inf
|
||||
i = 0
|
||||
|
||||
z0 = sigma ** 2 * math.log(1 / q - 1) + 0.5
|
||||
|
||||
while True: # do ... until loop
|
||||
coef = special.binom(alpha, i)
|
||||
log_coef = math.log(abs(coef))
|
||||
j = alpha - i
|
||||
|
||||
log_t0 = log_coef + i * math.log(q) + j * math.log(1 - q)
|
||||
log_t1 = log_coef + j * math.log(q) + i * math.log(1 - q)
|
||||
|
||||
log_e0 = math.log(0.5) + _log_erfc((i - z0) / (math.sqrt(2) * sigma))
|
||||
log_e1 = math.log(0.5) + _log_erfc((z0 - j) / (math.sqrt(2) * sigma))
|
||||
|
||||
log_s0 = log_t0 + (i * i - i) / (2 * (sigma ** 2)) + log_e0
|
||||
log_s1 = log_t1 + (j * j - j) / (2 * (sigma ** 2)) + log_e1
|
||||
|
||||
if coef > 0:
|
||||
log_a0 = _log_add(log_a0, log_s0)
|
||||
log_a1 = _log_add(log_a1, log_s1)
|
||||
else:
|
||||
log_a0 = _log_sub(log_a0, log_s0)
|
||||
log_a1 = _log_sub(log_a1, log_s1)
|
||||
|
||||
i += 1
|
||||
if max(log_s0, log_s1) < -30:
|
||||
break
|
||||
|
||||
return _log_add(log_a0, log_a1)
|
||||
|
||||
|
||||
def _compute_log_a(q: float, sigma: float, alpha: float) -> float:
|
||||
r"""Computes :math:`log(A_\alpha)` for any positive finite ``alpha``.
|
||||
|
||||
Notes:
|
||||
Note that
|
||||
:math:`A_\alpha` is real valued function of ``alpha`` and ``q``,
|
||||
and that 0 < ``q`` < 1.
|
||||
|
||||
Refer to Section 3.3 of https://arxiv.org/pdf/1908.10530.pdf
|
||||
for details.
|
||||
|
||||
Args:
|
||||
q: Sampling rate of SGM.
|
||||
sigma: The standard deviation of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
:math:`log(A_\alpha)` as defined in the paper mentioned above.
|
||||
"""
|
||||
if float(alpha).is_integer():
|
||||
return _compute_log_a_for_int_alpha(q, sigma, int(alpha))
|
||||
else:
|
||||
return _compute_log_a_for_frac_alpha(q, sigma, alpha)
|
||||
|
||||
|
||||
def _log_erfc(x: float) -> float:
|
||||
r"""Computes :math:`log(erfc(x))` with high accuracy for large ``x``.
|
||||
|
||||
Helper function used in computation of :math:`log(A_\alpha)`
|
||||
for a fractional alpha.
|
||||
|
||||
Args:
|
||||
x: The input to the function
|
||||
|
||||
Returns:
|
||||
:math:`log(erfc(x))`
|
||||
"""
|
||||
return math.log(2) + special.log_ndtr(-x * 2 ** 0.5)
|
||||
|
||||
|
||||
def _compute_rdp(q: float, sigma: float, alpha: float) -> float:
|
||||
r"""Computes RDP of the Sampled Gaussian Mechanism at order ``alpha``.
|
||||
|
||||
Args:
|
||||
q: Sampling rate of SGM.
|
||||
sigma: The standard deviation of the additive Gaussian noise.
|
||||
alpha: The order at which RDP is computed.
|
||||
|
||||
Returns:
|
||||
RDP at order ``alpha``; can be np.inf.
|
||||
"""
|
||||
if q == 0:
|
||||
return 0
|
||||
|
||||
# no privacy
|
||||
if sigma == 0:
|
||||
return np.inf
|
||||
|
||||
if q == 1.0:
|
||||
return alpha / (2 * sigma ** 2)
|
||||
|
||||
if np.isinf(alpha):
|
||||
return np.inf
|
||||
|
||||
return _compute_log_a(q, sigma, alpha) / (alpha - 1)
|
||||
|
||||
|
||||
def compute_rdp(
|
||||
q: float, noise_multiplier: float, steps: int, orders: Union[List[float], float]
|
||||
) -> Union[List[float], float]:
|
||||
r"""Computes Renyi Differential Privacy (RDP) guarantees of the
|
||||
Sampled Gaussian Mechanism (SGM) iterated ``steps`` times.
|
||||
|
||||
Args:
|
||||
q: Sampling rate of SGM.
|
||||
noise_multiplier: The ratio of the standard deviation of the
|
||||
additive Gaussian noise to the L2-sensitivity of the function
|
||||
to which it is added. Note that this is same as the standard
|
||||
deviation of the additive Gaussian noise when the L2-sensitivity
|
||||
of the function is 1.
|
||||
steps: The number of iterations of the mechanism.
|
||||
orders: An array (or a scalar) of RDP orders.
|
||||
|
||||
Returns:
|
||||
The RDP guarantees at all orders; can be ``np.inf``.
|
||||
"""
|
||||
if isinstance(orders, float):
|
||||
rdp = _compute_rdp(q, noise_multiplier, orders)
|
||||
else:
|
||||
rdp = np.array([_compute_rdp(q, noise_multiplier, order) for order in orders])
|
||||
|
||||
return rdp * steps
|
||||
|
||||
|
||||
def get_privacy_spent(
|
||||
orders: Union[List[float], float], rdp: Union[List[float], float], delta: float
|
||||
) -> Tuple[float, float]:
|
||||
r"""Computes epsilon given a list of Renyi Differential Privacy (RDP) values at
|
||||
multiple RDP orders and target ``delta``.
|
||||
|
||||
Args:
|
||||
orders: An array (or a scalar) of orders (alphas).
|
||||
rdp: A list (or a scalar) of RDP guarantees.
|
||||
delta: The target delta.
|
||||
|
||||
Returns:
|
||||
Pair of epsilon and optimal order alpha.
|
||||
|
||||
Raises:
|
||||
ValueError
|
||||
If the lengths of ``orders`` and ``rdp`` are not equal.
|
||||
"""
|
||||
orders_vec = np.atleast_1d(orders)
|
||||
rdp_vec = np.atleast_1d(rdp)
|
||||
|
||||
if len(orders_vec) != len(rdp_vec):
|
||||
raise ValueError(
|
||||
f"Input lists must have the same length.\n"
|
||||
f"\torders_vec = {orders_vec}\n"
|
||||
f"\trdp_vec = {rdp_vec}\n"
|
||||
)
|
||||
|
||||
eps = rdp_vec - math.log(delta) / (orders_vec - 1)
|
||||
|
||||
# special case when there is no privacy
|
||||
if np.isnan(eps).all():
|
||||
return np.inf, np.nan
|
||||
|
||||
idx_opt = np.nanargmin(eps) # Ignore NaNs
|
||||
return eps[idx_opt], orders_vec[idx_opt]
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
from scipy.special import gammainc
|
||||
from sklearn.cluster import KMeans
|
||||
from sklearn import cluster as skcluster
|
||||
|
||||
|
||||
kmeans_single = skcluster._kmeans.lloyd_iter_chunked_dense
|
||||
|
||||
|
||||
def sample(ndim, r, num_samples=1):
|
||||
x = np.random.normal(size=(num_samples, ndim))
|
||||
ssq = np.sum(x**2,axis=1)
|
||||
fr = r*gammainc(ndim/2,ssq/2)**(1/ndim)/np.sqrt(ssq)
|
||||
if num_samples > 1:
|
||||
fr = np.tile(fr.reshape(num_samples,1),(1,ndim))
|
||||
return np.multiply(x,fr)
|
||||
|
||||
|
||||
def sphere_packing_initialization(n_clusters, n_dim, min_cluster_radius,
|
||||
max_space_size, max_failed_cases, verbose=None):
|
||||
a, max_r = min_cluster_radius, max_space_size
|
||||
centers = np.empty((n_clusters, n_dim))
|
||||
cluster_id = 0
|
||||
fail_count = 0
|
||||
r = max_r - a
|
||||
while cluster_id < n_clusters:
|
||||
v = sample(n_dim, r)
|
||||
if cluster_id > 0 and np.min(np.linalg.norm(centers[:cluster_id, :] - v, axis=-1)) < 2 * a:
|
||||
fail_count += 1
|
||||
if fail_count >= max_failed_cases:
|
||||
fail_count = 0
|
||||
cluster_id = 0
|
||||
a = a / 2 # TODO Use binary search to find maximum a that don't fail (vaguely discribed in the diff-p kmeas paper)
|
||||
if verbose:
|
||||
print(f'Failing to pack, halving min_cluster_radius to {a}')
|
||||
r = max_r - a
|
||||
continue
|
||||
|
||||
centers[cluster_id] = v
|
||||
cluster_id += 1
|
||||
if verbose:
|
||||
print('Final min_cluster_radius', a)
|
||||
return centers, a
|
||||
|
||||
|
||||
def add_gaussian_noise(centers_new, weight_in_clusters, eps,
|
||||
max_cluster_l2, max_sample_weight,
|
||||
cluster_to_weight_ratio=-1, delta=1e-7, verbose=None):
|
||||
scaler = 1
|
||||
|
||||
if cluster_to_weight_ratio > 0:
|
||||
# Compute the scaler to apply to the sample weights
|
||||
scaler = max_cluster_l2 / (max_sample_weight * cluster_to_weight_ratio)
|
||||
max_sample_weight *= scaler
|
||||
|
||||
max_l2_sensitivity = np.sqrt(max_cluster_l2 ** 2 + max_sample_weight ** 2)
|
||||
sigma = np.sqrt(2 * np.log(1.25 / delta)) * max_l2_sensitivity / eps
|
||||
if verbose:
|
||||
print('cluster_to_weight_ratio', cluster_to_weight_ratio,
|
||||
'scaler', scaler,
|
||||
'max_sample_weight', max_sample_weight,
|
||||
'max_l2_sensitivity', max_l2_sensitivity,
|
||||
'sigma', sigma)
|
||||
centers_sum = (centers_new * weight_in_clusters.reshape(-1, 1)) + np.random.normal(scale=sigma, size=centers_new.shape)
|
||||
# Scale the sample weights by scaling the cluster weights, since (s*w1 + s*w2, ...) == s*(w1 + w2 + ...), where s is the scaler
|
||||
# Add noise then rescale back. We should never get negative weights because of the noise
|
||||
weight_in_clusters[:] = np.maximum(1e-10, (weight_in_clusters * scaler) + np.random.normal(scale=sigma, size=weight_in_clusters.shape)) / scaler
|
||||
centers_new[:] = centers_sum / weight_in_clusters.reshape(-1, 1)
|
||||
|
||||
|
||||
def DPKMeans(n_dim, eps, max_cluster_l2, max_sample_weight=1.0,
|
||||
max_iter=300, cluster_to_weight_ratio=-1, n_clusters=8,
|
||||
tol=1e-4, verbose=0, delta=1e-7, max_failed_cases=300,
|
||||
min_cluster_radius=None, **kwargs):
|
||||
"""Differentially private KMeans
|
||||
|
||||
Initialise the differentially-private Sklearn.cluster.KMeans overriding lloyd algorithm,
|
||||
by adding Gaussian noise.
|
||||
|
||||
Parameters
|
||||
---------
|
||||
|
||||
n_dim : int
|
||||
The dimension size of the input space
|
||||
|
||||
eps : float
|
||||
The privacy loss (epsilon) per iteration. Currently only fix epsilon is implemented so
|
||||
the overall privacy loss <= eps * max_iter
|
||||
|
||||
max_cluster_l2 : float
|
||||
The maximum l2 norm of any example vector that we want to cluster
|
||||
|
||||
max_sample_weight : float
|
||||
The maximum weight of a sample default=1.0
|
||||
|
||||
max_iter : int, default=300
|
||||
Maximum number of iterations of the k-means algorithm for a
|
||||
single run.
|
||||
|
||||
cluster_to_weight_ratio : float, default=-1
|
||||
The ratio max_cluster_l2 / max_sample_weight used to scale the cluster counts before adding the noise
|
||||
If it is set to -1, do not scale the counts
|
||||
|
||||
n_clusters : int, default=8
|
||||
The number of clusters to form as well as the number of
|
||||
centroids to generate.
|
||||
|
||||
tol : float, default=1e-4
|
||||
Relative tolerance with regards to Frobenius norm of the difference
|
||||
in the cluster centers of two consecutive iterations to declare
|
||||
convergence.
|
||||
|
||||
verbose : int, default=0
|
||||
Verbosity mode.
|
||||
|
||||
delta : float, default=1e-7
|
||||
Gaussian mechanism delta or probability of failure, should be set < 1/num of examples
|
||||
|
||||
max_failed_cases : int, default=300
|
||||
The number of sampling trails in sphere packing before halving the minimum cluster radius
|
||||
|
||||
min_cluster_radius : float, default=None (= max_cluster_l2 / n_clusters)
|
||||
Half the minimum distance between clusters centers
|
||||
"""
|
||||
|
||||
if min_cluster_radius is None:
|
||||
min_cluster_radius = max_cluster_l2 / n_clusters
|
||||
|
||||
# Initalise the cluster centers using sphere packing
|
||||
init_centers, min_cluster_radius = sphere_packing_initialization(n_clusters, n_dim,
|
||||
min_cluster_radius,
|
||||
max_cluster_l2,
|
||||
max_failed_cases,
|
||||
verbose)
|
||||
|
||||
final_eps = [0] # To keep track of the actual number of iterations until convergence
|
||||
def modified_lloyd(X, sample_weight, x_squared_norms, centers, centers_new,
|
||||
weight_in_clusters, labels, center_shift, n_threads,
|
||||
update_centers=True):
|
||||
|
||||
# Clip the maximum client contribution to the cluster count
|
||||
sample_weight = np.minimum(sample_weight, max_sample_weight)
|
||||
|
||||
if not update_centers:
|
||||
return kmeans_single(X, sample_weight, x_squared_norms, centers, centers_new,
|
||||
weight_in_clusters, labels, center_shift, n_threads, update_centers=False)
|
||||
|
||||
|
||||
# Scale input vectors if necessary
|
||||
if np.max(x_squared_norms) > max_cluster_l2 ** 2:
|
||||
if verbose:
|
||||
print(f'Scaling the input examples as their l2 norm is larger than {max_cluster_l2}')
|
||||
scaler_squared = np.minimum(max_cluster_l2 ** 2 / x_squared_norms, 1.0)
|
||||
x_squared_norms[:] = x_squared_norms * scaler_squared
|
||||
X[:] = X * np.sqrt(scaler_squared).reshape(-1, 1)
|
||||
|
||||
kmeans_single(X, sample_weight, x_squared_norms, centers, centers_new,
|
||||
weight_in_clusters, labels, center_shift, n_threads)
|
||||
|
||||
# Add noise to centers_new
|
||||
add_gaussian_noise(centers_new, weight_in_clusters, eps,
|
||||
max_cluster_l2, max_sample_weight,
|
||||
cluster_to_weight_ratio, delta=delta,
|
||||
verbose=verbose)
|
||||
|
||||
# Other values need to be changed because of that: center_shift, labels,
|
||||
center_shift[:] = np.linalg.norm(centers - centers_new, axis=-1)
|
||||
# Run E-step of kmeans to get the new labels
|
||||
kmeans_single(X, sample_weight, x_squared_norms, centers, centers_new,
|
||||
weight_in_clusters, labels, center_shift, n_threads, update_centers=False)
|
||||
|
||||
# Increment the number of iterations
|
||||
final_eps[0] += eps
|
||||
|
||||
sys.modules[KMeans.__module__].lloyd_iter_chunked_dense = modified_lloyd
|
||||
|
||||
kmeans = KMeans(n_clusters=n_clusters,
|
||||
algorithm='full',
|
||||
init=init_centers,
|
||||
verbose=verbose,
|
||||
max_iter=max_iter,
|
||||
tol=tol, **kwargs)
|
||||
kmeans.eps = final_eps
|
||||
return kmeans
|
||||
|
||||
|
||||
def resetKMeans():
|
||||
sys.modules[KMeans.__module__].lloyd_iter_chunked_dense = kmeans_single
|
|
@ -0,0 +1,75 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch as T
|
||||
from copy import deepcopy
|
||||
from utils import make_optimizer, print_rank
|
||||
|
||||
def extract_indices_from_embeddings(gradients, batch, embed_size, vocab_size):
|
||||
# Extract the Input gradient embeddings
|
||||
batch = T.cat([b.view(-1) for b in batch]).cpu().detach().numpy()
|
||||
embed_grad = gradients[:embed_size * vocab_size].reshape(vocab_size, embed_size)
|
||||
valid_batch = batch[batch > 0]
|
||||
tot_valid_tokens, tot_tokens = len(valid_batch), len(batch)
|
||||
# The embedding gradients of the indices seen in the batch have higher l2 norm,
|
||||
# because dl/dembed_i = dl/dembed_input_i * (if word_i is in batch) + dl/dembed_output_i
|
||||
extracted_indices = T.argsort(embed_grad.norm(dim=-1), descending=True)[:tot_tokens].cpu().detach().numpy()
|
||||
# Get the overlap ratio
|
||||
extracted_ratio = np.isin(valid_batch, extracted_indices).mean()
|
||||
# Find True positive extracted indices
|
||||
return extracted_ratio, np.intersect1d(extracted_indices, valid_batch)
|
||||
|
||||
|
||||
def compute_perplexity(encoded_batch, model):
|
||||
outputs = model.inference(encoded_batch)
|
||||
(batch_size, seq_len, vocab_size) = outputs[0].shape
|
||||
perplex = T.nn.functional.log_softmax(outputs[0], dim=-1)
|
||||
return perplex.reshape(-1, vocab_size)[np.arange(batch_size * seq_len),
|
||||
encoded_batch.reshape(-1)].reshape(batch_size, seq_len)
|
||||
|
||||
def practical_epsilon_leakage(original_params, model, encoded_batches, is_weighted_leakage=True,
|
||||
max_ratio=1e9, optimizer_config=None):
|
||||
# Copy the gradients and save the model.
|
||||
current_params = deepcopy(model.state_dict())
|
||||
current_gradients = dict((n,p.grad.clone().detach()) for n,p in model.named_parameters())
|
||||
model.load_state_dict(original_params)
|
||||
pre_perplex, post_perplex = [], []
|
||||
# This is just to initialise the gradients
|
||||
model.loss(encoded_batches[0][:1]).backward()
|
||||
model.zero_grad()
|
||||
tolerance = 1 / max_ratio
|
||||
max_leakage = 0
|
||||
with T.no_grad():
|
||||
# Original model before training on client
|
||||
for encoded_batch in encoded_batches:
|
||||
pre_perplex.append(compute_perplexity(encoded_batch, model))
|
||||
# The attacker doesn't not he optimal gradient magnitude but using Adamax with high lr, is proved to be effective
|
||||
for n, p in model.named_parameters():
|
||||
p.grad = current_gradients[n] #.grad
|
||||
print_rank('grad l2: {}'.format(p.grad), loglevel=logging.DEBUG)
|
||||
if optimizer_config is None:
|
||||
optimizer_config = {'lr': 0.03, 'amsgrad': False, 'type': 'adamax'}
|
||||
#T.optim.Adamax(model.parameters(), lr=optim_lr).step()
|
||||
make_optimizer(optimizer_config, model).step()
|
||||
#model.zero_grad()
|
||||
# The model after training on the client data
|
||||
for encoded_batch in encoded_batches:
|
||||
post_perplex.append(compute_perplexity(encoded_batch, model))
|
||||
|
||||
for pre, post in zip(pre_perplex, post_perplex):
|
||||
# Compute the ratio of preplexity and weight it be the probability of correctly predicting the word
|
||||
leakage = ((pre + tolerance) / (post + tolerance)).clamp_(0, max_ratio)
|
||||
print_rank('perplexities leakage: {} '.format(leakage), loglevel=logging.DEBUG)
|
||||
if is_weighted_leakage:
|
||||
weight_leakage = T.max(pre.exp(), post.exp()) * leakage
|
||||
else:
|
||||
weight_leakage = leakage
|
||||
max_leakage = max(max_leakage, weight_leakage.max().item())
|
||||
print_rank('raw max leakage: {}'.format(max_leakage), loglevel=logging.DEBUG)
|
||||
model.load_state_dict(current_params)
|
||||
for n,p in model.named_parameters():
|
||||
p.grad = current_gradients[n]
|
||||
# WE return the log to match epsilon
|
||||
return max(np.log(max_leakage), 0)
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from utils import print_rank
|
||||
from typing import Optional, Tuple
|
||||
|
||||
def quant_model(
|
||||
model: torch.nn.Module,
|
||||
quant_bits: int = 8,
|
||||
quant_threshold: Optional[int] = None,
|
||||
global_stats: bool = False
|
||||
):
|
||||
'''Quantize the gradients using the desired number of bits.
|
||||
|
||||
Nothing is returned as gradients inside :code:`model` are modified
|
||||
in-place.
|
||||
|
||||
Args:
|
||||
model: model which gradients we want to quantize.
|
||||
quant_bits: how many bits will we use to quantize the gradients.
|
||||
quant_threshold: fraction of components to be set to zero; defaults to
|
||||
None, in which case no quantization happens.
|
||||
global_stats: use a single histogram for all layers when binning,
|
||||
defaults to False.
|
||||
'''
|
||||
|
||||
# If no `quant_threshold`, does nothing
|
||||
if quant_threshold is None:
|
||||
return
|
||||
print_rank('Performing Gradient Quantization with Prob. Threshold: {}'.format(
|
||||
quant_threshold), loglevel=logging.INFO)
|
||||
|
||||
# If `global_stats` is true, min/max and thresh are computed across all layers
|
||||
if global_stats:
|
||||
flattened_grad = torch.cat([p.grad.data.flatten() for p in model.parameters()])
|
||||
min_grad, max_grad, thresh = find_min_max_gradient(flattened_grad,
|
||||
quant_threshold)
|
||||
|
||||
# Loop through all layers
|
||||
for p in model.parameters():
|
||||
if not global_stats:
|
||||
min_grad, max_grad, thresh = find_min_max_gradient(p.grad.data,
|
||||
quant_threshold)
|
||||
|
||||
# Perform binning and sparsification of components
|
||||
binned_grad = quant_bins(p.grad.data, 2 ** quant_bits, min_grad, max_grad)
|
||||
p.grad = torch.where(torch.abs(p.grad.data) > thresh, binned_grad,
|
||||
torch.tensor(0.).to(p.grad))
|
||||
|
||||
|
||||
def find_min_max_gradient(
|
||||
gradient: torch.Tensor,
|
||||
quant_threshold: Optional[float] = None
|
||||
) -> Tuple[float, float, float]:
|
||||
'''Get min and max gradients, as well as threshold gradient.
|
||||
|
||||
Args:
|
||||
gradient: tensor over which statistics will be computed.
|
||||
quant_threshold: which quantile to look for to compute threshold, must
|
||||
be between 0 and 1.
|
||||
'''
|
||||
|
||||
# Computes min/max and quantile corresponding to `quant_threshold`
|
||||
min_grad, max_grad = gradient.min(), gradient.max()
|
||||
thresh = torch.quantile(torch.abs(gradient), quant_threshold)
|
||||
|
||||
print_rank('Min. and Max. Gradients: {}, {}'.format(min_grad, max_grad),
|
||||
loglevel=logging.INFO)
|
||||
print_rank('Grad. Threshold: {}'.format(thresh), loglevel=logging.INFO)
|
||||
|
||||
return min_grad, max_grad, thresh
|
||||
|
||||
|
||||
def quant_bins(
|
||||
gradients: torch.Tensor,
|
||||
n_bins: int,
|
||||
min_grad: float,
|
||||
max_grad: float
|
||||
) -> torch.Tensor:
|
||||
'''Perform quantization using binning.
|
||||
|
||||
Creates histogram with `n_bins` bins between `min_grad` and `max_grad`.
|
||||
Returns a tensor similar to gradients but with components corresponding to
|
||||
bin labels.
|
||||
|
||||
Args:
|
||||
gradients: tensor we want to quantize.
|
||||
n_bins: how many bins to use for binning.
|
||||
min_grad: min. value for bins.
|
||||
max_grad: max. value for bins.
|
||||
'''
|
||||
|
||||
# We remove half bin width, as bucketize always takes the ceil instead of rounding
|
||||
bin_labels = torch.linspace(min_grad, max_grad, n_bins).to(gradients)
|
||||
bin_width = bin_labels[1] - bin_labels[0]
|
||||
grad_bins = torch.bucketize(gradients - .5 * bin_width, bin_labels, right=False)
|
||||
|
||||
return bin_labels[grad_bins]
|
|
@ -0,0 +1,13 @@
|
|||
torch
|
||||
mpi4py
|
||||
easydict
|
||||
scipy
|
||||
psutil
|
||||
transformers
|
||||
torchvision
|
||||
pandas
|
||||
h5py
|
||||
sphinx_rtd_theme
|
||||
azureml-core
|
||||
azureml-defaults
|
||||
pyyaml
|
|
@ -0,0 +1,14 @@
|
|||
## Setup Instructions for Pytest
|
||||
|
||||
1. In order to run test_e2e_trainer.py, we need a dataset for test, train an validation. For demonstrative purposes, we are using as example the Reddit dataset already processed by LEAF, that can be downloaded here: https://github.com/TalwalkarLab/leaf/tree/master/data/reddit (Setup instructions, point I)
|
||||
2. Create the following folder structure mockup/data inside /testing. Make sure that inside /data the files needed are divided by test, train and val folders.
|
||||
3. Run ```python preprocess_data.py``` to adjust the data as per FLUTE requirements.
|
||||
4. Run ```pytest -v``` to test the program.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
In case you encounter any issue while running test_e2e_trainer.py, please check the following points:
|
||||
|
||||
1. The file structure matches the path provided in testing/configs/hello_world_local.yaml
|
||||
2. Timeout in test_e2e_trainer.py is proportional to the amount of data using for the training.
|
||||
3. Command line used in test_e2e_trainer.py is commented according to the OS in use.
|
|
@ -0,0 +1,109 @@
|
|||
"""Builds vocabulary file from data."""
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import json
|
||||
import os
|
||||
|
||||
def build_counter(train_data, initial_counter=None):
|
||||
train_tokens = []
|
||||
for u in train_data:
|
||||
for c in train_data[u]['x']:
|
||||
train_tokens.extend([s for s in c])
|
||||
|
||||
all_tokens = []
|
||||
for i in train_tokens:
|
||||
all_tokens.extend(i)
|
||||
train_tokens = []
|
||||
|
||||
if initial_counter is None:
|
||||
counter = collections.Counter()
|
||||
else:
|
||||
counter = initial_counter
|
||||
|
||||
counter.update(all_tokens)
|
||||
all_tokens = []
|
||||
|
||||
return counter
|
||||
|
||||
|
||||
def build_vocab(counter, vocab_size=10000):
|
||||
pad_symbol, unk_symbol = 0, 1
|
||||
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
|
||||
count_pairs = count_pairs[:(vocab_size - 2)] # -2 to account for the unknown and pad symbols
|
||||
|
||||
words, _ = list(zip(*count_pairs))
|
||||
|
||||
vocab = {}
|
||||
vocab['<PAD>'] = pad_symbol
|
||||
vocab['<UNK>'] = unk_symbol
|
||||
|
||||
for i, w in enumerate(words):
|
||||
if w != '<PAD>':
|
||||
vocab[w] = i + 1
|
||||
|
||||
return {'vocab': vocab, 'size': vocab_size, 'unk_symbol': unk_symbol, 'pad_symbol': pad_symbol}
|
||||
|
||||
|
||||
def load_leaf_data(file_path):
|
||||
with open(file_path) as json_file:
|
||||
data = json.load(json_file)
|
||||
to_ret = data['user_data']
|
||||
data = None
|
||||
return to_ret
|
||||
|
||||
|
||||
def save_vocab(vocab, target_dir):
|
||||
os.makedirs(target_dir)
|
||||
with open('./mockup/models/vocab_reddit.vocab', 'w') as outV:
|
||||
outV.write('<OOV>\n')
|
||||
for t in vocab['vocab'].keys():
|
||||
outV.write(t+'\n')
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
json_files = [f for f in os.listdir(args.data_dir) if f.endswith('.json')]
|
||||
json_files.sort()
|
||||
|
||||
counter = None
|
||||
train_data = {}
|
||||
for f in json_files:
|
||||
print('loading {}'.format(f))
|
||||
train_data = load_leaf_data(os.path.join(args.data_dir, f))
|
||||
print('counting {}'.format(f))
|
||||
counter = build_counter(train_data, initial_counter=counter)
|
||||
print()
|
||||
train_data = {}
|
||||
|
||||
if counter is not None:
|
||||
vocab = build_vocab(counter, vocab_size=args.vocab_size)
|
||||
save_vocab(vocab, args.target_dir)
|
||||
else:
|
||||
print('No files to process.')
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('--data-dir',
|
||||
help='dir with training file;',
|
||||
type=str,
|
||||
required=True)
|
||||
parser.add_argument('--vocab-size',
|
||||
help='size of the vocabulary;',
|
||||
type=int,
|
||||
default=10000,
|
||||
required=False)
|
||||
parser.add_argument('--target-dir',
|
||||
help='dir with training file;',
|
||||
type=str,
|
||||
default='./',
|
||||
required=False)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,198 @@
|
|||
model_config:
|
||||
model_type: GRU
|
||||
model_folder: experiments/nlg_gru/model.py
|
||||
embed_dim: 160
|
||||
vocab_size: 10000
|
||||
hidden_dim: 512
|
||||
OOV_correct: false
|
||||
# Configuration for differential privacy
|
||||
dp_config:
|
||||
# Local dp clips and adds noise on the client and centrally accumulates the privacy budget.
|
||||
enable_local_dp: true
|
||||
eps: 100 # epsilon
|
||||
max_grad: 0.008 # max gradient
|
||||
# The max_weight and min_weight should be already scaled by weight_scaler
|
||||
# Because we scale down the weight using weight_scalar -> clip -> add noise -> scale back up.
|
||||
max_weight: 0.0001
|
||||
weight_scaler: 0.0001
|
||||
min_weight: 0.00009
|
||||
privacy_metrics_config:
|
||||
apply_metrics: true
|
||||
apply_indices_extraction: true
|
||||
# If we extracting word indices we want to consider the rank (sorted freq to rare)
|
||||
# of the words extracted. Any word that rank above this value is considered privacy risk
|
||||
allowed_word_rank: 9000
|
||||
apply_leakage_metric: true
|
||||
max_leakage: 30
|
||||
max_allowed_leakage: 3
|
||||
# take the 95th percentile of the leakage for the next round.
|
||||
adaptive_leakage_threshold: 0.95
|
||||
is_leakage_weighted: true
|
||||
attacker_optimizer_config:
|
||||
lr: 0.03
|
||||
type: adamax
|
||||
amsgrad: false
|
||||
# server_config determines all the server-side settings
|
||||
server_config:
|
||||
wantRL: false # use reinforcement learning to train the optimizer?
|
||||
resume_from_checkpoint: true # if a checkpoint is available start from this iteration?
|
||||
do_profiling: false # enable to capture profiling information during server updates. Will generate a lot of logging
|
||||
RL: # configuration for server-side RL. Ignored if wantRL is false
|
||||
marginal_update_RL: true
|
||||
RL_path: ./RL_models
|
||||
model_descriptor_RL: marginalUpdate
|
||||
network_params: 300,128,128,128,64,100
|
||||
initial_epsilon: 0.5
|
||||
final_epsilon: 0.0001
|
||||
epsilon_gamma: 0.90
|
||||
max_replay_memory_size: 1000
|
||||
minibatch_size: 16
|
||||
gamma: 0.99
|
||||
optimizer_config:
|
||||
lr: 0.0003
|
||||
type: adam
|
||||
amsgrad: true
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
step_size: 1
|
||||
gamma: 0.95
|
||||
# configuration for the server-side optimizer.
|
||||
optimizer_config:
|
||||
# this section for sgd
|
||||
#type: sgd
|
||||
#lr: 0.001
|
||||
# this section for adam
|
||||
type: adamax
|
||||
amsgrad: true
|
||||
lr: 0.0005
|
||||
# this section for adamax
|
||||
#type: adamax
|
||||
#amsgrad: true
|
||||
#lr: 0.002
|
||||
# this section for lamb
|
||||
#lr: 0.1
|
||||
#weight_decay: 0.0 #0.005
|
||||
#type: lamb
|
||||
# This section configures how the learning rate decays
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
gamma: 0.99 # decrease the learning rate gamma * lambda
|
||||
step_size: 100 # apply gamma every step_size iterations
|
||||
val_freq: 3 # evaluate validation set once every val_freq rounds
|
||||
rec_freq: 3 # evaluate test set once every rec_freq rounds
|
||||
max_iteration: 3 # total rounds of FL
|
||||
num_clients_per_iteration: 5 # number of clients to sample per round
|
||||
# server-side data configuration
|
||||
# we load all the data server side, but training data config is configured in the client config.
|
||||
data_config:
|
||||
# validation data
|
||||
val:
|
||||
batch_size: 128
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
val_data: ./data/val/val_data.json
|
||||
vocab_dict: ./models/vocab_reddit.vocab
|
||||
pin_memory: true
|
||||
# num_workers indicates how many workers are used for creating batches.
|
||||
# we've found that batch creation is very fast for small models and it's not
|
||||
# worth the overhead to create new worker processes. For large models
|
||||
# with a lot of data per client it might be more efficient to set this larger.
|
||||
# run with profiling enabled and see if a lot of time is spent in process creation/teardown
|
||||
num_workers: 0
|
||||
num_frames: 2400
|
||||
desired_num_samples: null
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
# Note this is NOT the main training data configuration, which is configured in the
|
||||
# client config. This section is ignored unless you are running replay data.
|
||||
# If you want to run replay data- set a path name for train_data_server.
|
||||
train:
|
||||
batch_size: 128
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
train_data_server: null
|
||||
vocab_dict: ./models/vocab_reddit.vocab
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
num_frames: 2400
|
||||
desired_max_samples: 500
|
||||
max_grad_norm: 10.0
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
# test data configuration
|
||||
test:
|
||||
batch_size: 128
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
train_data: null
|
||||
train_data_server: null
|
||||
test_data: ./data/test/test_data.json
|
||||
vocab_dict: ./models/vocab_reddit.vocab
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
type: model_optimization
|
||||
aggregate_median: softmax # the FL aggregation method
|
||||
weight_train_loss: grad_mean_loss #train_loss #grad_mean_loss #or train_loss - how each client's weight is determined.
|
||||
softmax_beta: 1000
|
||||
initial_lr_client: 0.1
|
||||
lr_decay_factor: 1.0
|
||||
best_model_criterion: loss # choose best model based on minimal loss, for checkpointing
|
||||
fall_back_to_best_model: false # if a model degrades, use the previous best
|
||||
# replay configuration. This is only applied if the server-side training data is fully configured and loaded.
|
||||
server_replay_config:
|
||||
server_iterations: 50
|
||||
optimizer_config:
|
||||
lr: 0.00002
|
||||
amsgrad: true
|
||||
type: adam
|
||||
# end server config
|
||||
# client config dictates the learning parameters for client-side model updates
|
||||
# most parameters are similar to the server config.
|
||||
# Note the core training data is defined in this config.
|
||||
client_config:
|
||||
meta_learning: basic
|
||||
stats_on_smooth_grad: true
|
||||
ignore_subtask: false
|
||||
num_skips_threshold: 10
|
||||
copying_train_data: false
|
||||
do_profiling: false # set to true to performance profile client-side training.
|
||||
data_config:
|
||||
# this is the main training data configuration
|
||||
train:
|
||||
batch_size: 128
|
||||
loader_type: text
|
||||
tokenizer_type: not_applicable
|
||||
prepend_datapath: false
|
||||
list_of_train_data: ./data/train/train_data.json
|
||||
vocab_dict: ./models/vocab_reddit.vocab
|
||||
pin_memory: true
|
||||
num_workers: 0
|
||||
num_frames: 2400
|
||||
desired_max_samples: 500
|
||||
max_grad_norm: 10.0
|
||||
max_batch_size: 128
|
||||
max_num_words: 25
|
||||
unsorted_batch: true
|
||||
utterance_mvn: false
|
||||
type: optimization
|
||||
meta_optimizer_config:
|
||||
lr: 1.0
|
||||
type: sgd
|
||||
optimizer_config:
|
||||
type: sgd
|
||||
annealing_config:
|
||||
type: step_lr
|
||||
step_interval: epoch
|
||||
step_size: 1
|
||||
gamma: 1.0
|
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
from itertools import islice
|
||||
|
||||
val_file = r"./mockup/data/val/val_data.json"
|
||||
test_file = r"./mockup/data/test/test_data.json"
|
||||
train_file = r"./mockup/data/train/train_data.json"
|
||||
|
||||
def main():
|
||||
|
||||
exp = parse_args()
|
||||
|
||||
# Remove vocab if already exists
|
||||
try:
|
||||
os.remove("./mockup/models/vocab_reddit.vocab")
|
||||
except:
|
||||
print("Vocab file not found")
|
||||
|
||||
# Building vocab
|
||||
os.system("echo Building vocab")
|
||||
os.system("python build_vocab.py --data-dir ./mockup/data/train --target-dir mockup/models")
|
||||
|
||||
# Preprocessing data
|
||||
os.system("echo Preprocessing data")
|
||||
|
||||
min = -25
|
||||
max = 0
|
||||
for iteration in range(3):
|
||||
|
||||
min = min + 25
|
||||
max = max + 25
|
||||
|
||||
if iteration == 0:
|
||||
file = val_file
|
||||
elif iteration == 1:
|
||||
file = test_file
|
||||
elif iteration == 2:
|
||||
file = train_file
|
||||
|
||||
with open(file, 'r') as f:
|
||||
json_file = json.load(f)
|
||||
|
||||
users_list = list()
|
||||
num_samples = json_file['num_samples']
|
||||
user_data = json_file['user_data']
|
||||
|
||||
# Truncate user_data to only 25 elements per file
|
||||
user_data = OrderedDict(islice(user_data.items(), min, max))
|
||||
user_data = dict(user_data)
|
||||
|
||||
# Give format to user_data and create users_list
|
||||
if exp == "nlg":
|
||||
for users in user_data:
|
||||
listToStr = ''
|
||||
users_list.append(users)
|
||||
for i, sentences in enumerate(user_data[users]['x']):
|
||||
for j, pieces in enumerate(sentences):
|
||||
listToStr = ' '.join([elem for elem in pieces])
|
||||
user_data[users]['x'][i][j] = listToStr
|
||||
|
||||
full_sentence = ' '.join([elem for elem in sentences])
|
||||
full_sentence = full_sentence.replace('<PAD>', '').replace('<EOS>', '').replace('<BOS>', '').strip()
|
||||
user_data[users]['x'][i] = full_sentence
|
||||
user_data[users].pop('y',None)
|
||||
|
||||
elif exp == "mlm":
|
||||
|
||||
user_data_aux = dict()
|
||||
for users in user_data:
|
||||
listToStr = ''
|
||||
users_list.append(users)
|
||||
for i, sentences in enumerate(user_data[users]['x']):
|
||||
for j, pieces in enumerate(sentences):
|
||||
listToStr = ' '.join([elem for elem in pieces])
|
||||
listToStr = listToStr.replace('<PAD>', '').replace('<EOS>', '').replace('<BOS>', '').strip()
|
||||
user_data[users]['x'][i][j] = listToStr
|
||||
user_data[users].pop('y',None)
|
||||
user_data_aux[users] = user_data[users]['x']
|
||||
user_data = user_data_aux
|
||||
|
||||
# Adjust number of samples
|
||||
new_dict = {'users':users_list ,'num_samples':num_samples[min:max], 'user_data':user_data}
|
||||
f = open(file,'w')
|
||||
json.dump(new_dict,f)
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-e","--Exp", help="Experiment name (nlg/mlm)")
|
||||
args = parser.parse_args()
|
||||
exp = args.Exp
|
||||
|
||||
if exp != "mlm" and exp!="nlg":
|
||||
raise ValueError ("Invalid experiment name, please try once again with mlm/nlg")
|
||||
else:
|
||||
return exp
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import subprocess
|
||||
import os.path
|
||||
import os
|
||||
import platform
|
||||
|
||||
launcher_path='e2e_trainer.py'
|
||||
data_path=r'./testing/mockup/'
|
||||
output_path=r'./testing/outputs'
|
||||
output_folder='./testing/outputs'
|
||||
config_path=r'./testing/configs/hello_world_local.yaml'
|
||||
|
||||
|
||||
def test_e2e_trainer():
|
||||
|
||||
try:
|
||||
#Verify complete script execution
|
||||
os.system("mkdir -p "+ output_folder)
|
||||
|
||||
command = ['mpiexec', '-np', '2', 'python', launcher_path,\
|
||||
'-dataPath',data_path,'-outputPath',output_path,'-config',config_path,\
|
||||
'-task','nlg_gru']
|
||||
|
||||
command_string = ""
|
||||
for elem in command:
|
||||
command_string = " ".join([command_string, str(elem)])
|
||||
|
||||
if platform.system() == "Windows":
|
||||
command_string = "cd .. &" + command_string
|
||||
else:
|
||||
command_string = "cd .. ;" + command_string # For Linux users
|
||||
|
||||
with open('logs.txt','w') as f:
|
||||
process= subprocess.run(command_string, shell=True,stdout=f,text=True,timeout=420)
|
||||
|
||||
return_code=process.returncode
|
||||
print(process.stderr)
|
||||
assert return_code==0
|
||||
|
||||
#Verify output files
|
||||
directory=len(os.listdir('./outputs'))
|
||||
assert directory > 0
|
||||
|
||||
#Verify logs for config file
|
||||
config_exists=False
|
||||
config_file='Copy created'
|
||||
logs=open('logs.txt','r')
|
||||
readLogs=logs.read()
|
||||
if config_file in readLogs:
|
||||
config_exists=True
|
||||
assert config_exists
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print("Encountered an exception: {}".format(e))
|
||||
raise e
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from .utils import *
|
||||
from utils.optimizers.lars import *
|
||||
from utils.optimizers.lamb import *
|
||||
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import random
|
||||
import logging
|
||||
from torch.utils.data import sampler
|
||||
from utils import AverageMeter
|
||||
|
||||
class BatchSampler(sampler.Sampler):
|
||||
"""
|
||||
Simply determines the order in which the loader will read samples from the data set.
|
||||
We want to sample batches randomly, but each batch should have samples that are
|
||||
close to each other in the dataset (so that we don't have a lot of zero padding)
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, batch_size, randomize=True, drop_last=False):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.randomize=randomize
|
||||
|
||||
batches = [range(begin_id, begin_id + batch_size) for begin_id in range(0, len(dataset), batch_size)]
|
||||
|
||||
# if the indexes in the last batch are going over len(dataset), we drop the last batch.
|
||||
if batches[-1][-1] > len(dataset):
|
||||
if drop_last:
|
||||
del batches[-1]
|
||||
else:
|
||||
batches[-1]=range(batches[-1][0],len(dataset))
|
||||
self.batches = batches
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
if self.randomize:
|
||||
random.shuffle(self.batches)
|
||||
|
||||
return iter(self.batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches) * self.batch_size
|
||||
|
||||
|
||||
class DynamicBatchSampler(sampler.Sampler):
|
||||
"""Extension of Sampler that will do the following:
|
||||
1. Change the batch size (essentially number of sequences)
|
||||
in a batch to ensure that the total number of frames are less
|
||||
than a certain threshold.
|
||||
2. Make sure the padding efficiency in the batch is high.
|
||||
"""
|
||||
|
||||
def __init__(self, sampler, frames_threshold, max_batch_size=0, unsorted_batch=False, fps= 1000 / 30):
|
||||
"""
|
||||
@sampler: will mostly be an instance of DistributedSampler.
|
||||
Though it should work with any sampler.
|
||||
@frames_threshold: maximum area of the batch
|
||||
"""
|
||||
self.sampler = sampler
|
||||
self.frames_threshold = frames_threshold
|
||||
self.max_batch_size = max_batch_size
|
||||
self.unsorted_batch = unsorted_batch
|
||||
|
||||
indices, batches = list(), list()
|
||||
# the dataset to which these indices are pointing to
|
||||
dataset = self.sampler.dataset
|
||||
# get all the indices and corresponding durations from
|
||||
# the sampler
|
||||
for idx in self.sampler:
|
||||
indices.append((idx, dataset.utt_list[idx]["duration"]))
|
||||
|
||||
# sort the indices according to duration
|
||||
if self.unsorted_batch is False:
|
||||
indices.sort(key=lambda elem : elem[1])
|
||||
max_dur = indices[-1][1]
|
||||
else:
|
||||
# make sure that you will be able to serve all the utterances
|
||||
max_dur = max([indices[i][1] for i in range(len(indices))])
|
||||
|
||||
# start clubbing the utterances together
|
||||
batch = list()
|
||||
batch_frames, batch_area = 0, 0
|
||||
max_frames_in_batch = 0
|
||||
average_meter = AverageMeter('Padding Efficiency')
|
||||
for idx, duration in indices:
|
||||
if duration > 0:
|
||||
frames = duration * fps
|
||||
if frames > max_frames_in_batch:
|
||||
max_frames_in_batch = frames
|
||||
|
||||
if (self.unsorted_batch and len(batch) < max_batch_size)\
|
||||
or (not self.unsorted_batch and batch_frames + frames <= self.frames_threshold and (max_batch_size == 0 or len(batch) < max_batch_size)):
|
||||
batch.append(idx)
|
||||
batch_frames += frames
|
||||
batch_area = max_frames_in_batch * len(batch)
|
||||
else:
|
||||
# log the stats and add previous batch to batches
|
||||
if batch_area > 0 and len(batch) > 0:
|
||||
average_meter.add(batch_frames, batch_area)
|
||||
batches.append(batch)
|
||||
# make a new one
|
||||
batch = list()
|
||||
batch_frames, batch_area = frames, frames
|
||||
max_frames_in_batch = batch_frames
|
||||
|
||||
# When all indices are processed
|
||||
if batch_area > 0 and len(batch) > 0:
|
||||
average_meter.add(batch_frames, batch_area)
|
||||
batches.append(batch)
|
||||
|
||||
# don't need the 'indices' any more
|
||||
del indices
|
||||
self.batches = batches
|
||||
average_meter.display_results(loglevel=logging.DEBUG)
|
||||
|
||||
def __iter__(self):
|
||||
# shuffle on a batch level
|
||||
random.shuffle(self.batches)
|
||||
return iter(self.batches)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.batches)
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import logging
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from utils import print_rank
|
||||
|
||||
def get_exp_dataloader(task):
|
||||
""" Detect the dataloader declared in the experiment folder
|
||||
|
||||
Args:
|
||||
task (str): task parsed from the console
|
||||
"""
|
||||
|
||||
try:
|
||||
dir = os.path.join('experiments',task,'dataloaders','text_dataloader.py')
|
||||
loader = SourceFileLoader("TextDataLoader",dir).load_module()
|
||||
loader = loader.TextDataLoader
|
||||
except:
|
||||
print_rank("Dataloader not found, please make sure is located inside the experiment folder")
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def detect_loader_type(my_data, loader_type):
|
||||
""" Detect the loader type declared in the configuration file
|
||||
|
||||
Inside this function should go the implementation of
|
||||
specific detection for any kind of loader.
|
||||
|
||||
Args:
|
||||
my_data (str): path of file or chunk file set
|
||||
loader_type (str): loader description in yaml file
|
||||
"""
|
||||
|
||||
if not loader_type == "auto_detect":
|
||||
return loader_type
|
||||
|
||||
# Here should go the implementation for the rest of loaders
|
||||
else:
|
||||
raise ValueError("Unknown format: {}".format(loader_type))
|
||||
|
||||
|
||||
def make_train_dataloader(data_config, data_path, clientx, task=None, vec_size=300, data_strct=None):
|
||||
""" Create a dataloader for training on either server or client side """
|
||||
|
||||
mode = 'train'
|
||||
tokenizer_type= data_config.get('tokenizer_type', 'not_applicable')
|
||||
|
||||
# Training list for a server
|
||||
if clientx is None:
|
||||
if not "train_data_server" in data_config or data_config["train_data_server"] is None:
|
||||
print_rank("No server training set is defined")
|
||||
return None
|
||||
my_data = os.path.join(data_path, data_config["train_data_server"])
|
||||
mode='val'
|
||||
|
||||
# Training list on a client side
|
||||
else:
|
||||
if tokenizer_type != 'not_applicable':
|
||||
assert clientx >=0 and clientx < len(data_config["train_data"]), "Invalid client index {}".format(clientx)
|
||||
my_data = data_config["train_data"][clientx]
|
||||
else:
|
||||
my_data = data_config["list_of_train_data"]
|
||||
|
||||
# Find the loader_type
|
||||
loader_type = detect_loader_type(my_data, data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
train_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else my_data,
|
||||
user_idx = clientx,
|
||||
mode = mode,
|
||||
args=data_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported {}: detected_type={} loader_type={} audio_format={}".format(my_data, loader_type, data_config["loader_type"], data_config["audio_format"]))
|
||||
return train_dataloader
|
||||
|
||||
|
||||
|
||||
def make_val_dataloader(data_config, data_path, task=None, data_strct=None):
|
||||
""" Return a data loader for a validation set """
|
||||
|
||||
if not "val_data" in data_config or data_config["val_data"] is None:
|
||||
print_rank("Validation data list is not set", loglevel=logging.DEBUG)
|
||||
return None
|
||||
|
||||
loader_type = detect_loader_type(data_config["val_data"], data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
|
||||
val_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["val_data"]),
|
||||
user_idx = 0,
|
||||
mode = 'val',
|
||||
args=data_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
||||
return val_dataloader
|
||||
|
||||
|
||||
def make_test_dataloader(data_config, data_path, task=None, data_strct=None):
|
||||
""" Return a data loader for an evaluation set. """
|
||||
|
||||
if not "test_data" in data_config or data_config["test_data"] is None:
|
||||
print_rank("Test data list is not set")
|
||||
return None
|
||||
|
||||
loader_type = detect_loader_type(data_config["test_data"], data_config["loader_type"])
|
||||
|
||||
if loader_type == 'text':
|
||||
TextDataLoader = get_exp_dataloader(task)
|
||||
|
||||
test_dataloader = TextDataLoader(
|
||||
data = data_strct if data_strct is not None else os.path.join(data_path, data_config["test_data"]),
|
||||
user_idx = 0,
|
||||
mode = 'test',
|
||||
args=data_config
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("Not supported loader_type={} audio_format={}".format(loader_type, data_config["audio_format"]))
|
||||
return test_dataloader
|
||||
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
class AdamW(Optimizer):
|
||||
""" Implements Adam algorithm with weight decay fix.
|
||||
Parameters:
|
||||
lr (float): learning rate. Default 1e-3.
|
||||
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
|
||||
eps (float): Adams epsilon. Default: 1e-6
|
||||
weight_decay (float): Weight decay. Default: 0.0
|
||||
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
|
||||
"""
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
|
||||
if lr < 0.0:
|
||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
|
||||
correct_bias=correct_bias)
|
||||
super(AdamW, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# In-place operations to update the averages at the same time
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
|
||||
denom = exp_avg_sq.sqrt().add_(group['eps'])
|
||||
|
||||
step_size = group['lr']
|
||||
if group['correct_bias']: # No bias correction for Bert
|
||||
bias_correction1 = 1.0 - beta1 ** state['step']
|
||||
bias_correction2 = 1.0 - beta2 ** state['step']
|
||||
step_size = step_size * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
p.data.addcdiv_(exp_avg, denom, value = -step_size)
|
||||
|
||||
# Just adding the square of the weights to the loss function is *not*
|
||||
# the correct way of using L2 regularization/weight decay with Adam,
|
||||
# since that will interact with the m and v parameters in strange ways.
|
||||
#
|
||||
# Instead we want to decay the weights in a manner that doesn't interact
|
||||
# with the m/v parameters. This is equivalent to adding the square
|
||||
# of the weights to the loss with plain (non-momentum) SGD.
|
||||
# Add weight decay at the end (fixed version)
|
||||
if group['weight_decay'] > 0.0:
|
||||
p.data.add_(p.data, alpha= -group['lr'] * group['weight_decay'])
|
||||
|
||||
return loss
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""Lamb optimizer."""
|
||||
|
||||
import collections
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.optim import Optimizer
|
||||
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
|
||||
"""Log a histogram of trust ratio scalars in across layers."""
|
||||
results = collections.defaultdict(list)
|
||||
for group in optimizer.param_groups:
|
||||
for p in group['params']:
|
||||
state = optimizer.state[p]
|
||||
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
|
||||
if i in state:
|
||||
results[i].append(state[i])
|
||||
|
||||
for k, v in results.items():
|
||||
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
|
||||
|
||||
except ImportError:
|
||||
def log_lamb_rs(optimizer, event_writer, token_count):
|
||||
print("tensorboardX is not installed")
|
||||
|
||||
|
||||
class LAMB(Optimizer):
|
||||
r"""Implements Lamb algorithm.
|
||||
|
||||
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
|
||||
|
||||
Arguments:
|
||||
params (iterable): iterable of parameters to optimize or dicts defining
|
||||
parameter groups
|
||||
lr (float, optional): learning rate (default: 1e-3)
|
||||
betas (Tuple[float, float], optional): coefficients used for computing
|
||||
running averages of gradient and its square (default: (0.9, 0.999))
|
||||
eps (float, optional): term added to the denominator to improve
|
||||
numerical stability (default: 1e-8)
|
||||
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
|
||||
adam (bool, optional): always use trust ratio = 1, which turns this into
|
||||
Adam. Useful for comparison purposes.
|
||||
|
||||
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
|
||||
https://arxiv.org/abs/1904.00962
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
|
||||
weight_decay=0, adam=False):
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay)
|
||||
self.adam = adam
|
||||
super(LAMB, self).__init__(params, defaults)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
grad = p.grad.data
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
# m_t
|
||||
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
|
||||
# v_t
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
||||
|
||||
# Paper v3 does not use debiasing.
|
||||
# bias_correction1 = 1 - beta1 ** state['step']
|
||||
# bias_correction2 = 1 - beta2 ** state['step']
|
||||
# Apply bias to lr to avoid broadcast.
|
||||
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
|
||||
|
||||
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
|
||||
|
||||
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
|
||||
if group['weight_decay'] != 0:
|
||||
adam_step.add_(p.data, alpha=group['weight_decay'])
|
||||
|
||||
adam_norm = adam_step.pow(2).sum().sqrt()
|
||||
if weight_norm == 0 or adam_norm == 0:
|
||||
trust_ratio = 1
|
||||
else:
|
||||
trust_ratio = weight_norm / adam_norm
|
||||
state['weight_norm'] = weight_norm
|
||||
state['adam_norm'] = adam_norm
|
||||
state['trust_ratio'] = trust_ratio
|
||||
if self.adam:
|
||||
trust_ratio = 1
|
||||
|
||||
p.data.add_(adam_step, alpha=-step_size * trust_ratio)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
"""distoptim.hit package"""
|
||||
import logging
|
||||
import torch
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
class LarsSGDV1(torch.optim.SGD):
|
||||
""" LARS SGD V1, based on https://arxiv.org/abs/1708.03888
|
||||
2018.
|
||||
Refer to torch.optim.SGD for paramters.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False):
|
||||
LOG.info("Init LarsSGDV1")
|
||||
super(LarsSGDV1, self).__init__(
|
||||
params, lr, momentum, dampening, weight_decay, nesterov)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
# dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
d_p = p.grad.data
|
||||
|
||||
p_n = p.data.norm()
|
||||
d_p_n = d_p.norm()
|
||||
|
||||
if weight_decay != 0:
|
||||
d_p_n.add_(weight_decay, p_n)
|
||||
d_p.add_(weight_decay, p.data)
|
||||
|
||||
alpha = 0.001 * p_n / d_p_n # This is the LARS eta from the paper
|
||||
lr = alpha * group['lr']
|
||||
lr = min(lr, 5.0)
|
||||
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = \
|
||||
torch.clone(d_p).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(lr, d_p)
|
||||
if nesterov:
|
||||
d_p = d_p.add(momentum, buf)
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
p.data.add_(-1, d_p)
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class LarsSGD(torch.optim.SGD):
|
||||
""" LARS SGD, based on https://arxiv.org/abs/1904.00962 Algorithm 1
|
||||
2019, a newer version.
|
||||
Refer to torch.optim.SGD for paramters.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr, momentum=0, dampening=0,
|
||||
weight_decay=0, nesterov=False):
|
||||
LOG.info("Init LarsSGD")
|
||||
super(LarsSGD, self).__init__(
|
||||
params, lr, momentum, dampening, weight_decay, nesterov)
|
||||
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
Arguments:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
weight_decay = group['weight_decay']
|
||||
momentum = group['momentum']
|
||||
# dampening = group['dampening']
|
||||
nesterov = group['nesterov']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
|
||||
d_p = p.grad.data
|
||||
if weight_decay != 0:
|
||||
d_p.add(p.data, alpha=weight_decay)
|
||||
|
||||
if momentum != 0:
|
||||
param_state = self.state[p]
|
||||
if 'momentum_buffer' not in param_state:
|
||||
buf = param_state['momentum_buffer'] = \
|
||||
torch.clone(d_p).detach()
|
||||
else:
|
||||
buf = param_state['momentum_buffer']
|
||||
buf.mul_(momentum).add_(1 - momentum, d_p)
|
||||
if nesterov:
|
||||
d_p = d_p.add(buf, alpha=momentum)
|
||||
else:
|
||||
d_p = buf
|
||||
|
||||
lr = group['lr'] * p.data.norm() / (d_p.norm() + 1e-8)
|
||||
lr.clamp_(0, 10)
|
||||
p.data.add_(d_p, alpha=-lr)
|
||||
|
||||
return loss
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import h5py
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
|
||||
path = r'C:\Users\train.tsv'
|
||||
|
||||
def local_time():
|
||||
return str(time.strftime("%H:%M:%S",time.localtime()))
|
||||
|
||||
|
||||
print(local_time() + " Starting script " )
|
||||
columns = ['author','num1','content','str1','str2','num2','subreddit']
|
||||
df = pd.read_csv(path, sep='\t', names=columns, header=None)
|
||||
print(local_time() + " File has been read " )
|
||||
|
||||
df_authors = pd.DataFrame(df['author'])
|
||||
df_content = pd.DataFrame(df['content'])
|
||||
df_file = pd.concat([df_authors,df_content], axis=1)
|
||||
print(local_time() + " Data needed has been concatenated ")
|
||||
|
||||
|
||||
users_group = df_file.groupby('author')
|
||||
group0 = df_file.groupby(['author','content'])
|
||||
group1 = pd.Series(users_group.size())
|
||||
users = (group1.index).to_numpy()
|
||||
print(local_time() + " users been formatted ")
|
||||
num_samples = group1.values
|
||||
print(local_time() + " num_samples has been formatted ")
|
||||
user_data_dict= {}
|
||||
|
||||
user_data_dict= {i: {'x':list()} for i in tqdm(users)}
|
||||
|
||||
for i in tqdm(range(len(df_file))):
|
||||
if df_file['content'][i] not in user_data_dict[df_file['author'][i]]['x']:
|
||||
user_data_dict[df_file['author'][i]]['x'].append(df_file['content'][i])
|
||||
|
||||
|
||||
print(local_time() + " user_data has been formatted ")
|
||||
f = h5py.File(r"C:\Users\train.hdf5", "w")
|
||||
dset_0 = f.create_dataset("num_samples",data=num_samples)
|
||||
dset_1= f.create_dataset("users", data =users)
|
||||
print(local_time() + " starting to store dictionary ")
|
||||
|
||||
user_data = f.create_group("user_data")
|
||||
for user in tqdm(user_data_dict):
|
||||
user_group = user_data.create_group(user)
|
||||
user_data_dict[user]['x'] = [str(e).encode('utf8') for e in user_data_dict[user]['x']]
|
||||
x_dset = user_group.create_dataset('x',data=user_data_dict[user]['x'])
|
||||
|
||||
print(local_time() + " end of script ")
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
path = r'C:\Users\train.tsv'
|
||||
|
||||
def local_time():
|
||||
return str(time.strftime("%H:%M:%S",time.localtime()))
|
||||
|
||||
|
||||
print(local_time() + " Starting script " )
|
||||
columns = ['author','num1','content','str1','str2','num2','subreddit']
|
||||
df = pd.read_csv(path, sep='\t', names=columns, header=None)
|
||||
print(local_time() + " File has been read " )
|
||||
|
||||
df_authors = pd.DataFrame(df['author'])
|
||||
df_content = pd.DataFrame(df['content'])
|
||||
df_file = pd.concat([df_authors,df_content], axis=1)
|
||||
print(local_time() + " Data needed has been concatenated ")
|
||||
|
||||
|
||||
users_group = df_file.groupby('author')
|
||||
group0 = df_file.groupby(['author','content'])
|
||||
group1 = pd.Series(users_group.size())
|
||||
users = (group1.index).to_numpy()
|
||||
print(local_time() + " users been formatted ")
|
||||
num_samples = group1.values
|
||||
print(local_time() + " num_samples has been formatted ")
|
||||
user_data_dict= {}
|
||||
|
||||
user_data_dict= {i: {'x':list()} for i in tqdm(users)}
|
||||
|
||||
for i in tqdm(range(len(df_file))):
|
||||
if df_file['content'][i] not in user_data_dict[df_file['author'][i]]['x']:
|
||||
user_data_dict[df_file['author'][i]]['x'].append(df_file['content'][i])
|
||||
|
||||
|
||||
f = open(r'C:\Users\train.json', "w")
|
||||
new_data = {'users': users.tolist(), 'num_samples': num_samples.tolist(), 'user_data': user_data_dict}
|
||||
json.dump(new_data,f)
|
||||
print(local_time() + " end of script ")
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import json
|
||||
import h5py
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
json_file = r'C:\Users\train.tsv'
|
||||
|
||||
def local_time():
|
||||
return str(time.strftime("%H:%M:%S",time.localtime()))
|
||||
|
||||
print(local_time() + " Starting script " )
|
||||
with open(json_file, 'r') as f:
|
||||
json_file = json.load(f)
|
||||
print(local_time() + " JSON file read " )
|
||||
|
||||
hdf_file = h5py.File(r"C:\Users\train.hdf5", "w")
|
||||
dset_0 = hdf_file.create_dataset("users",data=json_file['users'])
|
||||
dset_1 = hdf_file.create_dataset("num_samples",data=json_file['num_samples'])
|
||||
print(local_time() + " users and num_samples stored " )
|
||||
|
||||
user_data = hdf_file.create_group("user_data")
|
||||
for user in tqdm(json_file['user_data']):
|
||||
user_group = user_data.create_group(user)
|
||||
dset_2 = user_group.create_dataset('x',data=json_file['user_data'][user]['x'])
|
||||
|
||||
print(local_time() + " end of script " )
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
from mpi4py import MPI
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from utils import print_rank
|
||||
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
size = comm.Get_size()
|
||||
rank = comm.Get_rank()
|
||||
|
||||
""" Here we have classes and functions that allow one to send and process multiple
|
||||
messages in parallel on MPI. """
|
||||
|
||||
def process_in_parallel(client_fn, client_data, server_data, models, data_path):
|
||||
""" Process multiple orders in parallel
|
||||
|
||||
Parameters
|
||||
----------
|
||||
client_fn: callback
|
||||
Function we want to call.
|
||||
client_data: list of tuples
|
||||
Arguments that will be passed to function.
|
||||
server_data: tuple
|
||||
Data passed from server to update model parameters.
|
||||
models: torch.nn.Module
|
||||
Models we will send to the clients.
|
||||
data_path: str
|
||||
Path to data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
Output of each callback in the list passed as input.
|
||||
"""
|
||||
with ThreadPoolExecutor(max_workers=len(client_data)) as pool:
|
||||
requests = []
|
||||
for k, args in enumerate(client_data):
|
||||
requests.append(pool.submit(client_fn, args, server_data, models[k], data_path))
|
||||
|
||||
results = [request.result() for request in requests]
|
||||
print_rank(f'finished processing batch of size {len(client_data)}')
|
||||
|
||||
return results
|
|
@ -0,0 +1,602 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import logging
|
||||
import yaml
|
||||
import time
|
||||
import math
|
||||
import json
|
||||
import copy
|
||||
import io
|
||||
import pstats
|
||||
import functools
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from utils.optimizers.lars import LarsSGD
|
||||
from utils.optimizers.lamb import LAMB
|
||||
from utils.optimizers.adamW import AdamW
|
||||
from core.globals import TRAINING_FRAMEWORK_TYPE
|
||||
from easydict import EasyDict as edict
|
||||
from torch.optim.lr_scheduler import (
|
||||
StepLR,
|
||||
MultiStepLR,
|
||||
ReduceLROnPlateau )
|
||||
|
||||
|
||||
if TRAINING_FRAMEWORK_TYPE == 'mpi':
|
||||
from mpi4py import MPI
|
||||
else:
|
||||
raise NotImplementedError('Training framework is not yet supported')
|
||||
|
||||
def make_optimizer(optimizer_config, model):
|
||||
"""Initialization for optimizer."""
|
||||
|
||||
tmp_config = copy.deepcopy(optimizer_config)
|
||||
if optimizer_config["type"] == "sgd":
|
||||
tmp_config.pop("type", None)
|
||||
return torch.optim.SGD(model.parameters(), **tmp_config)
|
||||
|
||||
elif optimizer_config["type"] == "adam":
|
||||
tmp_config.pop("type", None)
|
||||
return torch.optim.Adam(model.parameters(), **tmp_config)
|
||||
|
||||
elif optimizer_config["type"] == "adamax":
|
||||
tmp_config.pop("type", None)
|
||||
tmp_config.pop("amsgrad", None)
|
||||
return torch.optim.Adamax(model.parameters(), **tmp_config)
|
||||
|
||||
elif optimizer_config["type"] == "lars":
|
||||
tmp_config.pop("type", None)
|
||||
from torchlars import LARS
|
||||
base_optimizer = torch.optim.SGD(model.parameters(), **tmp_config)
|
||||
return LARS(optimizer=base_optimizer, eps=1e-8, trust_coef=0.001)
|
||||
|
||||
elif optimizer_config["type"] == "LarsSGD":
|
||||
tmp_config.pop("type", None)
|
||||
return LarsSGD(model.parameters(),**tmp_config)
|
||||
|
||||
elif optimizer_config["type"] == "lamb":
|
||||
tmp_config.pop("type", None)
|
||||
return LAMB(model.parameters(), **tmp_config)
|
||||
|
||||
elif optimizer_config["type"] == "adamW":
|
||||
tmp_config.pop("type", None)
|
||||
tmp_config.pop("amsgrad", None)
|
||||
return AdamW(model.parameters(), **tmp_config)
|
||||
|
||||
else:
|
||||
raise ValueError("{} optimizer not supported".format(optimizer_config["type"]))
|
||||
|
||||
|
||||
def get_lr(optimizer):
|
||||
"""Obtain LR."""
|
||||
for param_group in optimizer.param_groups:
|
||||
return param_group['lr']
|
||||
|
||||
def get_lr_all(optimizer):
|
||||
"""Double checking for get_lr."""
|
||||
for param_group in optimizer.param_groups:
|
||||
yield param_group['lr']
|
||||
|
||||
|
||||
def softmax(X, theta = 1.0, axis = None):
|
||||
"""Compute the softmax of each element along an axis of X.
|
||||
|
||||
Args:
|
||||
X (ndarray): x, probably should be floats.
|
||||
theta (float): used as a multiplier prior to exponentiation. Default = 1.0
|
||||
axis : axis to compute values along. Default is the first non-singleton axis.
|
||||
|
||||
Returns:
|
||||
An array the same size as X. The result will sum to 1 along the specified axis.
|
||||
"""
|
||||
# make X at least 2d
|
||||
y = np.atleast_2d(X)
|
||||
|
||||
# find axis
|
||||
if axis is None:
|
||||
axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)
|
||||
|
||||
# multiply y against the theta parameter,
|
||||
y = y * float(theta)
|
||||
|
||||
# subtract the max for numerical stability
|
||||
y = y - np.expand_dims(np.max(y, axis = axis), axis)
|
||||
|
||||
# exponentiate y
|
||||
y = np.exp(y)
|
||||
|
||||
# take the sum along the specified axis
|
||||
ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)
|
||||
|
||||
# finally: divide elementwise
|
||||
p = y / ax_sum
|
||||
|
||||
# flatten if X was 1D
|
||||
if len(X.shape) == 1: p = p.flatten()
|
||||
|
||||
return p
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
""" Will calculate running micro and macro averages for various
|
||||
(error/efficiency) rates.
|
||||
"""
|
||||
def __init__(self, metric_name):
|
||||
self.numerators, self.denominators = list(), list()
|
||||
self.metric_name = metric_name
|
||||
|
||||
def add(self, top, bottom):
|
||||
self.numerators.append(top)
|
||||
self.denominators.append(bottom)
|
||||
|
||||
def get_macro_average(self):
|
||||
scores = [float(self.numerators[i]) / self.denominators[i] \
|
||||
for i in range(len(self.denominators))]
|
||||
return self.get_average(scores)
|
||||
|
||||
def get_micro_average(self):
|
||||
return float(sum(self.numerators)) / sum(self.denominators)
|
||||
|
||||
# accepts a list and returns average
|
||||
def get_average(self, l):
|
||||
return sum(l) / float(len(l))
|
||||
|
||||
def reset(self):
|
||||
self.numerators, self.denominators = list(), list()
|
||||
|
||||
def display_results(self, loglevel=logging.INFO):
|
||||
print_rank("{} Macro average: {}".format(self.metric_name,
|
||||
self.get_macro_average()), loglevel)
|
||||
print_rank("{} Micro average: {}".format(self.metric_name,
|
||||
self.get_micro_average()), loglevel)
|
||||
|
||||
|
||||
def make_lr_scheduler(annealing_config, optimizer, num_batches=1):
|
||||
"""Set learning rate scheduler."""
|
||||
|
||||
annealing_config = copy.deepcopy(annealing_config)
|
||||
annealing_type = annealing_config.pop("type")
|
||||
|
||||
# per epoch or per iter
|
||||
step_interval='epoch'
|
||||
if "step_interval" in annealing_config:
|
||||
step_interval = annealing_config.pop("step_interval")
|
||||
|
||||
if annealing_type == "step_lr":
|
||||
# convert epoch steps to iter steps
|
||||
# expochs can also be floats like 1.5
|
||||
if step_interval == "epoch":
|
||||
annealing_config["step_size"] = int(num_batches * \
|
||||
annealing_config["step_size"])
|
||||
lr_scheduler = StepLR(optimizer=optimizer,
|
||||
**annealing_config)
|
||||
elif annealing_type == "multi_step_lr":
|
||||
# convert epoch steps to iter steps
|
||||
if step_interval == "epoch":
|
||||
annealing_config["milestones"] = [int(i * num_batches) for i in annealing_config["milestones"]]
|
||||
lr_scheduler = MultiStepLR(optimizer=optimizer,
|
||||
**annealing_config)
|
||||
elif annealing_type == "rampup-keep-expdecay-keep":
|
||||
# emulate SpecAugment scheduling
|
||||
lr_scheduler = RampupKeepExpdecayKeepLRScheduler(optimizer=optimizer,
|
||||
**annealing_config)
|
||||
elif annealing_type == 'val_loss':
|
||||
lr_scheduler = ReduceLROnPlateau(optimizer,
|
||||
**annealing_config)
|
||||
else:
|
||||
raise ValueError("{} LR scheduler not supported".format(
|
||||
annealing_type))
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
class RampupKeepExpdecayKeepLRScheduler(torch.optim.lr_scheduler._LRScheduler):
|
||||
"""Implements the LR schedule described in the specaugment paper."""
|
||||
|
||||
def __init__(self, optimizer, peak_lr=0.001, floor_lr=0.00001, sr=1000, si=40000, sf=160000, last_epoch=-1):
|
||||
assert(peak_lr>=floor_lr)
|
||||
self.peak_lr = peak_lr
|
||||
self.floor_lr = floor_lr
|
||||
assert(sr<=si)
|
||||
assert(si<=sf)
|
||||
self.sr = sr
|
||||
self.si = si
|
||||
self.sf = sf
|
||||
self.gamma = math.log(self.floor_lr/self.peak_lr)/(float(self.sf-self.si))
|
||||
print('self.gamma')
|
||||
print(self.gamma)
|
||||
self.step_count = 0
|
||||
super(RampupKeepExpdecayKeepLRScheduler, self).__init__(optimizer, last_epoch=last_epoch)
|
||||
|
||||
def step(self, epoch=None):
|
||||
for p, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
p['lr'] = lr
|
||||
self.step_count += 1
|
||||
|
||||
def get_lr(self):
|
||||
lr = self.floor_lr
|
||||
if self.step_count < self.sr:
|
||||
# linear ramp up
|
||||
lr = self.peak_lr * float(self.step_count) / float(self.sr)
|
||||
elif self.step_count < self.si:
|
||||
# keep peak_lr
|
||||
lr = self.peak_lr
|
||||
elif self.step_count < self.sf:
|
||||
# exponential decay from peak_lr to floor_lr
|
||||
lr = self.peak_lr * math.exp(self.gamma * (float(self.step_count-self.si)))
|
||||
|
||||
return [lr for base_lr in self.base_lrs]
|
||||
|
||||
|
||||
|
||||
class ScheduledSamplingScheduler():
|
||||
""" Implementing the schedule sampling rate schedule.
|
||||
|
||||
0 - ramp_start = initial_rate
|
||||
ramp_start - ramp_end = {linearly increase to final_rate}
|
||||
ramp_end - infinity = final_rate
|
||||
"""
|
||||
|
||||
def __init__(self, model, ramp_start, ramp_stop,
|
||||
initial_rate, final_rate):
|
||||
self.model = model
|
||||
self.ramp_start = ramp_start
|
||||
self.ramp_stop = ramp_stop
|
||||
self.initial_rate = initial_rate
|
||||
self.final_rate = final_rate
|
||||
self.iter = 0
|
||||
|
||||
def step(self):
|
||||
if self.iter < self.ramp_start:
|
||||
self.model.scheduled_sampling_rate = self.initial_rate
|
||||
elif self.iter >= self.ramp_start and self.iter <= self.ramp_stop:
|
||||
self.model.scheduled_sampling_rate = self.initial_rate + (self.final_rate - self.initial_rate) * ( (self.iter - self.ramp_start) / (self.ramp_stop - self.ramp_start))
|
||||
else:
|
||||
self.model.scheduled_sampling_rate = self.final_rate
|
||||
|
||||
self.model.scheduled_sampling = (self.model.scheduled_sampling_rate != 0)
|
||||
self.iter += 1
|
||||
|
||||
def state_dict(self):
|
||||
return {key: value for key, value in self.__dict__.items() if key != 'model'}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.__dict__.update(state_dict)
|
||||
|
||||
|
||||
class NBestTaskScheduler():
|
||||
""" Implementing the scheduler for multi-task training.
|
||||
|
||||
num_tasks[0]: 0 <= i < iteration_per_task[0]
|
||||
num_tasks[1]: iteration_per_task[0] <= i < iteration_per_task[1]
|
||||
"""
|
||||
def __init__(self, num_tasks, iteration_per_task):
|
||||
assert len(num_tasks) == len(iteration_per_task), "Mismatched length {}!={}".format(len(num_tasks), len(iteration_per_task))
|
||||
self.iter = 0
|
||||
self.stagex = 0
|
||||
self.num_tasks = num_tasks
|
||||
self.iteration_per_task = iteration_per_task
|
||||
|
||||
def current_num_tasks(self):
|
||||
return self.num_tasks[self.stagex]
|
||||
|
||||
def no_label_updates(self):
|
||||
"""Return how many times transcription must be updated."""
|
||||
return (self.iter // self.iteration_per_task[-1]) + 1
|
||||
|
||||
def set_iteration_no(self, iter_no):
|
||||
self.iter = iter_no
|
||||
|
||||
def step(self):
|
||||
print_rank("Iter={}: #tasks {} at stage {}".format(self.iter, self.current_num_tasks(), self.stagex))
|
||||
local_iter = self.iter % self.iteration_per_task[-1]
|
||||
if local_iter == 0:
|
||||
self.stagex = 0
|
||||
elif local_iter >= self.iteration_per_task[self.stagex]:
|
||||
self.stagex += 1
|
||||
|
||||
self.iter += 1
|
||||
|
||||
|
||||
# Logging and write-to-disk utilities
|
||||
|
||||
def init_logging(log_dir, loglevel=logging.DEBUG):
|
||||
"""Initialize logging"""
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, "log.out")
|
||||
logging.basicConfig(filename=log_file,
|
||||
level=loglevel)
|
||||
handler = logging.StreamHandler(stream=sys.stdout)
|
||||
logging.getLogger().addHandler(handler)
|
||||
|
||||
|
||||
def print_cuda_stats():
|
||||
if torch.cuda.is_available():
|
||||
print_rank("torch.cuda.memory_allocated(): {}".format(torch.cuda.memory_allocated()))
|
||||
print_rank("torch.cuda.memory_cached(): {}".format(torch.cuda.memory_cached()))
|
||||
print_rank("torch.cuda.synchronize(): {}".format(torch.cuda.synchronize()))
|
||||
else:
|
||||
print_rank("No CUDA GPU available")
|
||||
|
||||
|
||||
def print_rank(str, loglevel=logging.INFO):
|
||||
|
||||
str = "{} : {}".format(time.ctime(), str)
|
||||
logging.log(loglevel, str)
|
||||
|
||||
def print_profiler(profiler, loglevel=logging.INFO):
|
||||
memfile = io.StringIO()
|
||||
pstats.Stats(profiler, stream=memfile) \
|
||||
.strip_dirs() \
|
||||
.sort_stats(pstats.SortKey.CUMULATIVE) \
|
||||
.print_stats(20)
|
||||
for l in memfile.getvalue().split('\n'):
|
||||
print_rank(l, loglevel=loglevel)
|
||||
memfile.close()
|
||||
|
||||
|
||||
def write_yaml(save_path, config):
|
||||
with open(save_path, 'w', encoding='utf8') as yaml_file:
|
||||
yaml.dump(config, yaml_file, default_flow_style=False)
|
||||
|
||||
def torch_save(save_path, state_or_model):
|
||||
torch.save(state_or_model, save_path)
|
||||
|
||||
def write_tokens(save_path, token_list):
|
||||
with open(save_path, 'w', encoding='utf8') as token_fid:
|
||||
for w in token_list:
|
||||
token_fid.write(w + '\n')
|
||||
|
||||
|
||||
def try_except_save(save_fn, **kwargs):
|
||||
""" Try to write it out 3 times."""
|
||||
|
||||
max_attempts = 3
|
||||
for attempt in range(1, max_attempts+1):
|
||||
try:
|
||||
save_fn(**kwargs)
|
||||
except IOError:
|
||||
print_rank("Write operation failed on {} attempt".format(attempt))
|
||||
else:
|
||||
print_rank("Write operation succeeded in {} attempts".format(attempt))
|
||||
return
|
||||
|
||||
|
||||
def write_nbest_jsonl(uttid2jsonl, uttid2hypos, uttid2scores, outputpath, nbest, orgpath="", newpath=""):
|
||||
""" Dump a json list file with n-best hypos."""
|
||||
|
||||
newjsonl = []
|
||||
for uttid, jsonl in uttid2jsonl.items():
|
||||
if not uttid in uttid2hypos:
|
||||
print("Missing utterance {} in results".format(uttid))
|
||||
continue
|
||||
hypos = uttid2hypos[uttid]
|
||||
if nbest > 1:
|
||||
# re-normalize the probablity from N-best: ignoring the events out of the N-best hypos
|
||||
weights = uttid2scores[uttid]
|
||||
if len(weights) < nbest:
|
||||
for n in range(len(weights), nbest):
|
||||
print_rank("Mising {}-th best result in {}. Appending {}".format(n, uttid, weights[0]))
|
||||
weights = np.append(weights, np.array(weights[0]))
|
||||
|
||||
weights = softmax(weights[0:nbest]) if uttid in uttid2scores else np.ones(nbest) / nbest
|
||||
# Filling the missing hypos with the 1st best candidate
|
||||
for n in range(min(nbest, len(hypos))):
|
||||
newjson = copy.deepcopy(jsonl)
|
||||
newjson["id"] = "{}-{}".format(uttid, n)
|
||||
newjson["text"] = " ".join(hypos[n])
|
||||
newjson["loss_weight"] = weights[n]
|
||||
else:
|
||||
newjson = copy.deepcopy(jsonl)
|
||||
newjson["id"] = uttid
|
||||
newjson["text"] = " ".join(hypos[0])
|
||||
|
||||
newjsonl.append(newjson)
|
||||
|
||||
with open(outputpath, 'w') as ofp:
|
||||
for jsonl in newjsonl:
|
||||
jsonl["wav"] = jsonl["wav"].replace(orgpath, newpath)
|
||||
ofp.write("{}\n".format(json.dumps(jsonl)))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def write_multitask_jsonl(uttid2jsonl, uttid2hypos, uttid2scores, outputpath, nbest, orgpath="", newpath=""):
|
||||
""" Dump a json list file with n-best hypos."""
|
||||
|
||||
if nbest==1:
|
||||
return write_nbest_jsonl(uttid2jsonl, uttid2hypos, uttid2scores, outputpath, nbest, orgpath, newpath)
|
||||
|
||||
newjsonl = []
|
||||
for uttid, jsonl in uttid2jsonl.items():
|
||||
if not uttid in uttid2hypos:
|
||||
print_rank("Missing utterance {} in results".format(uttid))
|
||||
continue
|
||||
hypos = uttid2hypos[uttid]
|
||||
# re-normalize the probablity from N-best: ignoring the events out of the N-best hypos
|
||||
weights = uttid2scores[uttid]
|
||||
if len(weights) < nbest:
|
||||
for n in range(len(weights), nbest):
|
||||
print_rank("Mising {}-th best result in {}. Appending {}".format(n, uttid, weights[0]))
|
||||
weights = np.append(weights, np.array(weights[0]))
|
||||
|
||||
weights = softmax(weights[0:nbest]) if uttid in uttid2scores else np.ones(nbest) / nbest
|
||||
newjson = jsonl
|
||||
newjson["task_weights"] = weights.tolist()
|
||||
assert len(weights) == nbest, "{}: Weight length does not match: {} != {}".format(uttid, len(weights), nbest)
|
||||
newjson["text"] = " ".join(hypos[0])
|
||||
newjson["subtextl"] = []
|
||||
all_null_results = newjson["text"] == ""
|
||||
for n in range(1, nbest):
|
||||
if n < len(hypos):
|
||||
newjson["subtextl"].append(" ".join(hypos[n]))
|
||||
else:
|
||||
print_rank("Mising {}-th best result in {}".format(n, uttid))
|
||||
newjson["subtextl"].append(" ".join(hypos[0]))
|
||||
if all_null_results is True:
|
||||
all_null_results = newjson["subtextl"][n-1] == ""
|
||||
|
||||
assert len(newjson["subtextl"]) == nbest-1, "#sub-rec results does not match: {} != {}".format(len(newjson["subtextl"]), nbest-1)
|
||||
# take meaningful results only and ignore null string
|
||||
if all_null_results is False:
|
||||
newjsonl.append(newjson)
|
||||
else:
|
||||
print_rank("Skip {}: Invalid result '{}'".format(uttid, newjson["text"]))
|
||||
|
||||
with open(outputpath, 'w') as ofp:
|
||||
for jsonl in newjsonl:
|
||||
jsonl["wav"] = jsonl["wav"].replace(orgpath, newpath)
|
||||
ofp.write("{}\n".format(json.dumps(jsonl)))
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def load_eval_result_jsonl(resultjsonl, uttid2hypos=OrderedDict(), uttid2scores=OrderedDict(), dumpfp=None, dump_msg="RESULT: "):
|
||||
"""Load the result JSON list file dumped by Evaluator().
|
||||
|
||||
Args:
|
||||
|
||||
resultjsonl (str): input JSON list file
|
||||
uttid2hypos: (dict): maps the utterance ID to text, [uttid] = hypothesis text
|
||||
uttid2scores (dict): maps the utterance ID to a confidence score, [uttid] = confidence score(s)
|
||||
dumpfp (file): pointer where the WERs will be written out
|
||||
dump_msg (str): message string before the WER result
|
||||
"""
|
||||
total_weighted_best_wer = 0
|
||||
total_weighted_oracle_wer = 0
|
||||
total_length = 0
|
||||
with open(resultjsonl) as resultfp:
|
||||
for line in resultfp:
|
||||
elems = json.loads(line.strip())
|
||||
if "hypothesis" in elems:
|
||||
uttid = elems["utt_id"]
|
||||
params = list(elems["hypothesis"].keys())
|
||||
uttid2hypos[uttid] = elems["hypothesis"][params[0]]
|
||||
if "nbest_model_scores" in elems:
|
||||
uttid2scores[uttid] = np.array(elems["nbest_model_scores"][params[0]])
|
||||
else:
|
||||
print_rank("Result: {}".format(line.strip()))
|
||||
if dumpfp is not None:
|
||||
dumpfp.write("{}{}\n".format(dump_msg, line.strip()))
|
||||
params = list(elems["wer-"].keys())
|
||||
total_weighted_best_wer += elems["wer-"][params[0]]["best_wer"] * elems["wer-"][params[0]]["total_length"]
|
||||
total_weighted_oracle_wer += elems["wer-"][params[0]]["oracle_wer"] * elems["wer-"][params[0]]["total_length"]
|
||||
total_length += elems["wer-"][params[0]]["total_length"]
|
||||
|
||||
return uttid2hypos, uttid2scores, total_weighted_best_wer, total_weighted_oracle_wer, total_length
|
||||
|
||||
|
||||
def find_pretrained_model(model_path, config):
|
||||
""""Load a a pre-trained/seed model if provided in config file."""
|
||||
output_file=None
|
||||
|
||||
if config.get("pretrained_model_path", None):
|
||||
output_file=config["pretrained_model_path"]
|
||||
|
||||
print_rank('Loading Model from: {}'.format(output_file), loglevel=logging.INFO)
|
||||
return output_file
|
||||
|
||||
|
||||
def flatten_grads_model(learner) -> np.ndarray:
|
||||
"""Given a model flatten all params and return as np array."""
|
||||
|
||||
return np.concatenate([w.grad.detach().clone().cpu().numpy().flatten() for w in learner.parameters()])
|
||||
|
||||
def flatten_grads_array(param_array)->np.array:
|
||||
"""Given a model flatten all params and return as np array."""
|
||||
|
||||
N=len(param_array)
|
||||
tmp_array=[]
|
||||
for i in range(N):
|
||||
tmp_array.append(np.concatenate([w.detach().clone().cpu().numpy().flatten() for w in param_array[i]]))
|
||||
return np.array(tmp_array)
|
||||
|
||||
def dist_weights_to_model(weights, parameters):
|
||||
"""Updates the model parameters with the supplied weights."""
|
||||
|
||||
offset = 0
|
||||
for param in parameters:
|
||||
new_size = functools.reduce(lambda x, y: x*y, param.shape)
|
||||
current_data = weights[offset:offset + new_size]
|
||||
param.data[:] = torch.from_numpy(current_data.reshape(param.shape)).to(param.data)
|
||||
offset += new_size
|
||||
|
||||
def dist_params_to_model(grads, model):
|
||||
"""Updates the model gradients (Corresponding to each param) with the supplied grads."""
|
||||
|
||||
offset = 0
|
||||
for p in model:
|
||||
new_size = functools.reduce(lambda x, y: x*y, p.data.shape)
|
||||
current_data = torch.from_numpy(grads[offset:offset + new_size].reshape(p.data.shape)).type(p.data.dtype).to(p)
|
||||
p.grad = current_data if p.grad==None else p.grad+current_data
|
||||
offset += new_size
|
||||
|
||||
def reshape_params_to_model(grads, model):
|
||||
""" Given Gradients and a model architecture this method updates the model gradients (Corresponding to each param)
|
||||
with the supplied grads """
|
||||
offset = 0
|
||||
reshaped_grads=[]
|
||||
for p in model:
|
||||
new_size = functools.reduce(lambda x, y: x*y, p.shape)
|
||||
current_data = torch.from_numpy(grads[offset:offset + new_size].reshape(p.shape)).type(p.dtype).to(p)
|
||||
reshaped_grads.append(current_data)
|
||||
offset += new_size
|
||||
return reshaped_grads
|
||||
|
||||
def _to_cuda(x):
|
||||
return x.cuda() if torch.cuda.is_available() else x
|
||||
|
||||
def update_json_log(log_path, status_info):
|
||||
"""Update J-son elements"""
|
||||
|
||||
elems = {}
|
||||
if os.path.exists(log_path):
|
||||
with open(log_path, 'r') as logfp:
|
||||
elems = json.load(logfp)
|
||||
print_rank("Loaded status info: {}".format(elems))
|
||||
|
||||
for k, v in status_info.items():
|
||||
elems[k] = v
|
||||
|
||||
with open(log_path, 'w') as logfp:
|
||||
json.dump(elems, logfp)
|
||||
print_rank("Updated status info: {}".format(elems))
|
||||
|
||||
|
||||
def scrub_empty_clients(data_strct):
|
||||
""" Clean empty clients in the data structure"""
|
||||
|
||||
users_out = []
|
||||
user_data_out = {}
|
||||
num_samples_out = []
|
||||
if 'user_data_label' in data_strct.keys():
|
||||
user_data_label_out = {}
|
||||
for ix, user in enumerate(data_strct['users']):
|
||||
if data_strct['num_samples'][ix] > 0:
|
||||
users_out.append(user)
|
||||
user_data_out[user] = data_strct['user_data'][user]
|
||||
num_samples_out.append(data_strct['num_samples'][ix])
|
||||
if 'user_data_label' in data_strct.keys():
|
||||
user_data_label_out[user] = data_strct['user_data_label'][user]
|
||||
|
||||
if ('user_data_label' in data_strct.keys()):
|
||||
return edict({'users': users_out, 'user_data': user_data_out, 'num_samples': num_samples_out, 'user_data_label': user_data_label_out})
|
||||
else:
|
||||
return edict({'users': users_out, 'user_data': user_data_out, 'num_samples': num_samples_out})
|
||||
|
||||
|
||||
def compute_grad_cosines(grads, model_grad):
|
||||
def compute_cosine(g, m):
|
||||
tot = 0
|
||||
g2 = 0
|
||||
m2 = 0
|
||||
for p1, p2 in zip(g, m):
|
||||
tot += torch.mul(p1, p2.to('cpu')).sum().item()
|
||||
g2 += torch.mul(p1, p1).sum().item()
|
||||
m2 += torch.mul(p2, p2).sum().item()
|
||||
return tot / (np.sqrt(g2) * np.sqrt(m2)) if g2 > 0 and m2 > 0 else 0
|
||||
return [compute_cosine(g, model_grad) for g in grads]
|
Загрузка…
Ссылка в новой задаче