add code for query binary classifier knowledge distillation (#3)

* Add new config about knowledge distillation for query binary classifier
This commit is contained in:
L.J. SHOU 2019-04-28 18:12:20 +08:00 коммит произвёл GitHub
Родитель 93c5c164c1
Коммит eaef9265b7
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
6 изменённых файлов: 20431 добавлений и 1 удалений

3
.gitignore поставляемый
Просмотреть файл

@ -1,5 +1,6 @@
.idea/
*~
*.pyc
*.cache/
*.cache*
dataset/GloVe/
models/

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,192 @@
{
"license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.",
"tool_version": "1.1.0",
"model_description": "This model is used for knowledge distillation for query binary classifier",
"inputs": {
"use_cache": true,
"dataset_type": "regression",
"data_paths": {
"train_data_path": "./dataset/knowledge_distillation/query_binary_classifier/train.tsv",
"valid_data_path": "./dataset/knowledge_distillation/query_binary_classifier/valid.tsv",
"test_data_path": "./dataset/knowledge_distillation/query_binary_classifier/test.tsv",
"pre_trained_emb": "./dataset/GloVe/glove.840B.300d.txt"
},
"pretrained_emb_type": "glove",
"pretrained_emb_binary_or_text": "text",
"add_start_end_for_seq": true,
"file_header": {
"query_text": 0,
"label_score": 1
},
"predict_file_header": {
"query_text": 0
},
"model_inputs": {
"query": ["query_text"]
},
"target": ["label_score"]
},
"outputs": {
"save_base_dir": "./models/kdqbc_bilstmattn_cnn/train/",
"model_name": "model.nb",
"train_log_name": "train.log",
"test_log_name": "test.log",
"predict_log_name": "predict.log",
"predict_fields": ["prediction"],
"predict_output_name": "predict.tsv",
"cache_dir": ".cache.kdqbc_bilstmattn_cnn/"
},
"training_params": {
"vocabulary": {
"min_word_frequency": 1
},
"optimizer": {
"name": "Adam",
"params": {
"lr": 0.001
}
},
"lr_decay": 0.95,
"minimum_lr": 0.0001,
"epoch_start_lr_decay": 3,
"use_gpu": true,
"batch_size": 256,
"batch_num_to_show_results": 10,
"max_epoch": 30,
"valid_times_per_epoch": 10,
"fixed_lengths":{
"query": 30
}
},
"architecture":[
{
"layer": "Embedding",
"conf": {
"word": {
"cols": [ "query_text"],
"dim": 300
}
}
},
{
"layer_id": "Encoder",
"layer": "BiLSTMAtt",
"conf": {
"hidden_dim": 128,
"dropout": 0,
"num_layers": 1
},
"inputs": ["query"]
},
{
"layer_id": "Conv1",
"layer": "Conv",
"conf": {
"stride": 1,
"padding": 0,
"window_size": 3,
"input_channel_num": 1,
"output_channel_num": 128,
"activation": "PReLU"
},
"inputs": ["Encoder"]
},
{
"layer_id": "Pooling1",
"layer": "Pooling",
"conf": {
"pool_axis": 1,
"pool_type": "max"
},
"inputs": ["Conv1"]
},
{
"layer_id": "Conv2",
"layer": "Conv",
"conf": {
"stride": 1,
"padding": 0,
"window_size": 2,
"input_channel_num": 1,
"output_channel_num": 128,
"activation": "PReLU"
},
"inputs": ["Encoder"]
},
{
"layer_id": "Pooling2",
"layer": "Pooling",
"conf": {
"pool_axis": 1,
"pool_type": "max"
},
"inputs": ["Conv2"]
},
{
"layer_id": "Conv3",
"layer": "Conv",
"conf": {
"stride": 1,
"padding": 0,
"window_size": 1,
"input_channel_num": 1,
"output_channel_num": 128,
"activation": "PReLU"
},
"inputs": ["Encoder"]
},
{
"layer_id": "Pooling3",
"layer": "Pooling",
"conf": {
"pool_axis": 1,
"pool_type": "max"
},
"inputs": ["Conv3"]
},
{
"layer_id": "Comb",
"layer": "Combination",
"conf": {
"operations": ["origin"]
},
"inputs": ["Pooling1", "Pooling2", "Pooling3"]
},
{
"layer_id": "FC",
"layer": "Linear",
"conf": {
"hidden_dim": [256, 128],
"activation": "PReLU",
"batch_norm": true,
"last_hidden_activation": true,
"last_hidden_softmax": false
},
"inputs": ["Comb"]
},
{
"output_layer_flag": true,
"layer_id": "output",
"layer": "Linear",
"conf": {
"hidden_dim": [1],
"activation": "Sigmoid",
"last_hidden_activation": true,
"last_hidden_softmax": false
},
"inputs": ["FC"]
}
],
"loss": {
"losses" :[
{
"type": "MSELoss",
"conf": {
"size_average": true
},
"inputs": ["output","label_score"]
}
]
},
"metrics": ["RMSE", "MSE"]
}

49
tools/calculate_AUC.py Normal file
Просмотреть файл

@ -0,0 +1,49 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import argparse
import codecs
from sklearn.metrics import roc_auc_score
def read_tsv(params):
prediction, label = [], []
predict_index, label_index = int(params.predict_index), int(params.label_index)
min_column_num = max(predict_index, label_index) + 1
with codecs.open(params.input_file, mode='r', encoding='utf-8') as f:
for index, line in enumerate(f):
if params.header and index == 0:
continue
line = line.rstrip()
# skip empty line
if not line:
continue
line = line.split('\t')
if len(line) < min_column_num:
print("at line:%s, %s"%(predict_index, line))
raise Exception("the given index of predict or label is exceed the index of the column")
prediction.append(float(line[predict_index]))
label.append(int(line[label_index]))
return prediction, label
def calculate_AUC(prediction, label):
return roc_auc_score(label, prediction)
def main(params):
prediction, label = read_tsv(params)
auc = calculate_AUC(prediction, label)
print("AUC is ", auc)
return auc
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AUC")
parser.add_argument("--input_file", type=str, help="tsv file")
parser.add_argument("--predict_index", type=str, help="the column index of prediction of model, start from 0")
parser.add_argument("--label_index", type=str, help="the column index of label, start from 0")
parser.add_argument("--header", action='store_true', default=False, help="whether contains header row or not, default is False")
params, _ = parser.parse_known_args()
assert params.input_file, 'Please specify a input file via --input_file'
assert params.predict_index, 'Please specify the column index of prediction via --predict_index'
assert params.label_index, 'Please specify the column index of label via --label_index'
main(params)