Merge branch 'master' of github.com:Microsoft/NeuronBlocks into dev/wutlin
This commit is contained in:
Коммит
627d3216c5
|
@ -2,5 +2,7 @@
|
|||
*~
|
||||
*.pyc
|
||||
*.cache*
|
||||
*.vs*
|
||||
dataset/GloVe/
|
||||
dataset/20_newsgroups/
|
||||
models/
|
||||
|
|
|
@ -10,6 +10,7 @@ import string
|
|||
import copy
|
||||
import torch
|
||||
import logging
|
||||
import shutil
|
||||
|
||||
from losses.BaseLossConf import BaseLossConf
|
||||
#import traceback
|
||||
|
@ -60,8 +61,8 @@ class ModelConf(object):
|
|||
self.tool_version = self.get_item(['tool_version'])
|
||||
self.language = self.get_item(['language'], default='english').lower()
|
||||
self.problem_type = self.get_item(['inputs', 'dataset_type']).lower()
|
||||
if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)
|
||||
#if ProblemTypes[self.problem_type] == ProblemTypes.sequence_tagging:
|
||||
self.tagging_scheme = self.get_item(['inputs', 'tagging_scheme'], default=None, use_default=True)
|
||||
|
||||
if self.mode == 'normal':
|
||||
self.use_cache = self.get_item(['inputs', 'use_cache'], True)
|
||||
|
@ -519,3 +520,7 @@ class ModelConf(object):
|
|||
if not (nb_version_split[0] == conf_version_split[0] and nb_version_split[1] == conf_version_split[1]):
|
||||
raise ConfigurationError('The NeuronBlocks version is %s, but the configuration version is %s, please update your configuration to %s.%s.X' % (nb_version, conf_version, nb_version_split[0], nb_version_split[1]))
|
||||
|
||||
def back_up(self, params):
|
||||
shutil.copy(params.conf_path, self.save_base_dir)
|
||||
logging.info('Configuration file is backed up to %s' % (self.save_base_dir))
|
||||
|
||||
|
|
10
README.md
10
README.md
|
@ -7,6 +7,8 @@
|
|||
|
||||
[简体中文](README_zh_CN.md)
|
||||
|
||||
[Tutorial](Tutorial.md) [中文教程](Tutorial_zh_CN.md)
|
||||
|
||||
|
||||
# Table of Contents
|
||||
* [Overview](#Overview)
|
||||
|
@ -129,6 +131,14 @@ Anyone who are familiar with are highly encouraged to contribute code.
|
|||
|
||||
# Reference
|
||||
**NeuronBlocks -- Building Your NLP DNN Models Like Playing Lego**, at https://arxiv.org/abs/1904.09535.
|
||||
```
|
||||
@article{gong2019neuronblocks,
|
||||
title={NeuronBlocks--Building Your NLP DNN Models Like Playing Lego},
|
||||
author={Gong, Ming and Shou, Linjun and Lin, Wutao and Sang, Zhijie and Yan, Quanjia and Yang, Ze and Jiang, Daxin},
|
||||
journal={arXiv preprint arXiv:1904.09535},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
# Related Project
|
||||
* [OpenPAI](https://github.com/Microsoft/pai) is an open source platform that provides complete AI model training and resource management capabilities, it is easy to extend and supports on-premise, cloud and hybrid environments in various scale.
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
|
||||
[English version](README.md)
|
||||
|
||||
[中文教程](Tutorial_zh_CN.md) [Tutorial](Tutorial.md)
|
||||
|
||||
# 目录
|
||||
|
||||
* [概览](#概览)
|
||||
|
@ -129,7 +131,15 @@ NeuronBlocks以开放的模式运行。它由 **微软 STCA NLP Group** 设计
|
|||
我们鼓励感兴趣的用户一起加入我们贡献code.
|
||||
|
||||
# 参考文献
|
||||
[**NeuronBlocks -- Building Your NLP DNN Models Like Playing Lego**](https://arxiv.org/abs/1904.09535).
|
||||
**NeuronBlocks -- Building Your NLP DNN Models Like Playing Lego**, at https://arxiv.org/abs/1904.09535.
|
||||
```
|
||||
@article{gong2019neuronblocks,
|
||||
title={NeuronBlocks--Building Your NLP DNN Models Like Playing Lego},
|
||||
author={Gong, Ming and Shou, Linjun and Lin, Wutao and Sang, Zhijie and Yan, Quanjia and Yang, Ze and Jiang, Daxin},
|
||||
journal={arXiv preprint arXiv:1904.09535},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
# 相关项目
|
||||
* [OpenPAI](https://github.com/Microsoft/pai): 作为开源平台,提供了完整的 AI 模型训练和资源管理能力,能轻松扩展,并支持各种规模的私有部署、云和混合环境。
|
||||
|
|
|
@ -300,6 +300,7 @@ Question answer matching is a crucial subtask of the question answering problem,
|
|||
BiLSTM (NeuronBlocks) | 0.767
|
||||
BiLSTM+Attn (NeuronBlocks) | 0.754
|
||||
BiLSTM+Match Attention (NeuronBlocks) | 0.785
|
||||
[MatchPyramid](https://arxiv.org/abs/1602.06359) (NeuronBlocks) | 0.763
|
||||
|
||||
*Tips: the model file and train log file can be found in JOSN config file's outputs/save_base_dir after you finish training.*
|
||||
|
||||
|
|
|
@ -289,6 +289,7 @@ Question answer matching is a crucial subtask of the question answering problem,
|
|||
BiLSTM (NeuronBlocks) | 0.767
|
||||
BiLSTM+Attn (NeuronBlocks) | 0.754
|
||||
BiLSTM+Match Attention (NeuronBlocks) | 0.785
|
||||
[MatchPyramid](https://arxiv.org/abs/1602.06359) (NeuronBlocks) | 0.763
|
||||
|
||||
*Tips: the model file and train log file can be found in JOSN config file's outputs/save_base_dir after you finish training.*
|
||||
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from collections import OrderedDict
|
||||
|
||||
from block_zoo.BaseLayer import BaseLayer, BaseConf
|
||||
from utils.DocInherit import DocInherit
|
||||
|
||||
|
||||
class Conv2DConf(BaseConf):
|
||||
""" Configuration of Conv
|
||||
|
||||
Args:
|
||||
stride (int): the stride of the convolving kernel. Can be a single number or a tuple (sH, sW). Default: 1
|
||||
padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0
|
||||
window_size (int): actually, the window size is (window_sizeH, window_sizeW), because for NLP tasks, 1d convolution is more commonly used.
|
||||
output_channel_num (int): number of feature maps
|
||||
batch_norm (bool): If True, apply batch normalization before activation
|
||||
activation (string): activation functions, e.g. ReLU
|
||||
|
||||
"""
|
||||
def __int__(self, **kwargs):
|
||||
super(Conv2DConf, self).__init__(**kwargs)
|
||||
|
||||
@DocInherit
|
||||
def default(self):
|
||||
self.stride = 1
|
||||
self.padding = 0
|
||||
self.window_size = 3
|
||||
self.output_channel_num = 16
|
||||
self.batch_norm = True
|
||||
self.activation = 'ReLU'
|
||||
|
||||
@DocInherit
|
||||
def declare(self):
|
||||
self.num_of_inputs = 1
|
||||
self.input_ranks = [4]
|
||||
|
||||
def check_size(self, value, attr):
|
||||
res = value
|
||||
if isinstance(value,int):
|
||||
res = [value, value]
|
||||
elif (isinstance(self.window_size, tuple) or isinstance(self.window_size, list)) and len(value)==2:
|
||||
res = list(value)
|
||||
else:
|
||||
raise AttributeError("The Atrribute `%s' should be given an integer or a list/tuple with length of 2, instead of %s." %(attr,str(value)))
|
||||
return res
|
||||
|
||||
@DocInherit
|
||||
def inference(self):
|
||||
self.window_size = self.check_size(self.window_size, "window_size")
|
||||
self.stride = self.check_size(self.stride, "stride")
|
||||
self.padding = self.check_size(self.padding, "padding")
|
||||
|
||||
self.input_channel_num = self.input_dims[0][-1]
|
||||
|
||||
self.output_dim = [self.input_dims[0][0]]
|
||||
if self.input_dims[0][1] != -1:
|
||||
self.output_dim.append((self.input_dims[0][1] + 2 * self.padding[0] - self.window_size[0]) // self.stride[0] + 1)
|
||||
else:
|
||||
self.output_dim.append(-1)
|
||||
if self.input_dims[0][2] != -1:
|
||||
self.output_dim.append((self.input_dims[0][2] + 2 * self.padding[1] - self.window_size[1]) // self.stride[1] + 1)
|
||||
else:
|
||||
self.output_dim.append(-1)
|
||||
self.output_dim.append(self.output_channel_num)
|
||||
|
||||
super(Conv2DConf, self).inference() # PUT THIS LINE AT THE END OF inference()
|
||||
|
||||
|
||||
@DocInherit
|
||||
def verify_before_inference(self):
|
||||
super(Conv2DConf, self).verify_before_inference()
|
||||
necessary_attrs_for_user = ['output_channel_num']
|
||||
for attr in necessary_attrs_for_user:
|
||||
self.add_attr_exist_assertion_for_user(attr)
|
||||
|
||||
@DocInherit
|
||||
def verify(self):
|
||||
super(Conv2DConf, self).verify()
|
||||
|
||||
necessary_attrs_for_user = ['stride', 'padding', 'window_size', 'input_channel_num', 'output_channel_num', 'activation']
|
||||
for attr in necessary_attrs_for_user:
|
||||
self.add_attr_exist_assertion_for_user(attr)
|
||||
|
||||
|
||||
class Conv2D(BaseLayer):
|
||||
""" Convolution along just 1 direction
|
||||
|
||||
Args:
|
||||
layer_conf (ConvConf): configuration of a layer
|
||||
"""
|
||||
def __init__(self, layer_conf):
|
||||
super(Conv2D, self).__init__(layer_conf)
|
||||
self.layer_conf = layer_conf
|
||||
if layer_conf.activation:
|
||||
self.activation = eval("nn." + self.layer_conf.activation)()
|
||||
else:
|
||||
self.activation = None
|
||||
|
||||
self.cnn = nn.Conv2d(in_channels=layer_conf.input_channel_num, out_channels=layer_conf.output_channel_num,kernel_size=layer_conf.window_size,stride=layer_conf.stride,padding=layer_conf.padding)
|
||||
|
||||
if layer_conf.batch_norm:
|
||||
self.batch_norm = nn.BatchNorm2d(layer_conf.output_channel_num) # the output_chanel of Conv is the input_channel of BN
|
||||
else:
|
||||
self.batch_norm = None
|
||||
|
||||
def forward(self, string, string_len=None):
|
||||
""" process inputs
|
||||
|
||||
Args:
|
||||
string (Tensor): tensor with shape: [batch_size, length, width, feature_dim]
|
||||
string_len (Tensor): [batch_size]
|
||||
|
||||
Returns:
|
||||
Tensor: shape: [batch_size, (length - conv_window_size) // stride + 1, (width - conv_window_size) // stride + 1, output_channel_num]
|
||||
|
||||
"""
|
||||
|
||||
string = string.permute([0,3,1,2]).contiguous()
|
||||
string_out = self.cnn(string)
|
||||
if hasattr(self, 'batch_norms') and self.batch_norm:
|
||||
string_out = self.batch_norm(string_out)
|
||||
|
||||
string_out = string_out.permute([0,2,3,1]).contiguous()
|
||||
|
||||
if self.activation:
|
||||
string_out = self.activation(string_out)
|
||||
if string_len is not None:
|
||||
string_len_out = (string_len - self.layer_conf.window_size[0]) // self.layer_conf.stride[0] + 1
|
||||
else:
|
||||
string_len_out = None
|
||||
return string_out, string_len_out
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from block_zoo.BaseLayer import BaseLayer, BaseConf
|
||||
from utils.DocInherit import DocInherit
|
||||
|
||||
|
||||
class Pooling2DConf(BaseConf):
|
||||
"""
|
||||
|
||||
Args:
|
||||
pool_type (str): 'max' or 'mean', default is 'max'.
|
||||
stride (int): which axis to conduct pooling, default is 1.
|
||||
padding (int): implicit zero paddings on both sides of the input. Can be a single number or a tuple (padH, padW). Default: 0
|
||||
window_size (int): the size of the pooling
|
||||
activation (string): activation functions, e.g. ReLU
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(Pooling2DConf, self).__init__(**kwargs)
|
||||
|
||||
@DocInherit
|
||||
def default(self):
|
||||
self.pool_type = 'max' # Supported: ['max', mean']
|
||||
self.stride = 1
|
||||
self.padding = 0
|
||||
self.window_size = 3
|
||||
|
||||
@DocInherit
|
||||
def declare(self):
|
||||
self.num_of_inputs = 1
|
||||
self.input_ranks = [4]
|
||||
|
||||
def check_size(self, value, attr):
|
||||
res = value
|
||||
if isinstance(value,int):
|
||||
res = [value, value]
|
||||
elif (isinstance(self.window_size, tuple) or isinstance(self.window_size, list)) and len(value)==2:
|
||||
res = list(value)
|
||||
else:
|
||||
raise AttributeError("The Atrribute `%s' should be given an integer or a list/tuple with length of 2, instead of %s." %(attr,str(value)))
|
||||
return res
|
||||
|
||||
@DocInherit
|
||||
def inference(self):
|
||||
|
||||
self.window_size = self.check_size(self.window_size, "window_size")
|
||||
self.stride = self.check_size(self.stride, "stride")
|
||||
self.padding = self.check_size(self.padding, "padding")
|
||||
|
||||
self.output_dim = [self.input_dims[0][0]]
|
||||
if self.input_dims[0][1] != -1:
|
||||
self.output_dim.append((self.input_dims[0][1] + 2 * self.padding[0] - self.window_size[0]) // self.stride[0] + 1)
|
||||
else:
|
||||
self.output_dim.append(-1)
|
||||
if self.input_dims[0][2] != -1:
|
||||
self.output_dim.append((self.input_dims[0][2] + 2 * self.padding[1] - self.window_size[1]) // self.stride[1] + 1)
|
||||
else:
|
||||
self.output_dim.append(-1)
|
||||
# print("pool",self.output_dim)
|
||||
self.input_channel_num = self.input_dims[0][-1]
|
||||
|
||||
self.output_dim.append(self.input_dims[0][-1])
|
||||
|
||||
# DON'T MODIFY THIS
|
||||
self.output_rank = len(self.output_dim)
|
||||
|
||||
@DocInherit
|
||||
def verify(self):
|
||||
super(Pooling2DConf, self).verify()
|
||||
|
||||
necessary_attrs_for_user = ['pool_type']
|
||||
for attr in necessary_attrs_for_user:
|
||||
self.add_attr_exist_assertion_for_user(attr)
|
||||
|
||||
self.add_attr_value_assertion('pool_type', ['max', 'mean'])
|
||||
|
||||
assert all([input_rank >= 4 for input_rank in self.input_ranks]), "Cannot apply a pooling layer on a tensor of which the rank is less than 4. Usually, a tensor whose rank is at least 4, e.g. [batch size, length, width, feature]"
|
||||
|
||||
assert self.output_dim[-1] != -1, "The shape of input is %s , and the input channel number of pooling should not be -1." % (str(self.input_dims[0]))
|
||||
|
||||
class Pooling2D(BaseLayer):
|
||||
""" Pooling layer
|
||||
|
||||
Args:
|
||||
layer_conf (PoolingConf): configuration of a layer
|
||||
"""
|
||||
def __init__(self, layer_conf):
|
||||
super(Pooling2D, self).__init__(layer_conf)
|
||||
self.pool = None
|
||||
if layer_conf.pool_type == "max":
|
||||
self.pool = nn.MaxPool2d(kernel_size=layer_conf.window_size,stride=layer_conf.stride,padding=layer_conf.padding)
|
||||
elif layer_conf.pool_type == "mean":
|
||||
self.pool = nn.AvgPool2d(kernel_size=layer_conf.window_size,stride=layer_conf.stride,padding=layer_conf.padding)
|
||||
|
||||
def forward(self, string, string_len=None):
|
||||
""" process inputs
|
||||
|
||||
Args:
|
||||
string (Tensor): tensor with shape: [batch_size, length, width, feature_dim]
|
||||
string_len (Tensor): [batch_size], default is None.
|
||||
|
||||
Returns:
|
||||
Tensor: Pooling result of string
|
||||
|
||||
"""
|
||||
|
||||
string = string.permute([0,3,1,2]).contiguous()
|
||||
|
||||
string = self.pool(string)
|
||||
|
||||
string = string.permute([0,2,3,1]).contiguous()
|
||||
|
||||
return string, string_len
|
||||
|
||||
|
|
@ -15,6 +15,9 @@ from .ConvPooling import ConvPooling, ConvPoolingConf
|
|||
|
||||
from .Dropout import Dropout, DropoutConf
|
||||
|
||||
from .Conv2D import Conv2D, Conv2DConf
|
||||
from .Pooling2D import Pooling2D, Pooling2DConf
|
||||
|
||||
from .embedding import CNNCharEmbedding, CNNCharEmbeddingConf
|
||||
|
||||
from .attentions import FullAttention, FullAttentionConf
|
||||
|
@ -24,6 +27,7 @@ from .attentions import BiAttFlow, BiAttFlowConf
|
|||
from .attentions import MatchAttention, MatchAttentionConf
|
||||
from .attentions import Attention, AttentionConf
|
||||
from .attentions import BilinearAttention, BilinearAttentionConf
|
||||
from .attentions import Interaction, InteractionConf
|
||||
|
||||
# Operators
|
||||
from .op import *
|
||||
|
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import copy
|
||||
|
||||
from block_zoo.BaseLayer import BaseLayer, BaseConf
|
||||
from utils.DocInherit import DocInherit
|
||||
from utils.common_utils import transfer_to_gpu
|
||||
|
||||
|
||||
class InteractionConf(BaseConf):
|
||||
"""Configuration of Interaction Layer
|
||||
|
||||
Args:
|
||||
hidden_dim (int): dimension of hidden state
|
||||
matching_type (string): shoule be 'general', 'mul', 'plus', 'minus', 'dot', 'concat'
|
||||
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(InteractionConf, self).__init__(**kwargs)
|
||||
|
||||
@DocInherit
|
||||
def default(self):
|
||||
self.matching_type = 'general'
|
||||
|
||||
@DocInherit
|
||||
def declare(self):
|
||||
self.num_of_inputs = 2
|
||||
self.input_ranks = [3, 3]
|
||||
|
||||
@DocInherit
|
||||
def inference(self):
|
||||
shape1 = self.input_dims[0]
|
||||
shape2 = self.input_dims[1]
|
||||
if shape1[1] == -1 or shape2[1] == -1:
|
||||
raise ConfigurationError("For Interaction layer, the sequence length should be fixed")
|
||||
# print(shape1,shape2)
|
||||
self.output_dim = None
|
||||
if self.matching_type in ['mul', 'plus', 'minus']:
|
||||
self.output_dim = [shape1[0], shape1[1], shape2[1], shape1[2]]
|
||||
elif self.matching_type in ['dot', 'general']:
|
||||
self.output_dim = [shape1[0], shape1[1], shape2[1], 1]
|
||||
elif self.matching_type == 'concat':
|
||||
self.output_dim = [shape1[0], shape1[1], shape2[1], shape1[2] + shape2[2]]
|
||||
else:
|
||||
raise ValueError(f"Invalid `matching_type`."
|
||||
f"{self.matching_type} received."
|
||||
f"Must be in `mul`, `general`, `plus`, `minus` "
|
||||
f"`dot` and `concat`.")
|
||||
# print(self.output_dim)
|
||||
super(InteractionConf, self).inference() # PUT THIS LINE AT THE END OF inference()
|
||||
|
||||
@DocInherit
|
||||
def verify(self):
|
||||
super(InteractionConf, self).verify()
|
||||
assert hasattr(self, 'matching_type'), "Please define matching_type attribute of BiGRUConf in default() or the configuration file"
|
||||
assert self.matching_type in ['general', 'dot', 'mul', 'plus', 'minus', 'add'], "Invalid `matching_type`{self.matching_type} received. Must be in `mul`, `general`, `plus`, `minus`, `dot` and `concat`."
|
||||
|
||||
|
||||
class Interaction(BaseLayer):
|
||||
"""Bidirectional GRU
|
||||
|
||||
Args:
|
||||
layer_conf (BiGRUConf): configuration of a layer
|
||||
"""
|
||||
def __init__(self, layer_conf):
|
||||
super(Interaction, self).__init__(layer_conf)
|
||||
self.matching_type = layer_conf.matching_type
|
||||
shape1 = layer_conf.input_dims[0]
|
||||
shape2 = layer_conf.input_dims[1]
|
||||
if self.matching_type == 'general':
|
||||
self.linear_in = nn.Linear(shape1[-1], shape2[-1], bias=False)
|
||||
|
||||
def forward(self, string1, string1_len, string2, string2_len):
|
||||
""" process inputs
|
||||
|
||||
Args:
|
||||
string1 (Tensor): [batch_size, seq_len1, dim]
|
||||
string1_len (Tensor): [batch_size]
|
||||
string2 (Tensor): [batch_size, seq_len2, dim]
|
||||
string2_len (Tensor): [batch_size]
|
||||
|
||||
Returns:
|
||||
Tensor: [batch_size, seq_len1, seq_len2]
|
||||
|
||||
"""
|
||||
padded_seq_len1 = string1.shape[1]
|
||||
padded_seq_len2 = string2.shape[1]
|
||||
seq_dim1 = string1.shape[-1]
|
||||
seq_dim2 = string2.shape[-1]
|
||||
x1 = string1
|
||||
x2 = string2
|
||||
result = None
|
||||
|
||||
|
||||
if self.matching_type == 'dot' or self.matching_type == 'general':
|
||||
# if self._normalize:
|
||||
# x1 = K.l2_normalize(x1, axis=2)
|
||||
# x2 = K.l2_normalize(x2, axis=2)
|
||||
if self.matching_type=='general':
|
||||
x1 = x1.view(-1, seq_dim1)
|
||||
x1 = self.linear_in(x1)
|
||||
x1 = x1.view(-1, padded_seq_len1, seq_dim2)
|
||||
result = torch.bmm(x1, x2.transpose(1, 2).contiguous())
|
||||
result = torch.unsqueeze(result, -1)
|
||||
# print("result", result.size())
|
||||
else:
|
||||
if self.matching_type == 'mul':
|
||||
def func(x, y):
|
||||
return x * y
|
||||
elif self.matching_type == 'plus':
|
||||
def func(x, y):
|
||||
return x + y
|
||||
elif self.matching_type == 'minus':
|
||||
def func(x, y):
|
||||
return x - y
|
||||
elif self.matching_type == 'concat':
|
||||
def func(x, y):
|
||||
return torch.concat([x, y], axis=-1)
|
||||
else:
|
||||
raise ValueError(f"Invalid matching type."
|
||||
f"{self.matching_type} received."
|
||||
f"Mut be in `dot`, `general`, `mul`, `plus`, "
|
||||
f"`minus` and `concat`.")
|
||||
x1_exp = torch.stack([x1] * padded_seq_len2, dim=2)
|
||||
x2_exp = torch.stack([x2] * padded_seq_len1, dim=1)
|
||||
result = func(x1_exp, x2_exp)
|
||||
|
||||
return result, padded_seq_len1
|
||||
|
|
@ -6,4 +6,5 @@ from .LinearAttention import LinearAttention, LinearAttentionConf
|
|||
from .BiAttFlow import BiAttFlow, BiAttFlowConf
|
||||
from .MatchAttention import MatchAttention, MatchAttentionConf
|
||||
from .Attention import Attention, AttentionConf
|
||||
from .BilinearAttention import BilinearAttention, BilinearAttentionConf
|
||||
from .BilinearAttention import BilinearAttention, BilinearAttentionConf
|
||||
from .Interaction import Interaction, InteractionConf
|
|
@ -25,16 +25,19 @@ class FlattenConf(BaseConf):
|
|||
@DocInherit
|
||||
def declare(self):
|
||||
self.num_of_inputs = 1
|
||||
self.input_ranks = [3]
|
||||
self.input_ranks = [-1]
|
||||
|
||||
@DocInherit
|
||||
def inference(self):
|
||||
self.output_dim = []
|
||||
if self.input_dims[0][1] == -1:
|
||||
raise ConfigurationError("For Flatten layer, the sequence length should be fixed")
|
||||
else:
|
||||
self.output_dim.append(self.input_dims[0][0])
|
||||
self.output_dim.append(self.input_dims[0][1]*self.input_dims[0][-1])
|
||||
flatted_length = 1
|
||||
for i in range(1, len(self.input_dims[0])):
|
||||
if self.input_dims[0][i] == -1:
|
||||
raise ConfigurationError("For Flatten layer, the sequence length should be fixed")
|
||||
else:
|
||||
flatted_length *= self.input_dims[0][i]
|
||||
|
||||
self.output_dim = [self.input_dims[0][0], flatted_length]
|
||||
|
||||
super(FlattenConf, self).inference()
|
||||
|
||||
|
@ -62,6 +65,8 @@ class Flatten(nn.Module):
|
|||
Returns:
|
||||
Tensor: [batch_size, seq_len*dim], [batch_size]
|
||||
"""
|
||||
return string.view(string.shape[0], -1), None
|
||||
flattened = string.view(string.shape[0], -1)
|
||||
string_len = flattened.size(1)
|
||||
|
||||
return flattened, string_len
|
||||
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
# add the project root to python path
|
||||
import os
|
||||
from settings import ProblemTypes, version
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
from ModelConf import ModelConf
|
||||
from problem import Problem
|
||||
from utils.common_utils import log_set, dump_to_pkl, load_from_pkl
|
||||
|
||||
def main(params, data_path, save_path):
|
||||
conf = ModelConf("cache", params.conf_path, version, params)
|
||||
|
||||
if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
|
||||
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
|
||||
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
if os.path.isfile(conf.problem_path):
|
||||
problem.load_problem(conf.problem_path)
|
||||
logging.info("Cache loaded!")
|
||||
logging.debug("Cache loaded from %s" % conf.problem_path)
|
||||
else:
|
||||
raise Exception("Cache does not exist!")
|
||||
|
||||
data, length, target = problem.encode(data_path, conf.file_columns, conf.input_types, conf.file_with_col_header,
|
||||
conf.object_inputs, conf.answer_column_name, conf.min_sentence_len,
|
||||
extra_feature=conf.extra_feature,max_lengths=conf.max_lengths, file_format='tsv',
|
||||
cpu_num_workers=conf.cpu_num_workers)
|
||||
if not os.path.isdir(os.path.dirname(save_path)):
|
||||
os.makedirs(os.path.dirname(save_path))
|
||||
dump_to_pkl({'data': data, 'length': length, 'target': target}, save_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Data encoding')
|
||||
parser.add_argument("data_path", type=str)
|
||||
parser.add_argument("save_path", type=str)
|
||||
parser.add_argument("--conf_path", type=str, default='conf.json', help="configuration path")
|
||||
parser.add_argument("--debug", type=bool, default=False)
|
||||
parser.add_argument("--force", type=bool, default=False)
|
||||
|
||||
log_set('encoding_data.log')
|
||||
|
||||
params, _ = parser.parse_known_args()
|
||||
|
||||
if params.debug is True:
|
||||
import debugger
|
||||
main(params, params.data_path, params.save_path)
|
|
@ -0,0 +1,176 @@
|
|||
{
|
||||
"license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.",
|
||||
"tool_version": "1.1.0",
|
||||
"model_description": "This model is used for question answer matching task, and it achieved auc: 0.7634 in WikiQACorpus test set",
|
||||
"inputs": {
|
||||
"use_cache": true,
|
||||
"dataset_type": "classification",
|
||||
"data_paths": {
|
||||
"train_data_path": "./dataset/WikiQACorpus/WikiQA-train.tsv",
|
||||
"valid_data_path": "./dataset/WikiQACorpus/WikiQA-dev.tsv",
|
||||
"test_data_path": "./dataset/WikiQACorpus/WikiQA-test.tsv",
|
||||
"pre_trained_emb": "./dataset/GloVe/glove.840B.300d.txt"
|
||||
},
|
||||
"file_with_col_header": true,
|
||||
"add_start_end_for_seq": true,
|
||||
"file_header": {
|
||||
"question_id": 0,
|
||||
"question_text": 1,
|
||||
"document_id": 2,
|
||||
"document_title": 3,
|
||||
"passage_id": 4,
|
||||
"passage_text": 5,
|
||||
"label": 6
|
||||
},
|
||||
"model_inputs": {
|
||||
"question": ["question_text"],
|
||||
"passage": ["passage_text"]
|
||||
},
|
||||
"target": ["label"]
|
||||
},
|
||||
"outputs":{
|
||||
"save_base_dir": "./models/wikiqa_bilstm/",
|
||||
"model_name": "model.nb",
|
||||
"train_log_name": "train.log",
|
||||
"test_log_name": "test.log",
|
||||
"predict_log_name": "predict.log",
|
||||
"predict_fields": ["prediction"],
|
||||
"predict_output_name": "predict.tsv",
|
||||
"cache_dir": ".cache.wikiqa/"
|
||||
},
|
||||
"training_params": {
|
||||
"vocabulary": {
|
||||
"min_word_frequency": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"name": "Adam",
|
||||
"params": {
|
||||
}
|
||||
},
|
||||
"lr_decay": 0.95,
|
||||
"minimum_lr": 0.0001,
|
||||
"epoch_start_lr_decay": 3,
|
||||
"use_gpu": true,
|
||||
"batch_size": 30,
|
||||
"batch_num_to_show_results": 100,
|
||||
"max_epoch": 20,
|
||||
"valid_times_per_epoch": 5,
|
||||
"fixed_lengths":{
|
||||
"question": 30,
|
||||
"passage": 120
|
||||
}
|
||||
},
|
||||
"architecture":[
|
||||
{
|
||||
"layer": "Embedding",
|
||||
"conf": {
|
||||
"word": {
|
||||
"cols": ["question_text", "passage_text"],
|
||||
"dim": 300,
|
||||
"fix_weight": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"layer_id": "question_dropout",
|
||||
"layer": "Dropout",
|
||||
"conf": {
|
||||
"dropout": 0.2
|
||||
},
|
||||
"inputs": ["question"]
|
||||
},
|
||||
{
|
||||
"layer_id": "passage_dropout",
|
||||
"layer": "Dropout",
|
||||
"conf": {
|
||||
"dropout": 0.2
|
||||
},
|
||||
"inputs": ["passage"]
|
||||
},
|
||||
{
|
||||
"layer_id": "interaction",
|
||||
"layer": "Interaction",
|
||||
"conf": {
|
||||
"dropout": 0.2,
|
||||
"hidden_dim": 300,
|
||||
"matching_type": "general"
|
||||
},
|
||||
"inputs": ["question_dropout", "passage_dropout"]
|
||||
},
|
||||
{
|
||||
"layer_id": "cnn_1",
|
||||
"layer": "Conv2D",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"window_size": 3,
|
||||
"padding": 1,
|
||||
"output_channel_num": 8
|
||||
},
|
||||
"inputs": ["interaction"]
|
||||
},
|
||||
{
|
||||
"layer_id": "maxpooling_1",
|
||||
"layer": "Pooling2D",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"window_size": 3,
|
||||
"padding": 1,
|
||||
"pool_type": "max"
|
||||
},
|
||||
"inputs": ["cnn_1"]
|
||||
},
|
||||
{
|
||||
"layer_id": "cnn_2",
|
||||
"layer": "Conv2D",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"window_size": [5,5],
|
||||
"padding": 2,
|
||||
"output_channel_num": 8
|
||||
},
|
||||
"inputs": ["maxpooling_1"]
|
||||
},
|
||||
{
|
||||
"layer_id": "maxpooling_2",
|
||||
"layer": "Pooling2D",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"window_size": 5,
|
||||
"padding": 2,
|
||||
"pool_type": "max"
|
||||
},
|
||||
"inputs": ["cnn_2"]
|
||||
},
|
||||
{
|
||||
"layer_id": "flatten",
|
||||
"layer": "Flatten",
|
||||
"conf": {
|
||||
},
|
||||
"inputs": ["maxpooling_2"]
|
||||
},
|
||||
{
|
||||
"output_layer_flag": true,
|
||||
"layer_id": "output",
|
||||
"layer": "Linear",
|
||||
"conf": {
|
||||
"hidden_dim": [128,2],
|
||||
"activation": "PReLU",
|
||||
"last_hidden_activation": false
|
||||
},
|
||||
"inputs": ["flatten"]
|
||||
}
|
||||
],
|
||||
"loss": {
|
||||
"losses": [
|
||||
{
|
||||
"type": "CrossEntropyLoss",
|
||||
"conf": {
|
||||
"weight": [0.1,0.9],
|
||||
"size_average": true
|
||||
},
|
||||
"inputs": ["output","label"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"metrics": ["auc","accuracy"]
|
||||
}
|
25
predict.py
25
predict.py
|
@ -14,27 +14,10 @@ from LearningMachine import LearningMachine
|
|||
|
||||
def main(params):
|
||||
conf = ModelConf('predict', params.conf_path, version, params, mode=params.mode)
|
||||
|
||||
if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
|
||||
problem = Problem(conf.problem_type, conf.input_types, None,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
|
||||
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
|
||||
problem = Problem(conf.problem_type, conf.input_types, None,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
|
||||
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
|
||||
problem = Problem(conf.problem_type, conf.input_types,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
|
||||
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
problem = Problem('predict', conf.problem_type, conf.input_types, None,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
if os.path.isfile(conf.saved_problem_path):
|
||||
problem.load_problem(conf.saved_problem_path)
|
||||
logging.info("Problem loaded!")
|
||||
|
|
122
problem.py
122
problem.py
|
@ -26,10 +26,8 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
class Problem():
|
||||
def __init__(self, problem_type, input_types, answer_column_name=None, lowercase=False,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True,
|
||||
source_with_pad=True, target_with_start=False, target_with_end=False,
|
||||
target_with_unk=True, target_with_pad=True, same_length=True, with_bos_eos=True,
|
||||
|
||||
def __init__(self, phase, problem_type, input_types, answer_column_name=None, lowercase=False, with_bos_eos=True,
|
||||
tagging_scheme=None, tokenizer="nltk", remove_stopwords=False, DBC2SBC=True, unicode_fix=True):
|
||||
"""
|
||||
|
||||
|
@ -50,9 +48,24 @@ class Problem():
|
|||
same_length:
|
||||
with_bos_eos: whether to add bos and eos when encoding
|
||||
"""
|
||||
self.lowercase = lowercase
|
||||
|
||||
self.input_dicts = dict()
|
||||
# init
|
||||
source_with_start, source_with_end, source_with_unk, source_with_pad, \
|
||||
target_with_start, target_with_end, target_with_unk, target_with_pad, \
|
||||
same_length = (True, ) * 9
|
||||
if ProblemTypes[problem_type] == ProblemTypes.sequence_tagging:
|
||||
pass
|
||||
elif \
|
||||
ProblemTypes[problem_type] == ProblemTypes.classification or \
|
||||
ProblemTypes[problem_type] == ProblemTypes.regression:
|
||||
target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5
|
||||
if phase != 'train':
|
||||
same_length = True
|
||||
elif ProblemTypes[problem_type] == ProblemTypes.mrc:
|
||||
target_with_start, target_with_end, target_with_unk, target_with_pad, same_length = (False, ) * 5
|
||||
with_bos_eos = False
|
||||
|
||||
self.lowercase = lowercase
|
||||
self.problem_type = problem_type
|
||||
self.tagging_scheme = tagging_scheme
|
||||
self.with_bos_eos = with_bos_eos
|
||||
|
@ -65,6 +78,7 @@ class Problem():
|
|||
self.target_with_unk = target_with_unk
|
||||
self.target_with_pad = target_with_pad
|
||||
|
||||
self.input_dicts = dict()
|
||||
for input_type in input_types:
|
||||
self.input_dicts[input_type] = CellDict(with_unk=source_with_unk, with_pad=source_with_pad,
|
||||
with_start=source_with_start, with_end=source_with_end)
|
||||
|
@ -245,6 +259,10 @@ class Problem():
|
|||
Returns:
|
||||
|
||||
"""
|
||||
# parameter check
|
||||
if not word2vec_path:
|
||||
word_emb_dim, format, file_type, involve_all_words = None, None, None, None
|
||||
|
||||
if 'bpe' in input_types:
|
||||
try:
|
||||
bpe_encoder = BPEEncoder(input_types['bpe']['bpe_path'])
|
||||
|
@ -324,51 +342,65 @@ class Problem():
|
|||
|
||||
return word_emb_matrix
|
||||
|
||||
def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
@staticmethod
|
||||
def _merge_encode_data(dest_dict, src_dict):
|
||||
if len(dest_dict) == 0:
|
||||
dest_dict = src_dict
|
||||
else:
|
||||
for branch in src_dict:
|
||||
for input_type in dest_dict[branch]:
|
||||
dest_dict[branch][input_type].extend(src_dict[branch][input_type])
|
||||
return dest_dict
|
||||
|
||||
@staticmethod
|
||||
def _merge_encode_lengths(dest_dict, src_dict):
|
||||
def judge_dict(obj):
|
||||
return True if isinstance(obj, dict) else False
|
||||
cnt_legal, cnt_illegal = 0, 0
|
||||
output_data = dict()
|
||||
lengths = dict()
|
||||
target = dict()
|
||||
for data in tqdm(data_generator):
|
||||
|
||||
if len(dest_dict) == 0:
|
||||
dest_dict = src_dict
|
||||
else:
|
||||
for branch in src_dict:
|
||||
if judge_dict(src_dict[branch]):
|
||||
for type_branch in src_dict[branch]:
|
||||
dest_dict[branch][type_branch].extend(src_dict[branch][type_branch])
|
||||
else:
|
||||
dest_dict[branch].extend(src_dict[branch])
|
||||
return dest_dict
|
||||
|
||||
@staticmethod
|
||||
def _merge_target(dest_dict, src_dict):
|
||||
if not src_dict:
|
||||
return src_dict
|
||||
|
||||
if len(dest_dict) == 0:
|
||||
dest_dict = src_dict
|
||||
else:
|
||||
for single_type in src_dict:
|
||||
dest_dict[single_type].extend(src_dict[single_type])
|
||||
return dest_dict
|
||||
|
||||
def encode_data_multi_processor(self, data_generator, cpu_num_workers, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
|
||||
|
||||
for data in data_generator:
|
||||
scheduler = ProcessorsScheduler(cpu_num_workers)
|
||||
func_args = (data, file_columns, input_types, object_inputs,
|
||||
answer_column_name, min_sentence_len, extra_feature, max_lengths, fixed_lengths, file_format, bpe_encoder)
|
||||
res = scheduler.run_data_parallel(self.encode_data_list, func_args)
|
||||
|
||||
output_data, lengths, target = dict(), dict(), dict()
|
||||
cnt_legal, cnt_illegal = 0, 0
|
||||
for (index, j) in res:
|
||||
# logging.info("collect proccesor %d result"%index)
|
||||
tmp_data, tmp_lengths, tmp_target, tmp_cnt_legal, tmp_cnt_illegal = j.get()
|
||||
|
||||
if len(output_data) == 0:
|
||||
output_data = tmp_data
|
||||
else:
|
||||
for branch in tmp_data:
|
||||
for input_type in output_data[branch]:
|
||||
output_data[branch][input_type].extend(tmp_data[branch][input_type])
|
||||
if len(lengths) == 0:
|
||||
lengths = tmp_lengths
|
||||
else:
|
||||
for branch in tmp_lengths:
|
||||
if judge_dict(tmp_lengths[branch]):
|
||||
for type_branch in tmp_lengths[branch]:
|
||||
lengths[branch][type_branch].extend(tmp_lengths[branch][type_branch])
|
||||
else:
|
||||
lengths[branch].extend(tmp_lengths[branch])
|
||||
if not tmp_target:
|
||||
target = None
|
||||
else:
|
||||
if len(target) == 0:
|
||||
target = tmp_target
|
||||
else:
|
||||
for single_type in tmp_target:
|
||||
target[single_type].extend(tmp_target[single_type])
|
||||
output_data = self._merge_encode_data(output_data, tmp_data)
|
||||
lengths = self._merge_encode_lengths(lengths, tmp_lengths)
|
||||
target = self._merge_target(target, tmp_target)
|
||||
cnt_legal += tmp_cnt_legal
|
||||
cnt_illegal += tmp_cnt_illegal
|
||||
|
||||
return output_data, lengths, target, cnt_legal, cnt_illegal
|
||||
yield output_data, lengths, target, cnt_legal, cnt_illegal
|
||||
|
||||
def encode_data_list(self, data_list, file_columns, input_types, object_inputs, answer_column_name, min_sentence_len,
|
||||
extra_feature, max_lengths=None, fixed_lengths=None, file_format="tsv", bpe_encoder=None):
|
||||
|
@ -678,9 +710,19 @@ class Problem():
|
|||
bpe_encoder = None
|
||||
|
||||
progress = self.get_data_generator_from_file([data_path], file_with_col_header)
|
||||
data, lengths, target, cnt_legal, cnt_illegal = self.encode_data_multi_processor(progress, cpu_num_workers,
|
||||
encoder_generator = self.encode_data_multi_processor(progress, cpu_num_workers,
|
||||
file_columns, input_types, object_inputs, answer_column_name, min_sentence_len, extra_feature, max_lengths,
|
||||
fixed_lengths, file_format, bpe_encoder=bpe_encoder)
|
||||
|
||||
data, lengths, target = dict(), dict(), dict()
|
||||
cnt_legal, cnt_illegal = 0, 0
|
||||
for temp_data, temp_lengths, temp_target, temp_cnt_legal, temp_cnt_illegal in tqdm(encoder_generator):
|
||||
data = self._merge_encode_data(data, temp_data)
|
||||
lengths = self._merge_encode_lengths(lengths, temp_lengths)
|
||||
target = self._merge_target(target, temp_target)
|
||||
cnt_legal += temp_cnt_legal
|
||||
cnt_illegal += temp_cnt_illegal
|
||||
|
||||
logging.info("%s: %d legal samples, %d illegal samples" % (data_path, cnt_legal, cnt_illegal))
|
||||
return data, lengths, target
|
||||
|
||||
|
|
25
test.py
25
test.py
|
@ -16,27 +16,10 @@ from LearningMachine import LearningMachine
|
|||
|
||||
def main(params):
|
||||
conf = ModelConf("test", params.conf_path, version, params, mode=params.mode)
|
||||
|
||||
if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
|
||||
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer, remove_stopwords=conf.remove_stopwords,
|
||||
DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
|
||||
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
problem = Problem("test", conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
if os.path.isfile(conf.saved_problem_path):
|
||||
problem.load_problem(conf.saved_problem_path)
|
||||
logging.info("Problem loaded!")
|
||||
|
|
361
train.py
361
train.py
|
@ -21,202 +21,207 @@ from optimizers import *
|
|||
|
||||
from LearningMachine import LearningMachine
|
||||
|
||||
class Cache:
|
||||
def __init__(self):
|
||||
self.dictionary_invalid = True
|
||||
self.embedding_invalid = True
|
||||
self.encoding_invalid = True
|
||||
|
||||
def _check_dictionary(self, conf, params):
|
||||
# init status
|
||||
self.dictionary_invalid = True
|
||||
self.embedding_invalid = True
|
||||
|
||||
def verify_cache(cache_conf, cur_conf):
|
||||
""" To verify if the cache is appliable to current configuration
|
||||
# cache_conf
|
||||
cache_conf = None
|
||||
cache_conf_path = os.path.join(conf.cache_dir, 'conf_cache.json')
|
||||
if os.path.isfile(cache_conf_path):
|
||||
params_cache = copy.deepcopy(params)
|
||||
try:
|
||||
cache_conf = ModelConf('cache', cache_conf_path, version, params_cache)
|
||||
except Exception as e:
|
||||
cache_conf = None
|
||||
if cache_conf is None or not self._verify_conf(cache_conf, conf):
|
||||
return False
|
||||
|
||||
# problem
|
||||
if not os.path.isfile(conf.problem_path):
|
||||
return False
|
||||
|
||||
Args:
|
||||
cache_conf (ModelConf):
|
||||
cur_conf (ModelConf):
|
||||
# embedding
|
||||
if conf.emb_pkl_path:
|
||||
if not os.path.isfile(conf.emb_pkl_path):
|
||||
return False
|
||||
self.embedding_invalid = False
|
||||
|
||||
self.dictionary_invalid = False
|
||||
return True
|
||||
|
||||
def _check_encoding(self, conf):
|
||||
self.encoding_invalid = False
|
||||
return True
|
||||
|
||||
Returns:
|
||||
def check(self, conf, params):
|
||||
# dictionary
|
||||
if not self._check_dictionary(conf, params):
|
||||
self._renew_cache(params, conf.cache_dir)
|
||||
return
|
||||
# encoding
|
||||
if not self._check_encoding(conf):
|
||||
self._renew_cache(params, conf.cache_dir)
|
||||
|
||||
"""
|
||||
if cache_conf.tool_version != cur_conf.tool_version:
|
||||
return False
|
||||
|
||||
attribute_to_cmp = ['file_columns', 'object_inputs', 'answer_column_name', 'input_types']
|
||||
|
||||
flag = True
|
||||
for attr in attribute_to_cmp:
|
||||
if not (hasattr(cache_conf, attr) and hasattr(cur_conf, attr) and getattr(cache_conf, attr) == getattr(cur_conf, attr)):
|
||||
logging.error('configuration %s is inconsistent with the old cache' % attr)
|
||||
flag = False
|
||||
return flag
|
||||
|
||||
|
||||
def main(params):
|
||||
conf = ModelConf("train", params.conf_path, version, params, mode=params.mode)
|
||||
|
||||
shutil.copy(params.conf_path, conf.save_base_dir)
|
||||
logging.info('Configuration file is backed up to %s' % (conf.save_base_dir))
|
||||
|
||||
if ProblemTypes[conf.problem_type] == ProblemTypes.sequence_tagging:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=True, target_with_end=True, target_with_unk=True, target_with_pad=True, same_length=True,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.classification \
|
||||
or ProblemTypes[conf.problem_type] == ProblemTypes.regression:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
|
||||
same_length=False, with_bos_eos=conf.add_start_end_for_seq, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
elif ProblemTypes[conf.problem_type] == ProblemTypes.mrc:
|
||||
problem = Problem(conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
source_with_start=True, source_with_end=True, source_with_unk=True, source_with_pad=True,
|
||||
target_with_start=False, target_with_end=False, target_with_unk=False, target_with_pad=False,
|
||||
same_length=False, with_bos_eos=False, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
|
||||
cache_load_flag = False
|
||||
if not conf.pretrained_model_path:
|
||||
# first time training, load cache if appliable
|
||||
if conf.use_cache:
|
||||
cache_conf_path = os.path.join(conf.cache_dir, 'conf_cache.json')
|
||||
if os.path.isfile(cache_conf_path):
|
||||
params_cache = copy.deepcopy(params)
|
||||
'''
|
||||
for key in vars(params_cache):
|
||||
setattr(params_cache, key, None)
|
||||
params_cache.mode = params.mode
|
||||
'''
|
||||
try:
|
||||
cache_conf = ModelConf('cache', cache_conf_path, version, params_cache)
|
||||
except Exception as e:
|
||||
cache_conf = None
|
||||
if cache_conf is None or verify_cache(cache_conf, conf) is not True:
|
||||
logging.info('Found cache that is ineffective')
|
||||
if params.mode == 'philly' or params.force is True:
|
||||
renew_option = 'yes'
|
||||
else:
|
||||
renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(conf.cache_dir))
|
||||
if renew_option.lower() != 'yes':
|
||||
exit(0)
|
||||
else:
|
||||
shutil.rmtree(conf.cache_dir)
|
||||
time.sleep(2) # sleep 2 seconds since the deleting is asynchronous
|
||||
logging.info('Old cache is deleted')
|
||||
else:
|
||||
logging.info('Found cache that is appliable to current configuration...')
|
||||
|
||||
elif os.path.isdir(conf.cache_dir):
|
||||
renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(conf.cache_dir))
|
||||
if renew_option.lower() != 'yes':
|
||||
exit(0)
|
||||
else:
|
||||
shutil.rmtree(conf.cache_dir)
|
||||
time.sleep(2) # Sleep 2 seconds since the deleting is asynchronous
|
||||
logging.info('Old cache is deleted')
|
||||
|
||||
if not os.path.exists(conf.cache_dir):
|
||||
os.makedirs(conf.cache_dir)
|
||||
shutil.copy(params.conf_path, os.path.join(conf.cache_dir, 'conf_cache.json'))
|
||||
|
||||
# first time training, load problem from cache, and then backup the cache to model_save_dir/.necessary_cache/
|
||||
if conf.use_cache and os.path.isfile(conf.problem_path):
|
||||
def load(self, conf, problem, emb_matrix):
|
||||
# load dictionary when (not finetune) and (cache valid)
|
||||
if not conf.pretrained_model_path and not self.dictionary_invalid:
|
||||
problem.load_problem(conf.problem_path)
|
||||
if conf.emb_pkl_path is not None:
|
||||
if os.path.isfile(conf.emb_pkl_path):
|
||||
emb_matrix = np.array(load_from_pkl(conf.emb_pkl_path))
|
||||
cache_load_flag = True
|
||||
else:
|
||||
if params.mode == 'normal':
|
||||
renew_option = input('The cache is invalid because the embedding matrix does not exist in the cache directory. Input "yes" to renew cache and "no" to exit. (default:no): ')
|
||||
if renew_option.lower() != 'yes':
|
||||
exit(0)
|
||||
else:
|
||||
# by default, renew cache
|
||||
renew_option = 'yes'
|
||||
else:
|
||||
emb_matrix = None
|
||||
cache_load_flag = True
|
||||
if cache_load_flag:
|
||||
logging.info("Cache loaded!")
|
||||
|
||||
if cache_load_flag is False:
|
||||
logging.info("Preprocessing... Depending on your corpus size, this step may take a while.")
|
||||
# modify train_data_path to [train_data_path, valid_data_path, test_data_path]
|
||||
# remember the test_data may be None
|
||||
data_path_list = [conf.train_data_path, conf.valid_data_path, conf.test_data_path]
|
||||
if conf.pretrained_emb_path:
|
||||
emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header,
|
||||
conf.answer_column_name, word2vec_path=conf.pretrained_emb_path,
|
||||
word_emb_dim=conf.pretrained_emb_dim, format=conf.pretrained_emb_type,
|
||||
file_type=conf.pretrained_emb_binary_or_text, involve_all_words=conf.involve_all_words_in_pretrained_emb,
|
||||
show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers,
|
||||
max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency)
|
||||
else:
|
||||
emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header,
|
||||
conf.answer_column_name, word2vec_path=None, word_emb_dim=None, format=None,
|
||||
file_type=None, involve_all_words=conf.involve_all_words_in_pretrained_emb,
|
||||
show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers,
|
||||
max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency)
|
||||
if not self.embedding_invalid:
|
||||
emb_matrix = np.array(load_from_pkl(conf.emb_pkl_path))
|
||||
logging.info('[Cache] loading dictionary successfully')
|
||||
|
||||
if not self.encoding_invalid:
|
||||
pass
|
||||
return problem, emb_matrix
|
||||
|
||||
def save(self, conf, params, problem, emb_matrix):
|
||||
if not os.path.exists(conf.cache_dir):
|
||||
os.makedirs(conf.cache_dir)
|
||||
shutil.copy(params.conf_path, os.path.join(conf.cache_dir, 'conf_cache.json'))
|
||||
if self.dictionary_invalid:
|
||||
if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'):
|
||||
with HDFSDirectTransferer(conf.problem_path, with_hdfs_command=True) as transferer:
|
||||
transferer.pkl_dump(problem.export_problem(conf.problem_path, ret_without_save=True))
|
||||
else:
|
||||
problem.export_problem(conf.problem_path)
|
||||
if conf.use_cache:
|
||||
logging.info("Cache saved to %s" % conf.problem_path)
|
||||
if emb_matrix is not None and conf.emb_pkl_path is not None:
|
||||
if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'):
|
||||
with HDFSDirectTransferer(conf.emb_pkl_path, with_hdfs_command=True) as transferer:
|
||||
transferer.pkl_dump(emb_matrix)
|
||||
else:
|
||||
dump_to_pkl(emb_matrix, conf.emb_pkl_path)
|
||||
logging.info("Embedding matrix saved to %s" % conf.emb_pkl_path)
|
||||
else:
|
||||
logging.debug("Cache saved to %s" % conf.problem_path)
|
||||
logging.info("[Cache] problem is saved to %s" % conf.problem_path)
|
||||
if emb_matrix is not None and conf.emb_pkl_path is not None:
|
||||
if conf.mode == 'philly' and conf.emb_pkl_path.startswith('/hdfs/'):
|
||||
with HDFSDirectTransferer(conf.emb_pkl_path, with_hdfs_command=True) as transferer:
|
||||
transferer.pkl_dump(emb_matrix)
|
||||
else:
|
||||
dump_to_pkl(emb_matrix, conf.emb_pkl_path)
|
||||
logging.info("Embedding matrix saved to %s" % conf.emb_pkl_path)
|
||||
|
||||
if self.encoding_invalid:
|
||||
pass
|
||||
|
||||
# Back up the problem.pkl to save_base_dir/.necessary_cache. During test phase, we would load cache from save_base_dir/.necessary_cache/problem.pkl
|
||||
def back_up(self, conf, problem):
|
||||
cache_bakup_path = os.path.join(conf.save_base_dir, 'necessary_cache/')
|
||||
logging.debug('Prepare dir: %s' % cache_bakup_path)
|
||||
prepare_dir(cache_bakup_path, True, allow_overwrite=True, clear_dir_if_exist=True)
|
||||
|
||||
shutil.copy(conf.problem_path, cache_bakup_path)
|
||||
problem.export_problem(cache_bakup_path+'problem.pkl')
|
||||
logging.debug("Problem %s is backed up to %s" % (conf.problem_path, cache_bakup_path))
|
||||
if problem.output_dict:
|
||||
logging.debug("Problem target cell dict: %s" % (problem.output_dict.cell_id_map))
|
||||
|
||||
if params.make_cache_only:
|
||||
logging.info("Finish building cache!")
|
||||
def _renew_cache(self, params, cache_path):
|
||||
if not os.path.exists(cache_path):
|
||||
return
|
||||
logging.info('Found cache that is ineffective')
|
||||
renew_option = 'yes'
|
||||
if params.mode != 'philly' and params.force is not True:
|
||||
renew_option = input('There exists ineffective cache %s for old models. Input "yes" to renew cache and "no" to exit. (default:no): ' % os.path.abspath(cache_path))
|
||||
if renew_option.lower() != 'yes':
|
||||
exit(0)
|
||||
else:
|
||||
shutil.rmtree(cache_path)
|
||||
time.sleep(2) # sleep 2 seconds since the deleting is asynchronous
|
||||
logging.info('Old cache is deleted')
|
||||
|
||||
vocab_info = dict() # include input_type's vocab_size & init_emd_matrix
|
||||
vocab_sizes = problem.get_vocab_sizes()
|
||||
for input_cluster in vocab_sizes:
|
||||
vocab_info[input_cluster] = dict()
|
||||
vocab_info[input_cluster]['vocab_size'] = vocab_sizes[input_cluster]
|
||||
# add extra info for char_emb
|
||||
if input_cluster.lower() == 'char':
|
||||
for key, value in conf.input_types[input_cluster].items():
|
||||
if key != 'cols':
|
||||
vocab_info[input_cluster][key] = value
|
||||
if input_cluster == 'word' and emb_matrix is not None:
|
||||
vocab_info[input_cluster]['init_weights'] = emb_matrix
|
||||
else:
|
||||
vocab_info[input_cluster]['init_weights'] = None
|
||||
def _verify_conf(self, cache_conf, cur_conf):
|
||||
""" To verify if the cache is appliable to current configuration
|
||||
|
||||
lm = LearningMachine('train', conf, problem, vocab_info=vocab_info, initialize=True, use_gpu=conf.use_gpu)
|
||||
else:
|
||||
# when finetuning, load previous saved problem
|
||||
Args:
|
||||
cache_conf (ModelConf):
|
||||
cur_conf (ModelConf):
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if cache_conf.tool_version != cur_conf.tool_version:
|
||||
return False
|
||||
|
||||
attribute_to_cmp = ['file_columns', 'object_inputs', 'answer_column_name', 'input_types', 'language']
|
||||
|
||||
flag = True
|
||||
for attr in attribute_to_cmp:
|
||||
if not (hasattr(cache_conf, attr) and hasattr(cur_conf, attr) and getattr(cache_conf, attr) == getattr(cur_conf, attr)):
|
||||
logging.error('configuration %s is inconsistent with the old cache' % attr)
|
||||
flag = False
|
||||
return flag
|
||||
|
||||
def main(params):
|
||||
# init
|
||||
conf = ModelConf("train", params.conf_path, version, params, mode=params.mode)
|
||||
problem = Problem("train", conf.problem_type, conf.input_types, conf.answer_column_name,
|
||||
with_bos_eos=conf.add_start_end_for_seq, tagging_scheme=conf.tagging_scheme, tokenizer=conf.tokenizer,
|
||||
remove_stopwords=conf.remove_stopwords, DBC2SBC=conf.DBC2SBC, unicode_fix=conf.unicode_fix)
|
||||
if conf.pretrained_model_path:
|
||||
### when finetuning, load previous saved problem
|
||||
problem.load_problem(conf.saved_problem_path)
|
||||
lm = LearningMachine('train', conf, problem, vocab_info=None, initialize=False, use_gpu=conf.use_gpu)
|
||||
|
||||
# cache verification
|
||||
emb_matrix = None
|
||||
cache = Cache()
|
||||
if conf.use_cache:
|
||||
## check
|
||||
cache.check(conf, params)
|
||||
## load
|
||||
problem, emb_matrix = cache.load(conf, problem, emb_matrix)
|
||||
|
||||
# data preprocessing
|
||||
## build dictionary when (not in finetune model) and (not use cache or cache invalid)
|
||||
if (not conf.pretrained_model_path) and ((conf.use_cache == False) or cache.dictionary_invalid):
|
||||
logging.info("Preprocessing... Depending on your corpus size, this step may take a while.")
|
||||
# modify train_data_path to [train_data_path, valid_data_path, test_data_path]
|
||||
# remember the test_data may be None
|
||||
data_path_list = [conf.train_data_path, conf.valid_data_path, conf.test_data_path]
|
||||
emb_matrix = problem.build(data_path_list, conf.file_columns, conf.input_types, conf.file_with_col_header,
|
||||
conf.answer_column_name, word2vec_path=conf.pretrained_emb_path,
|
||||
word_emb_dim=conf.pretrained_emb_dim, format=conf.pretrained_emb_type,
|
||||
file_type=conf.pretrained_emb_binary_or_text, involve_all_words=conf.involve_all_words_in_pretrained_emb,
|
||||
show_progress=True if params.mode == 'normal' else False, cpu_num_workers = conf.cpu_num_workers,
|
||||
max_vocabulary=conf.max_vocabulary, word_frequency=conf.min_word_frequency)
|
||||
|
||||
## encode rawdata when do not use cache
|
||||
if conf.use_cache == False:
|
||||
pass
|
||||
|
||||
# environment preparing
|
||||
## cache save
|
||||
if conf.use_cache:
|
||||
cache.save(conf, params, problem, emb_matrix)
|
||||
|
||||
if params.make_cache_only:
|
||||
if conf.use_cache:
|
||||
logging.info("Finish building cache!")
|
||||
else:
|
||||
logging.info('Please set parameters "use_cache" is true')
|
||||
return
|
||||
|
||||
## back up the problem.pkl to save_base_dir/.necessary_cache.
|
||||
## During test phase, we would load cache from save_base_dir/.necessary_cache/problem.pkl
|
||||
conf.back_up(params)
|
||||
cache.back_up(conf, problem)
|
||||
if problem.output_dict:
|
||||
logging.debug("Problem target cell dict: %s" % (problem.output_dict.cell_id_map))
|
||||
|
||||
# train phase
|
||||
## init
|
||||
### model
|
||||
vocab_info, initialize = None, False
|
||||
if not conf.pretrained_model_path:
|
||||
vocab_info, initialize = get_vocab_info(conf, problem, emb_matrix), True
|
||||
|
||||
lm = LearningMachine('train', conf, problem, vocab_info=vocab_info, initialize=initialize, use_gpu=conf.use_gpu)
|
||||
if conf.pretrained_model_path:
|
||||
logging.info('Loading the pretrained model: %s...' % conf.pretrained_model_path)
|
||||
lm.load_model(conf.pretrained_model_path)
|
||||
|
||||
### loss
|
||||
if len(conf.metrics_post_check) > 0:
|
||||
for metric_to_chk in conf.metrics_post_check:
|
||||
metric, target = metric_to_chk.split('@')
|
||||
if not problem.output_dict.has_cell(target):
|
||||
raise Exception("The target %s of %s does not exist in the training data." % (target, metric_to_chk))
|
||||
|
||||
if conf.pretrained_model_path:
|
||||
logging.info('Loading the pretrained model: %s...' % conf.pretrained_model_path)
|
||||
lm.load_model(conf.pretrained_model_path)
|
||||
|
||||
loss_conf = conf.loss
|
||||
loss_conf['output_layer_id'] = conf.output_layer_id
|
||||
loss_conf['answer_column_name'] = conf.answer_column_name
|
||||
|
@ -225,11 +230,13 @@ def main(params):
|
|||
if conf.use_gpu is True:
|
||||
loss_fn.cuda()
|
||||
|
||||
### optimizer
|
||||
optimizer = eval(conf.optimizer_name)(lm.model.parameters(), **conf.optimizer_params)
|
||||
|
||||
## train
|
||||
lm.train(optimizer, loss_fn)
|
||||
|
||||
# test the best model with the best model saved
|
||||
## test the best model with the best model saved
|
||||
lm.load_model(conf.model_save_path)
|
||||
if conf.test_data_path is not None:
|
||||
test_path = conf.test_data_path
|
||||
|
@ -241,6 +248,22 @@ def main(params):
|
|||
else:
|
||||
lm.test(loss_fn, test_path)
|
||||
|
||||
def get_vocab_info(conf, problem, emb_matrix):
|
||||
vocab_info = dict() # include input_type's vocab_size & init_emd_matrix
|
||||
vocab_sizes = problem.get_vocab_sizes()
|
||||
for input_cluster in vocab_sizes:
|
||||
vocab_info[input_cluster] = dict()
|
||||
vocab_info[input_cluster]['vocab_size'] = vocab_sizes[input_cluster]
|
||||
# add extra info for char_emb
|
||||
if input_cluster.lower() == 'char':
|
||||
for key, value in conf.input_types[input_cluster].items():
|
||||
if key != 'cols':
|
||||
vocab_info[input_cluster][key] = value
|
||||
if input_cluster == 'word' and emb_matrix is not None:
|
||||
vocab_info[input_cluster]['init_weights'] = emb_matrix
|
||||
else:
|
||||
vocab_info[input_cluster]['init_weights'] = None
|
||||
return vocab_info
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Training')
|
||||
|
|
|
@ -9,6 +9,7 @@ import shutil
|
|||
import time
|
||||
import tempfile
|
||||
import subprocess
|
||||
import hashlib
|
||||
|
||||
def log_set(log_path, console_level='INFO', console_detailed=False, disable_log_file=False):
|
||||
"""
|
||||
|
@ -216,3 +217,23 @@ def prepare_dir(path, is_dir, allow_overwrite=False, clear_dir_if_exist=False, e
|
|||
overwrite_option = input('The file %s already exists, input "yes" to allow us to overwrite it or "no" to exit. (default:no): ' % path)
|
||||
if overwrite_option.lower() != 'yes':
|
||||
exit(0)
|
||||
|
||||
def md5(file_paths, chunk_size=1024*1024*1024):
|
||||
""" Calculate a md5 of lists of files.
|
||||
|
||||
Args:
|
||||
file_paths: an iterable object contains files. Files will be concatenated orderly if there are more than one file
|
||||
chunk_size: unit is byte, default value is 1GB
|
||||
Returns:
|
||||
md5
|
||||
|
||||
"""
|
||||
md5 = hashlib.md5()
|
||||
for path in file_paths:
|
||||
with open(path, 'rb') as fin:
|
||||
while True:
|
||||
data = fin.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
md5.update(data)
|
||||
return md5.hexdigest()
|
Загрузка…
Ссылка в новой задаче