add horovod distributed training to the gensen model and make the training stop with small validation loss
This commit is contained in:
Родитель
2d5bfe6862
Коммит
5137d91c46
|
@ -16,11 +16,11 @@
|
|||
"# GenSen Deep Dive on AzureML\n",
|
||||
"**Learning General Purpose Distributed Sentence Representations via Large Scale Multi-task Learning** [\\[1\\]](#References)\n",
|
||||
"\n",
|
||||
"### What is sentence similarity?\n",
|
||||
"## What is sentence similarity?\n",
|
||||
"\n",
|
||||
"Sentence similarity or semantic textual similarity deals with determining how similar two pieces of texts are. This can take the form of assigning a score from 1 to 5. Related tasks are parahrase or duplicate identification.\n",
|
||||
"\n",
|
||||
"### How to evaluate?\n",
|
||||
"## How to evaluate?\n",
|
||||
"\n",
|
||||
"[SentEval](https://arxiv.org/abs/1803.05449) [\\[2\\]](#References) is an evaluation toolkit for evaluating sentence representations. It includes 17 downstream tasks, including common semantic textual similarity tasks. The semantic textual similarity (**STS**) benchmark tasks from 2012-2016 (STS12, STS13, STS14, STS15, STS16, STSB) measure the relatedness of two sentences based on the cosine similarity of the two representations. The evaluation criterion is Pearson correlation.\n",
|
||||
"\n",
|
||||
|
@ -28,11 +28,23 @@
|
|||
"\n",
|
||||
"The Microsoft Research Paraphrase Corpus [(**MRPC**)](https://www.microsoft.com/en-us/download/details.aspx?id=52398) corpus is a paraphrase identification dataset, where systems aim to identify if two sentences are paraphrases of each other. The evaluation metric is classification accuracy and F1.\n",
|
||||
"\n",
|
||||
"### What is GenSen?\n",
|
||||
"## What is GenSen?\n",
|
||||
"\n",
|
||||
"GenSen is a technique to learn general purpose, fixed-length representations of sentences via multi-task training. GenSen is to combine the benefits of These representations are useful for transfer and low-resource learning. GenSen is trained on several data sources with multiple training objectives on over 100 milion sentences.\n",
|
||||
"GenSen is a technique to learn general purpose, fixed-length representations of sentences via multi-task training. GenSen model is to combine the benefits of diverse sentence-representation learning objectives into a single multi-task framework. This is the first large-scale reusable sentence representation model obtained by combining a set of training objectives with the level of diversity explored here, i.e. multi-lingual NMT, natural language inference, constituency parsing and skip-thought vectors. These representations are useful for transfer and low-resource learning. GenSen is trained on several data sources with multiple training objectives on over 100 milion sentences.\n",
|
||||
"\n",
|
||||
"### Why GenSen?\n",
|
||||
"The GenSen model is most similar to that of Luong et al. (2015) [\\[4\\]](#References), who train a many-to-many **sequence-to-sequence** model on a diverse set of weakly ralated tasks that includes machine translation, constituency parsing, image captioning, sequence autoencoding, and intra-sentence skip-thoughts. However, there are two key differences. GenSen uses an attention mechanism preventing learning a fixed-length vector representation for a sentence and it aims for learning re-usable sentence representations that transfers elsewhere, as opposed to Luong's work aims for improvements on the same tasks on which the model is trained.\n",
|
||||
"\n",
|
||||
"### Sequence to Sequence Learning\n",
|
||||
"\n",
|
||||
"![Sequence to sequence learning examples - (left) machine translation and (right) constituent parsing](img/seq2seq.png)**Sequence to sequence learning examples - (left) machine translation and (right) constituent parsing**\n",
|
||||
"\n",
|
||||
"Sequence to sequence learning (*seq2seq*) aims to directly model the conditional probability $p(x|y)$ of mapping an input sequence, $x_1,...,x_n$, into an output sequence, $y_1,...,y_m$. It accomplishes such goal through the *encoder-decoder* framework. As illustrated in the above figure, the encoder computes a representation $s$ for each input sequence. Based on that input representation, the *decoder* generates an ouput sequence, one unit at a time, and hence, decomposes the conditional probability as:\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"\\log p(y|x)=\\sum_{j=1}^{m} \\log p(y_i|y_{<j}, x, s)\n",
|
||||
"$$\n",
|
||||
"\n",
|
||||
"## Why GenSen?\n",
|
||||
"\n",
|
||||
"GenSen model performs the state-of-the-art results on multiple datasets, such as MRPC, SICK-R, SICK-E and STS, for sentence similarity. The reported results are as follows compared with other models [\\[3\\]](#References):\n",
|
||||
"\n",
|
||||
|
@ -65,7 +77,7 @@
|
|||
" * 1.2. [Tokenize](#1.2-Tokenize) \n",
|
||||
" * 1.3. [Preprocess for GenSen Model](#1.3-Preprocess-for-GenSen-Model) \n",
|
||||
" * 1.4. [Upload to Azure Blob Storage](#1.4-Upload-to-Azure-Blob-Storage) \n",
|
||||
"2. [Train GenSen Model with Distributed Pytorch on AzureML](#2-Train-GenSen-Model-with-Distributed-Pytorch-on-AzureML) \n",
|
||||
"2. [Train GenSen Model with Distributed Pytorch with Horovod on AzureML](#2-Train-GenSen-Model-with-Distributed-Pytorch-with-Horovod-on-AzureML) \n",
|
||||
" * 2.1. [Initialization](#2.1-Initialization) \n",
|
||||
" * 2.1.1 [Initialize Workspace](#2.1.1-Initialize-Workspace) \n",
|
||||
" * 2.1.2 [Create or Attach Existing AmlCompute](#2.1.2-Create-or-Attach-Existing-AmlCompute) \n",
|
||||
|
@ -81,7 +93,7 @@
|
|||
"3. [Tune Model Hyperparameters](#3-Tune-Model-Hyperparameters)\n",
|
||||
" * 3.1 [Start a Hyperparameter Sweep](#3.1-Start-a-Hyperparameter-Sweep)\n",
|
||||
" * 3.2 [Monitor HyperDrive runs](#3.2-Monitor-HyperDrive-runs)\n",
|
||||
"* [Evaluate Model by SentEval](#Comparison-of-Baseline-Models)"
|
||||
"- [References](#References)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -188,7 +200,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 42,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -527,21 +539,9 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 47,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'BASE_DATA_PATH' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[1;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"\u001b[1;32m<ipython-input-1-7876e18c3dbe>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mdata_folder\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mBASE_DATA_PATH\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m\"clean/snli_1.0/\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m",
|
||||
"\u001b[1;31mNameError\u001b[0m: name 'BASE_DATA_PATH' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_folder = os.path.join(BASE_DATA_PATH, \"clean/snli_1.0/\")"
|
||||
]
|
||||
|
@ -550,7 +550,7 @@
|
|||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# 2 Train GenSen Model with Distributed Pytorch on AzureML\n",
|
||||
"# 2 Train GenSen Model with Distributed Pytorch with Horovod on AzureML\n",
|
||||
"In this tutorial, you will train a GenSen model with PyTorch on AML using distributed training across a GPU cluster. This could also be a generic guideline to train models using GPU cluster.\n",
|
||||
"\n",
|
||||
"Once you've created your workspace and set up your development environment, training a model in Azure Machine Learning involves the following steps:\n",
|
||||
|
@ -580,7 +580,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -620,7 +620,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -628,7 +628,7 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Found existing compute target.\n",
|
||||
"{'currentNodeCount': 0, 'targetNodeCount': 0, 'nodeStateCounts': {'preparingNodeCount': 0, 'runningNodeCount': 0, 'idleNodeCount': 0, 'unusableNodeCount': 0, 'leavingNodeCount': 0, 'preemptedNodeCount': 0}, 'allocationState': 'Steady', 'allocationStateTransitionTime': '2019-05-21T00:01:54.616000+00:00', 'errors': None, 'creationTime': '2019-05-20T22:09:40.142683+00:00', 'modifiedTime': '2019-05-20T22:10:11.888950+00:00', 'provisioningState': 'Succeeded', 'provisioningStateTransitionTime': None, 'scaleSettings': {'minNodeCount': 0, 'maxNodeCount': 4, 'nodeIdleTimeBeforeScaleDown': 'PT120S'}, 'vmPriority': 'Dedicated', 'vmSize': 'STANDARD_NC6'}\n"
|
||||
"{'currentNodeCount': 0, 'targetNodeCount': 0, 'nodeStateCounts': {'preparingNodeCount': 0, 'runningNodeCount': 0, 'idleNodeCount': 0, 'unusableNodeCount': 0, 'leavingNodeCount': 0, 'preemptedNodeCount': 0}, 'allocationState': 'Steady', 'allocationStateTransitionTime': '2019-05-30T18:42:14.260000+00:00', 'errors': None, 'creationTime': '2019-05-20T22:09:40.142683+00:00', 'modifiedTime': '2019-05-20T22:10:11.888950+00:00', 'provisioningState': 'Succeeded', 'provisioningStateTransitionTime': None, 'scaleSettings': {'minNodeCount': 0, 'maxNodeCount': 4, 'nodeIdleTimeBeforeScaleDown': 'PT120S'}, 'vmPriority': 'Dedicated', 'vmSize': 'STANDARD_NC6'}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
@ -676,7 +676,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 20,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -700,7 +700,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -723,7 +723,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -732,7 +732,7 @@
|
|||
"$AZUREML_DATAREFERENCE_gensen"
|
||||
]
|
||||
},
|
||||
"execution_count": 25,
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -779,7 +779,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -801,7 +801,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 39,
|
||||
"execution_count": 55,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -821,7 +821,7 @@
|
|||
" distributed_backend='mpi',\n",
|
||||
" use_gpu=True,\n",
|
||||
" conda_packages=['scikit-learn=0.20.3']\n",
|
||||
" )"
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -853,7 +853,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 40,
|
||||
"execution_count": 56,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
|
@ -861,7 +861,7 @@
|
|||
"output_type": "stream",
|
||||
"text": [
|
||||
"Run(Experiment: pytorch-gensen,\n",
|
||||
"Id: pytorch-gensen_1559160755_910c17f3,\n",
|
||||
"Id: pytorch-gensen_1559254459_202b7e15,\n",
|
||||
"Type: azureml.scriptrun,\n",
|
||||
"Status: Queued)\n"
|
||||
]
|
||||
|
@ -883,7 +883,7 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 38,
|
||||
"execution_count": 51,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -905,13 +905,13 @@
|
|||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 41,
|
||||
"execution_count": 57,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "7a3652129ef147258e649eaaf2e7a83c",
|
||||
"model_id": "e452ca03e8d2473a85df3b27dd2defad",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
|
@ -1548,7 +1548,8 @@
|
|||
"\n",
|
||||
"1. Subramanian, Sandeep and Trischler, Adam and Bengio, Yoshua and Pal, Christopher J, [*Learning general purpose distributed sentence representations via large scale multi-task learning*](https://arxiv.org/abs/1804.00079), ICLR, 2018.\n",
|
||||
"2. A. Conneau, D. Kiela, [*SentEval: An Evaluation Toolkit for Universal Sentence Representations*](https://arxiv.org/abs/1803.05449).\n",
|
||||
"3. Semantic textual similarity. url: http://nlpprogress.com/english/semantic_textual_similarity.html"
|
||||
"3. Semantic textual similarity. url: http://nlpprogress.com/english/semantic_textual_similarity.html\n",
|
||||
"4. Minh-Thang Luong, Quoc V Le, Ilya Sutskever, Oriol Vinyals, and Lukasz Kaiser. [*Multi-task sequence to sequence learning*](https://arxiv.org/abs/1511.06114), 2015."
|
||||
]
|
||||
}
|
||||
],
|
||||
|
|
|
@ -267,7 +267,6 @@ class MultitaskModel(nn.Module):
|
|||
h_t = h_t.unsqueeze(0)
|
||||
h_t = self.enc_drp(h_t)
|
||||
|
||||
print("INSIDE FORWARD:", h_t.shape)
|
||||
# Debug with squeeze on error.
|
||||
trg_h, _ = self.decoders[task_idx](
|
||||
trg_emb, h_t.view(-1, self.trg_hidden_dim), h_t.view(-1, self.trg_hidden_dim)
|
||||
|
@ -336,7 +335,6 @@ class MultitaskModel(nn.Module):
|
|||
h_t = src_h_t[-1]
|
||||
else:
|
||||
src_h, _ = pad_packed_sequence(src_h, batch_first=True)
|
||||
print("INSIDE GET HIDDEN",torch.max(src_h, 1)[0].shape)
|
||||
h_t = torch.max(src_h, 1)[0].squeeze()
|
||||
|
||||
return src_h, h_t
|
||||
|
|
|
@ -4,7 +4,8 @@
|
|||
"clip_c": 1,
|
||||
"lrate": 0.0001,
|
||||
"batch_size": 48,
|
||||
"n_gpus": 1
|
||||
"n_gpus": 1,
|
||||
"stop_patience": 10000
|
||||
},
|
||||
"management": {
|
||||
"monitor_loss": 9600,
|
||||
|
@ -14,31 +15,31 @@
|
|||
},
|
||||
"data": {"paths": [
|
||||
{
|
||||
"train_src": "snli_1.0_train.txt.s1.tok",
|
||||
"train_trg": "snli_1.0_train.txt.s2.tok",
|
||||
"val_src": "snli_1.0_dev.txt.s1.tok",
|
||||
"val_trg": "snli_1.0_dev.txt.s2.tok",
|
||||
"train_src": "data/processed/snli_1.0_train.txt.s1.tok",
|
||||
"train_trg": "data/processed/snli_1.0_train.txt.s2.tok",
|
||||
"val_src": "data/processed/snli_1.0_dev.txt.s1.tok",
|
||||
"val_trg": "data/processed/snli_1.0_dev.txt.s1.tok",
|
||||
"taskname": "snli"
|
||||
}
|
||||
],
|
||||
"max_src_length": 90,
|
||||
"max_trg_length": 90,
|
||||
"task": "multi-seq2seq-nli",
|
||||
"save_dir": "model",
|
||||
"save_dir": "data/models/example",
|
||||
"load_dir": "auto",
|
||||
"nli_train": "snli_1.0_train.txt.clean.noblank",
|
||||
"nli_dev": "snli_1.0_dev.txt.clean.noblank",
|
||||
"nli_test": "snli_1.0_test.txt.clean.noblank"
|
||||
},
|
||||
"nli_train": "data/processed/snli_1.0_train.txt.clean.noblank",
|
||||
"nli_dev": "data/processed/snli_1.0_dev.txt.clean.noblank",
|
||||
"nli_test": "data/processed/snli_1.0_test.txt.clean.noblank"
|
||||
},
|
||||
"model": {
|
||||
"dim_src": 2048,
|
||||
"dim_trg": 2048,
|
||||
"dim_word_src": 512,
|
||||
"dim_word_trg": 512,
|
||||
"n_words_src": 80000,
|
||||
"n_words_trg": 30000,
|
||||
"n_layers_src": 1,
|
||||
"bidirectional": true,
|
||||
"dim_src": 2048,
|
||||
"dim_trg": 2048,
|
||||
"dim_word_src": 512,
|
||||
"dim_word_trg": 512,
|
||||
"n_words_src": 80000,
|
||||
"n_words_trg": 30000,
|
||||
"n_layers_src": 1,
|
||||
"bidirectional": true,
|
||||
"layernorm": false,
|
||||
"dropout": 0.3
|
||||
}
|
||||
|
|
|
@ -2,8 +2,10 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""Run script."""
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
@ -13,9 +15,10 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from azureml.core.run import Run
|
||||
import horovod.torch as hvd
|
||||
|
||||
from utils_nlp.model.gensen.models import MultitaskModel
|
||||
from utils_nlp.model.gensen.utils import (
|
||||
from models import MultitaskModel
|
||||
from utils import (
|
||||
BufferedDataIterator,
|
||||
NLIIterator,
|
||||
compute_validation_loss,
|
||||
|
@ -23,21 +26,16 @@ from utils_nlp.model.gensen.utils import (
|
|||
|
||||
sys.path.append(".") # Required to run on the MILA machines with SLURM
|
||||
|
||||
# parser = argparse.ArgumentParser()
|
||||
# parser.add_argument('--data_folder', type=str, help='data folder')
|
||||
#
|
||||
# args = parser.parse_args()
|
||||
#
|
||||
# input_data = args.input_data
|
||||
# import os
|
||||
# os.chdir(input_data)
|
||||
# print("the input data is at %s" % input_data)
|
||||
#
|
||||
# get the Azure ML run object
|
||||
run = Run.get_context()
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
hvd.init()
|
||||
if torch.cuda.is_available():
|
||||
# Horovod: pin GPU to local rank.
|
||||
torch.cuda.set_device(hvd.local_rank())
|
||||
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", help="path to json config", required=True)
|
||||
|
@ -186,9 +184,25 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
).cuda()
|
||||
|
||||
logging.info(model)
|
||||
"""Using Horovod"""
|
||||
# Horovod: scale learning rate by the number of GPUs.
|
||||
optimizer = optim.SGD(model.parameters(), lr=learning_rate * hvd.size(),
|
||||
momentum=args.momentum)
|
||||
|
||||
# Horovod: broadcast parameters & optimizer state.
|
||||
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
|
||||
hvd.broadcast_optimizer_state(optimizer, root_rank=0)
|
||||
|
||||
# Horovod: (optional) compression algorithm.
|
||||
compression = hvd.Compression.fp16 if args.fp16_allreduce else hvd.Compression.none
|
||||
|
||||
# Horovod: wrap optimizer with DistributedOptimizer.
|
||||
optimizer = hvd.DistributedOptimizer(optimizer,
|
||||
named_parameters=model.named_parameters(),
|
||||
compression=compression)
|
||||
|
||||
n_gpus = config["training"]["n_gpus"]
|
||||
model = torch.nn.DataParallel(model, device_ids=range(n_gpus))
|
||||
# model = torch.nn.DataParallel(model, device_ids=range(n_gpus))
|
||||
|
||||
if load_dir == "auto":
|
||||
ckpt = os.path.join(save_dir, "best_model.model")
|
||||
|
@ -203,9 +217,9 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
logging.info("Loading model from specified checkpoint %s " % load_dir)
|
||||
model.load_state_dict(torch.load(open(load_dir, encoding="utf-8")))
|
||||
|
||||
lr = learning_rate
|
||||
# lr = learning_rate
|
||||
# lr = config['training']['lrate']
|
||||
optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
# optimizer = optim.Adam(model.parameters(), lr=lr)
|
||||
|
||||
task_losses = [[] for task in tasknames]
|
||||
task_idxs = [0 for task in tasknames]
|
||||
|
@ -218,6 +232,9 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
logging.info("Commencing Training ...")
|
||||
rng_num_tasks = len(tasknames) - 1 if paired_tasks else len(tasknames)
|
||||
|
||||
min_val_loss = 10000000
|
||||
min_val_loss_epoch = -1
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
# Train NLI once every 10 minibatches of other tasks
|
||||
|
@ -320,6 +337,8 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
torch.nn.utils.clip_grad_norm(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
|
||||
|
||||
|
||||
end = time.time()
|
||||
mbatch_times.append(end - start)
|
||||
|
||||
|
@ -363,18 +382,18 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
mbatch_times = []
|
||||
nli_losses = []
|
||||
|
||||
if (
|
||||
updates % config["management"]["checkpoint_freq"] == 0
|
||||
and updates != 0
|
||||
):
|
||||
logging.info("Saving model ...")
|
||||
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
)
|
||||
# Let the training end.
|
||||
break
|
||||
# if (
|
||||
# updates % config["management"]["checkpoint_freq"] == 0
|
||||
# and updates != 0
|
||||
# ):
|
||||
# logging.info("Saving model ...")
|
||||
#
|
||||
# torch.save(
|
||||
# model.state_dict(),
|
||||
# open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
# )
|
||||
# # Let the training end.
|
||||
# break
|
||||
|
||||
if updates % config["management"]["eval_freq"] == 0:
|
||||
logging.info("############################")
|
||||
|
@ -446,8 +465,53 @@ def train(config, data_folder, learning_rate=0.0001):
|
|||
logging.info(
|
||||
"******************************************************"
|
||||
)
|
||||
# If the validation loss is small enough, and it starts to go up. Should stop training.
|
||||
# Small is defined by the number of epochs it lasts.
|
||||
if validation_loss < min_val_loss:
|
||||
min_val_loss = validation_loss
|
||||
min_val_loss_epoch = updates
|
||||
|
||||
if updates - min_val_loss_epoch > config['training']['stop_patience']:
|
||||
logging.info("Saving model ...")
|
||||
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
open(os.path.join(save_dir, "best_model.model"), "wb"),
|
||||
)
|
||||
# Let the training end.
|
||||
break
|
||||
|
||||
updates += batch_size * n_gpus
|
||||
nli_ctr += 1
|
||||
finally:
|
||||
os.chdir(owd)
|
||||
|
||||
|
||||
def read_config(json_file):
|
||||
"""Read JSON config."""
|
||||
json_object = json.load(open(json_file, 'r', encoding='utf-8'))
|
||||
return json_object
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
help="path to json config",
|
||||
required=True
|
||||
)
|
||||
parser.add_argument('--data_folder', type=str, help='data folder')
|
||||
# Add learning rate to tune model.
|
||||
parser.add_argument('--learning_rate', type=float, default=0.0001, help='learning rate')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--fp16-allreduce', action='store_true', default=False,
|
||||
help='use fp16 compression during allreduce')
|
||||
args = parser.parse_args()
|
||||
data_folder = args.data_folder
|
||||
learning_rate = args.learning_rate
|
||||
# os.chdir(data_folder)
|
||||
|
||||
config_file_path = args.config
|
||||
config = read_config(config_file_path)
|
||||
train(config, data_folder, learning_rate)
|
Загрузка…
Ссылка в новой задаче