add code for query binary classifier knowledge distillation (#3)
* Add new config about knowledge distillation for query binary classifier
This commit is contained in:
Родитель
93c5c164c1
Коммит
eaef9265b7
|
@ -1,5 +1,6 @@
|
|||
.idea/
|
||||
*~
|
||||
*.pyc
|
||||
*.cache/
|
||||
*.cache*
|
||||
dataset/GloVe/
|
||||
models/
|
||||
|
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,192 @@
|
|||
{
|
||||
"license": "Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the MIT license.",
|
||||
"tool_version": "1.1.0",
|
||||
"model_description": "This model is used for knowledge distillation for query binary classifier",
|
||||
"inputs": {
|
||||
"use_cache": true,
|
||||
"dataset_type": "regression",
|
||||
"data_paths": {
|
||||
"train_data_path": "./dataset/knowledge_distillation/query_binary_classifier/train.tsv",
|
||||
"valid_data_path": "./dataset/knowledge_distillation/query_binary_classifier/valid.tsv",
|
||||
"test_data_path": "./dataset/knowledge_distillation/query_binary_classifier/test.tsv",
|
||||
"pre_trained_emb": "./dataset/GloVe/glove.840B.300d.txt"
|
||||
},
|
||||
"pretrained_emb_type": "glove",
|
||||
"pretrained_emb_binary_or_text": "text",
|
||||
"add_start_end_for_seq": true,
|
||||
"file_header": {
|
||||
"query_text": 0,
|
||||
"label_score": 1
|
||||
},
|
||||
"predict_file_header": {
|
||||
"query_text": 0
|
||||
},
|
||||
"model_inputs": {
|
||||
"query": ["query_text"]
|
||||
},
|
||||
"target": ["label_score"]
|
||||
},
|
||||
"outputs": {
|
||||
"save_base_dir": "./models/kdqbc_bilstmattn_cnn/train/",
|
||||
"model_name": "model.nb",
|
||||
"train_log_name": "train.log",
|
||||
"test_log_name": "test.log",
|
||||
"predict_log_name": "predict.log",
|
||||
"predict_fields": ["prediction"],
|
||||
"predict_output_name": "predict.tsv",
|
||||
"cache_dir": ".cache.kdqbc_bilstmattn_cnn/"
|
||||
},
|
||||
"training_params": {
|
||||
"vocabulary": {
|
||||
"min_word_frequency": 1
|
||||
},
|
||||
"optimizer": {
|
||||
"name": "Adam",
|
||||
"params": {
|
||||
"lr": 0.001
|
||||
}
|
||||
},
|
||||
"lr_decay": 0.95,
|
||||
"minimum_lr": 0.0001,
|
||||
"epoch_start_lr_decay": 3,
|
||||
"use_gpu": true,
|
||||
"batch_size": 256,
|
||||
"batch_num_to_show_results": 10,
|
||||
"max_epoch": 30,
|
||||
"valid_times_per_epoch": 10,
|
||||
"fixed_lengths":{
|
||||
"query": 30
|
||||
}
|
||||
},
|
||||
"architecture":[
|
||||
{
|
||||
"layer": "Embedding",
|
||||
"conf": {
|
||||
"word": {
|
||||
"cols": [ "query_text"],
|
||||
"dim": 300
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"layer_id": "Encoder",
|
||||
"layer": "BiLSTMAtt",
|
||||
"conf": {
|
||||
"hidden_dim": 128,
|
||||
"dropout": 0,
|
||||
"num_layers": 1
|
||||
},
|
||||
"inputs": ["query"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Conv1",
|
||||
"layer": "Conv",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"padding": 0,
|
||||
"window_size": 3,
|
||||
"input_channel_num": 1,
|
||||
"output_channel_num": 128,
|
||||
"activation": "PReLU"
|
||||
},
|
||||
"inputs": ["Encoder"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Pooling1",
|
||||
"layer": "Pooling",
|
||||
"conf": {
|
||||
"pool_axis": 1,
|
||||
"pool_type": "max"
|
||||
},
|
||||
"inputs": ["Conv1"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Conv2",
|
||||
"layer": "Conv",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"padding": 0,
|
||||
"window_size": 2,
|
||||
"input_channel_num": 1,
|
||||
"output_channel_num": 128,
|
||||
"activation": "PReLU"
|
||||
},
|
||||
"inputs": ["Encoder"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Pooling2",
|
||||
"layer": "Pooling",
|
||||
"conf": {
|
||||
"pool_axis": 1,
|
||||
"pool_type": "max"
|
||||
},
|
||||
"inputs": ["Conv2"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Conv3",
|
||||
"layer": "Conv",
|
||||
"conf": {
|
||||
"stride": 1,
|
||||
"padding": 0,
|
||||
"window_size": 1,
|
||||
"input_channel_num": 1,
|
||||
"output_channel_num": 128,
|
||||
"activation": "PReLU"
|
||||
},
|
||||
"inputs": ["Encoder"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Pooling3",
|
||||
"layer": "Pooling",
|
||||
"conf": {
|
||||
"pool_axis": 1,
|
||||
"pool_type": "max"
|
||||
},
|
||||
"inputs": ["Conv3"]
|
||||
},
|
||||
{
|
||||
"layer_id": "Comb",
|
||||
"layer": "Combination",
|
||||
"conf": {
|
||||
"operations": ["origin"]
|
||||
},
|
||||
"inputs": ["Pooling1", "Pooling2", "Pooling3"]
|
||||
},
|
||||
{
|
||||
"layer_id": "FC",
|
||||
"layer": "Linear",
|
||||
"conf": {
|
||||
"hidden_dim": [256, 128],
|
||||
"activation": "PReLU",
|
||||
"batch_norm": true,
|
||||
"last_hidden_activation": true,
|
||||
"last_hidden_softmax": false
|
||||
},
|
||||
"inputs": ["Comb"]
|
||||
},
|
||||
{
|
||||
"output_layer_flag": true,
|
||||
"layer_id": "output",
|
||||
"layer": "Linear",
|
||||
"conf": {
|
||||
"hidden_dim": [1],
|
||||
"activation": "Sigmoid",
|
||||
"last_hidden_activation": true,
|
||||
"last_hidden_softmax": false
|
||||
},
|
||||
"inputs": ["FC"]
|
||||
}
|
||||
],
|
||||
"loss": {
|
||||
"losses" :[
|
||||
{
|
||||
"type": "MSELoss",
|
||||
"conf": {
|
||||
"size_average": true
|
||||
},
|
||||
"inputs": ["output","label_score"]
|
||||
}
|
||||
]
|
||||
},
|
||||
"metrics": ["RMSE", "MSE"]
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
from sklearn.metrics import roc_auc_score
|
||||
|
||||
def read_tsv(params):
|
||||
prediction, label = [], []
|
||||
predict_index, label_index = int(params.predict_index), int(params.label_index)
|
||||
min_column_num = max(predict_index, label_index) + 1
|
||||
with codecs.open(params.input_file, mode='r', encoding='utf-8') as f:
|
||||
for index, line in enumerate(f):
|
||||
if params.header and index == 0:
|
||||
continue
|
||||
line = line.rstrip()
|
||||
# skip empty line
|
||||
if not line:
|
||||
continue
|
||||
line = line.split('\t')
|
||||
if len(line) < min_column_num:
|
||||
print("at line:%s, %s"%(predict_index, line))
|
||||
raise Exception("the given index of predict or label is exceed the index of the column")
|
||||
prediction.append(float(line[predict_index]))
|
||||
label.append(int(line[label_index]))
|
||||
return prediction, label
|
||||
|
||||
def calculate_AUC(prediction, label):
|
||||
return roc_auc_score(label, prediction)
|
||||
|
||||
def main(params):
|
||||
prediction, label = read_tsv(params)
|
||||
auc = calculate_AUC(prediction, label)
|
||||
print("AUC is ", auc)
|
||||
return auc
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="AUC")
|
||||
parser.add_argument("--input_file", type=str, help="tsv file")
|
||||
parser.add_argument("--predict_index", type=str, help="the column index of prediction of model, start from 0")
|
||||
parser.add_argument("--label_index", type=str, help="the column index of label, start from 0")
|
||||
parser.add_argument("--header", action='store_true', default=False, help="whether contains header row or not, default is False")
|
||||
|
||||
params, _ = parser.parse_known_args()
|
||||
|
||||
assert params.input_file, 'Please specify a input file via --input_file'
|
||||
assert params.predict_index, 'Please specify the column index of prediction via --predict_index'
|
||||
assert params.label_index, 'Please specify the column index of label via --label_index'
|
||||
main(params)
|
Загрузка…
Ссылка в новой задаче