зеркало из https://github.com/microsoft/MWSS.git
initial commit
This commit is contained in:
Родитель
55cc0759b5
Коммит
fdfc60571e
36
LICENSE
36
LICENSE
|
@ -1,21 +1,23 @@
|
|||
MIT License
|
||||
Early Detection of Fake News with Multi-source Weak Social Supervision
|
||||
|
||||
Copyright (c) Microsoft Corporation.
|
||||
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
MIT License
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
|
245
README.md
245
README.md
|
@ -1,33 +1,230 @@
|
|||
# Project
|
||||
## Multi Weak Source Supervision for Fake News Detection
|
||||
|
||||
> This repo has been populated by an initial template to help get you started. Please
|
||||
> make sure to update the content to build a great experience for community-building.
|
||||
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li, Kai Shu
|
||||
|
||||
As the maintainer of this project, please make a few updates:
|
||||
This repository contains code for fake news detection with Multi-source Weak Social Supervision (MWSS), published at **ECML-PKDD 2020** at: [Early Detection of Fake News with Multi-source Weak Social Supervision](https://www.microsoft.com/en-us/research/publication/leveraging-multi-source-weak-social-supervision-for-early-detection-of-fake-news/)
|
||||
|
||||
- Improving this README.MD file to provide a great experience
|
||||
- Updating SUPPORT.MD with content about this project's support experience
|
||||
- Understanding the security reporting process in SECURITY.MD
|
||||
- Remove this section from the README
|
||||
### Requirements
|
||||
torch=1.x
|
||||
|
||||
## Contributing
|
||||
transformers=2.4.0
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
### Usage
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
a. train_type is {0:"clean", 1:"noise", 2:"clean+noise"}
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
b. __"--meta_learn"__ is to set the instance weight for each [noise samples](https://arxiv.org/abs/1803.09050).
|
||||
|
||||
c. __"--multi_head"__ is to set the weak source count, if you have three different weak source, you should set it to 3.
|
||||
|
||||
## Trademarks
|
||||
d. __"--group_opt"__: specific optimizer for group weight. You can choose __Adam__ and __SGD__.
|
||||
|
||||
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
|
||||
trademarks or logos is subject to and must follow
|
||||
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
|
||||
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
|
||||
Any use of third-party trademarks or logos are subject to those third-party's policies.
|
||||
e. __"--gold_ratio"__: Float gold ratio for the training data. Default is 0 which will use \[0.02, 0.04, 0.06, 0.08, 0.1\] all the gold ratio. For gold ratio 0.02, set it as "--gold_ratio 0.02"
|
||||
|
||||
- Finetune on RoBERTa Group Weight
|
||||
|
||||
|
||||
python3 run_classifiy.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--evaluate_during_training --do_train --do_eval \
|
||||
--num_train_epochs 15 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 100 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--g_train_batch_size 5 \
|
||||
--s_train_batch_size 5 \
|
||||
--clf_model "robert" \
|
||||
--meta_learn \
|
||||
--weak_type "none" \
|
||||
--multi_head 3 \
|
||||
--use_group_net \
|
||||
--group_opt "adam" \
|
||||
--train_path "./data/political/weak" \
|
||||
--eval_path "./data/political/test.csv" \
|
||||
--fp16 \
|
||||
--fp16_opt_level O1\
|
||||
--learning_rate 1e-4 \
|
||||
--group_adam_epsilon 1e-9 \
|
||||
--group_lr 1e-3 \
|
||||
--gold_ratio 0.04 \
|
||||
--id "ParameterGroup1"
|
||||
|
||||
The log information will stored in
|
||||
|
||||
~/output
|
||||
|
||||
|
||||
|
||||
- CNN Baseline Model
|
||||
|
||||
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path distilbert-base-uncased \
|
||||
--evaluate_during_training --do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 30 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--weak_type most_vote \
|
||||
--per_gpu_train_batch_size 256 \
|
||||
--per_gpu_eval_batch_size 256 \
|
||||
--learning_rate 1e-3 \
|
||||
--clf_model cnn
|
||||
- CNN Instance Weight Model with multi classification heads
|
||||
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path distilbert-base-uncased \
|
||||
--evaluate_during_training --do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 256 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--per_gpu_eval_batch_size 256 \
|
||||
--g_train_batch_size 256 \
|
||||
--s_train_batch_size 256 \
|
||||
--learning_rate 1e-3 \
|
||||
--clf_model cnn \
|
||||
--meta_learn \
|
||||
--weak_type "none"
|
||||
|
||||
- CNN group weight
|
||||
|
||||
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path distilbert-base-uncased \
|
||||
--evaluate_during_training --do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 256 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--per_gpu_eval_batch_size 256 \
|
||||
--g_train_batch_size 256 \
|
||||
--s_train_batch_size 256 \
|
||||
--learning_rate 1e-3 \
|
||||
--clf_model cnn \
|
||||
--meta_learn \
|
||||
--weak_type "none" \
|
||||
--multi_head 3 \
|
||||
--use_group_weight \
|
||||
--group_opt "SGD" \
|
||||
--group_momentum 0.9 \
|
||||
--group_lr 1e-5
|
||||
|
||||
- RoBERTa Baseline Model
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--evaluate_during_training \
|
||||
--do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 30 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--weak_type most_vote \
|
||||
--per_gpu_train_batch_size 16 \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--learning_rate 5e-5 \
|
||||
--clf_model robert
|
||||
|
||||
- RoBERTa Instance Weight with Multi Head Classification
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--evaluate_during_training --do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 256 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--weak_type most_vote \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--g_train_batch_size 16 \
|
||||
--s_train_batch_size 16 \
|
||||
--learning_rate 5e-5 \
|
||||
--clf_model robert \
|
||||
--meta_learn \
|
||||
--weak_type "none" \
|
||||
--multi_head 3 \
|
||||
|
||||
|
||||
- RoBERTa Group Weight
|
||||
|
||||
|
||||
python run_classifiy.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--evaluate_during_training --do_train --do_eval --do_lower_case \
|
||||
--num_train_epochs 256 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 10 \
|
||||
--max_seq_length 256 \
|
||||
--weak_type most_vote \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--g_train_batch_size 16 \
|
||||
--s_train_batch_size 16 \
|
||||
--learning_rate 5e-5 \
|
||||
--clf_model robert \
|
||||
--meta_learn \
|
||||
--weak_type "none" \
|
||||
--multi_head 3 \
|
||||
--use_group_weight \
|
||||
--group_opt "SGD" \
|
||||
--group_momentum 0.9 \
|
||||
--group_lr 1e-5
|
||||
|
||||
|
||||
- Finetune on RoBERTa Group Weight
|
||||
|
||||
|
||||
python3 run_classifiy.py \
|
||||
--model_name_or_path roberta-base \
|
||||
--evaluate_during_training --do_train --do_eval \
|
||||
--num_train_epochs 15 \
|
||||
--output_dir ./output/ \
|
||||
--logging_steps 100 \
|
||||
--max_seq_length 256 \
|
||||
--train_type 0 \
|
||||
--per_gpu_eval_batch_size 16 \
|
||||
--g_train_batch_size 1 \
|
||||
--s_train_batch_size 1 \
|
||||
--clf_model "robert" \
|
||||
--meta_learn \
|
||||
--weak_type "none" \
|
||||
--multi_head 3 \
|
||||
--use_group_net \
|
||||
--group_opt "adam" \
|
||||
--train_path "./data/political/weak" \
|
||||
--eval_path "./data/political/test.csv" \
|
||||
--fp16 \
|
||||
--fp16_opt_level O1\
|
||||
--learning_rate "1e-4,5e-4,1e-5,5e-5" \
|
||||
--group_adam_epsilon "1e-9, 1e-8, 5e-8" \
|
||||
--group_lr "1e-3,1e-4,3e-4,5e-4,1e-5,5e-5" \
|
||||
--gold_ratio 0.04
|
||||
|
||||
The log information will stored in
|
||||
|
||||
~/ray_results/GoldRatio_{}_GroupNet
|
||||
|
||||
|
||||
You can run the following command to extract the best result which is sorted by the average of accuracy and f1.
|
||||
|
||||
export LOG_FILE=~/ray_results/GoldRatio_{}_GroupNet
|
||||
python read_json.py --file_name $LOG_FILE --save_dir ./output
|
||||
|
||||
In the meantime, you can visualize the log text by tensorboard
|
||||
|
||||
|
||||
tensorboard --logdir $LOG_FILE
|
||||
|
|
41
SECURITY.md
41
SECURITY.md
|
@ -1,41 +0,0 @@
|
|||
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.5 BLOCK -->
|
||||
|
||||
## Security
|
||||
|
||||
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
|
||||
|
||||
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
|
||||
|
||||
## Reporting Security Issues
|
||||
|
||||
**Please do not report security vulnerabilities through public GitHub issues.**
|
||||
|
||||
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
|
||||
|
||||
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
|
||||
|
||||
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
|
||||
|
||||
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
|
||||
|
||||
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
|
||||
* Full paths of source file(s) related to the manifestation of the issue
|
||||
* The location of the affected source code (tag/branch/commit or direct URL)
|
||||
* Any special configuration required to reproduce the issue
|
||||
* Step-by-step instructions to reproduce the issue
|
||||
* Proof-of-concept or exploit code (if possible)
|
||||
* Impact of the issue, including how an attacker might exploit the issue
|
||||
|
||||
This information will help us triage your report more quickly.
|
||||
|
||||
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
|
||||
|
||||
## Preferred Languages
|
||||
|
||||
We prefer all communications to be in English.
|
||||
|
||||
## Policy
|
||||
|
||||
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
|
||||
|
||||
<!-- END MICROSOFT SECURITY.MD BLOCK -->
|
25
SUPPORT.md
25
SUPPORT.md
|
@ -1,25 +0,0 @@
|
|||
# TODO: The maintainer of this repo has not yet edited this file
|
||||
|
||||
**REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
|
||||
|
||||
- **No CSS support:** Fill out this template with information about how to file issues and get help.
|
||||
- **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
|
||||
- **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
|
||||
|
||||
*Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
|
||||
|
||||
# Support
|
||||
|
||||
## How to file issues and get help
|
||||
|
||||
This project uses GitHub Issues to track bugs and feature requests. Please search the existing
|
||||
issues before filing new issues to avoid duplicates. For new issues, file your bug or
|
||||
feature request as a new Issue.
|
||||
|
||||
For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
|
||||
FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
|
||||
CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
|
||||
|
||||
## Microsoft Support Policy
|
||||
|
||||
Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
|
|
@ -0,0 +1,143 @@
|
|||
'''
|
||||
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
|
||||
Licensed under the MIT license.
|
||||
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li and Kai Shu
|
||||
'''
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from itertools import chain
|
||||
import os
|
||||
import pickle
|
||||
# from snorkel_process import weak_supervision
|
||||
class FakeNewsDataset(Dataset):
|
||||
def __init__(self, file_name, tokenizer, is_weak, max_length, weak_type="", overwrite=False, balance_weak=False):
|
||||
super(FakeNewsDataset, self).__init__()
|
||||
tokenizer_name = type(tokenizer).__name__
|
||||
# if tokenizer_name == "BertTokenizer" or tokenizer_name == "RobertaTokenizer":
|
||||
pickle_file = file_name.replace(".csv", '_{}_{}.pkl'.format(max_length, tokenizer_name))
|
||||
self.weak_label_count = 3
|
||||
if os.path.exists(pickle_file) and overwrite is False:
|
||||
load_data = pickle.load(open(pickle_file, "rb"))
|
||||
for key, value in load_data.items():
|
||||
setattr(self, key, value)
|
||||
else:
|
||||
save_data = {}
|
||||
data = pd.read_csv(file_name)
|
||||
# if tokenizer_name == "BertTokenizer" or tokenizer_name == "RobertaTokenizer":
|
||||
|
||||
self.news = [tokenizer.encode(i, max_length=max_length, pad_to_max_length=True) for i in data['news'].values.tolist()]
|
||||
self.attention_mask = [[1] * (i.index(tokenizer.pad_token_id) if tokenizer.pad_token_id in i else len(i))
|
||||
for i in self.news]
|
||||
self.attention_mask = [mask + [0] * (max_length - len(mask)) for mask in self.attention_mask]
|
||||
|
||||
if is_weak:
|
||||
self.weak_labels = []
|
||||
assert "label" not in data.columns, "noise data should not contain the clean label"
|
||||
self.weak_labels = data.iloc[:, list(range(1, len(data.columns)))].values.tolist()
|
||||
save_data.update({"weak_labels": data.iloc[:, list(range(1, len(data.columns)))].values.tolist()})
|
||||
else:
|
||||
self.labels = data['label'].values.tolist()
|
||||
|
||||
save_data.update({"news": self.news, "attention_mask": self.attention_mask})
|
||||
if is_weak is False:
|
||||
save_data.update({"labels": self.labels})
|
||||
|
||||
pickle.dump(save_data, open(pickle_file, "wb"))
|
||||
if is_weak:
|
||||
if weak_type == "most_vote":
|
||||
self.weak_labels = [1 if np.sum(i) > 1 else 0 for i in self.weak_labels]
|
||||
elif weak_type == "flat":
|
||||
self.weak_labels = list(chain.from_iterable(self.weak_labels))
|
||||
self.news = list(chain.from_iterable([[i] * self.weak_label_count for i in self.news]))
|
||||
self.attention_mask = list(
|
||||
chain.from_iterable([[i] * self.weak_label_count for i in self.attention_mask]))
|
||||
#"credit_label","polarity_label","bias_label"
|
||||
elif weak_type == "cred":
|
||||
self.weak_labels = [i[0] for i in self.weak_labels]
|
||||
elif weak_type == "polar":
|
||||
self.weak_labels = [i[1] for i in self.weak_labels]
|
||||
elif weak_type == "bias":
|
||||
self.weak_labels = [i[2] for i in self.weak_labels]
|
||||
self.is_weak = is_weak
|
||||
self.weak_type = weak_type
|
||||
if self.is_weak and balance_weak:
|
||||
self.__balance_helper()
|
||||
if self.is_weak:
|
||||
self.__instance_shuffle()
|
||||
|
||||
def __bert_tokenizer(self, tokenizer, max_length, data):
|
||||
encode_output = [tokenizer.encode_plus(i, max_length=max_length, pad_to_max_length=True) for i in
|
||||
data['news'].values.tolist()]
|
||||
self.news = [i['input_ids'] for i in encode_output]
|
||||
self.attention_mask = [i['attention_mask'] for i in encode_output]
|
||||
|
||||
self.token_type_ids = [i['token_type_ids'] for i in encode_output]
|
||||
|
||||
def __instance_shuffle(self):
|
||||
index_array = np.array(list(range(len(self))))
|
||||
np.random.shuffle(index_array)
|
||||
self.news = np.array(self.news)[index_array]
|
||||
self.weak_labels = np.array(self.weak_labels)[index_array]
|
||||
self.attention_mask = np.array(self.attention_mask)[index_array]
|
||||
|
||||
def __balance_helper(self):
|
||||
self.weak_labels = np.array(self.weak_labels)
|
||||
# minority_count = min(int(np.sum(self.weak_labels)), len(self.weak_labels) - int(np.sum(self.weak_labels)))
|
||||
majority_count = max(int(np.sum(self.weak_labels)), len(self.weak_labels) - int(np.sum(self.weak_labels)))
|
||||
one_index = np.argwhere(self.weak_labels == 1)
|
||||
zero_index = np.argwhere(self.weak_labels == 0)
|
||||
zero_index = list(zero_index.reshape(-1,)) + list(np.random.choice(len(zero_index), majority_count-len(zero_index)))
|
||||
one_index = list(one_index.reshape(-1,)) + list(np.random.choice(len(one_index), majority_count-len(one_index)))
|
||||
self.weak_labels = self.weak_labels[one_index + zero_index]
|
||||
self.news = np.array(self.news)[one_index + zero_index]
|
||||
self.attention_mask = np.array(self.attention_mask)[one_index + zero_index]
|
||||
if hasattr(self, "token_type_ids"):
|
||||
self.token_type_ids = np.array(self.token_type_ids)[one_index+zero_index]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.news)
|
||||
|
||||
def __getitem__(self, item):
|
||||
if self.is_weak:
|
||||
return torch.tensor(self.news[item]), torch.tensor(self.attention_mask[item]), torch.tensor(
|
||||
self.weak_labels[item])
|
||||
else:
|
||||
return torch.tensor(self.news[item]), torch.tensor(self.attention_mask[item]), torch.tensor(self.labels[item])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class SnorkelDataset(Dataset):
|
||||
def __init__(self, file_name, tokenizer, max_length, overwrite=False):
|
||||
super(SnorkelDataset, self).__init__()
|
||||
tokenizer_name = type(tokenizer).__name__
|
||||
pickle_file = file_name.replace(".csv", '_{}_{}.pkl'.format(max_length, tokenizer_name))
|
||||
assert os.path.exists(pickle_file), "please run loadFakeNewsDataset first"
|
||||
snorkel_file = pickle_file + "_snorkel"
|
||||
if os.path.exists(snorkel_file) and overwrite is False:
|
||||
snorkel_data = pickle.load(open(snorkel_file, "rb"))
|
||||
|
||||
else:
|
||||
snorkel_data = weak_supervision(pickle_file, snorkel_file)
|
||||
|
||||
|
||||
for key, value in snorkel_data.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.news)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return torch.tensor(self.news[item]), torch.tensor(self.attention_mask[item]), torch.tensor(self.snorkel_weak[item])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,474 @@
|
|||
'''
|
||||
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
|
||||
Licensed under the MIT license.
|
||||
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li and Kai Shu
|
||||
'''
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from copy import deepcopy
|
||||
import math
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
except:
|
||||
print("Install AMP!")
|
||||
|
||||
def modify_parameters(net, deltas, eps):
|
||||
for param, delta in zip(net.parameters(), deltas):
|
||||
if delta is None:
|
||||
continue
|
||||
param.data.add_(eps, delta)
|
||||
|
||||
|
||||
|
||||
def update_params_sgd(params, grads, opt, eta, args):
|
||||
# supports SGD-like optimizers
|
||||
ans = []
|
||||
|
||||
wdecay = opt.defaults.get('weight_decay', 0.)
|
||||
momentum = opt.defaults.get('momentum', 0.)
|
||||
# eta = opt.defaults["lr"]
|
||||
for i, param in enumerate(params):
|
||||
if grads[i] is None:
|
||||
ans.append(param)
|
||||
continue
|
||||
try:
|
||||
moment = opt.state[param]['momentum_buffer'] * momentum
|
||||
except:
|
||||
moment = torch.zeros_like(param)
|
||||
|
||||
dparam = grads[i] + param * wdecay
|
||||
|
||||
# eta is the learning tate
|
||||
ans.append(param - (dparam + moment) * eta)
|
||||
|
||||
return ans
|
||||
|
||||
def update_params_adam(params, grads, opt):
|
||||
ans = []
|
||||
group = opt.param_groups[0]
|
||||
assert len(opt.param_groups) == 1
|
||||
for p, grad in zip(params, grads):
|
||||
if grad is None:
|
||||
ans.append(p)
|
||||
continue
|
||||
amsgrad = group['amsgrad']
|
||||
state = opt.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p.data)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p.data)
|
||||
|
||||
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
|
||||
if amsgrad:
|
||||
max_exp_avg_sq = state['max_exp_avg_sq']
|
||||
beta1, beta2 = group['betas']
|
||||
|
||||
state['step'] += 1
|
||||
bias_correction1 = 1 - beta1 ** state['step']
|
||||
bias_correction2 = 1 - beta2 ** state['step']
|
||||
|
||||
if group['weight_decay'] != 0:
|
||||
grad.add_(group['weight_decay'], p.data)
|
||||
|
||||
# Decay the first and second moment running average coefficient
|
||||
exp_avg.mul_(beta1).add_(1 - beta1, grad)
|
||||
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
|
||||
if amsgrad:
|
||||
# Maintains the maximum of all 2nd moment running avg. till now
|
||||
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
|
||||
# Use the max. for normalizing running avg. of gradient
|
||||
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
else:
|
||||
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
|
||||
|
||||
step_size = group['lr'] / bias_correction1
|
||||
|
||||
|
||||
# ans.append(p.data.addcdiv(-step_size, exp_avg, denom))
|
||||
ans.append(torch.addcdiv(p, -step_size, exp_avg, denom))
|
||||
|
||||
return ans
|
||||
|
||||
# ============== l2w step procedure debug ===================
|
||||
# NOTE: main_net is implemented as nn.Module as usual
|
||||
|
||||
def step_l2w(main_net, main_opt, main_scheduler, g_input, s_input, train_input, args, gold_ratio):
|
||||
# init eps to 0
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
try:
|
||||
eta = main_scheduler.get_lr()[0]
|
||||
except:
|
||||
eta = main_opt.defaults.get("lr", 0)
|
||||
|
||||
|
||||
|
||||
eps = nn.Parameter(torch.zeros_like(s_input['labels'].float()))
|
||||
eps = eps.view(-1)
|
||||
|
||||
# flat the weight for multi head
|
||||
|
||||
|
||||
# calculate current weighted loss
|
||||
main_net.train()
|
||||
|
||||
loss_s = main_net(**s_input)[0]
|
||||
# {reduction: "none"} in s_inputs
|
||||
|
||||
loss_s = (eps * loss_s).sum()
|
||||
if gold_ratio > 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s = (loss_train + loss_s) / 2
|
||||
|
||||
|
||||
|
||||
# get theta grads
|
||||
# 1. update w to w'
|
||||
param_grads = torch.autograd.grad(loss_s, main_net.parameters(), allow_unused=True)
|
||||
|
||||
|
||||
params_new = update_params_sgd(main_net.parameters(), param_grads, main_opt, eta, args)
|
||||
# params_new = update_params_adam(main_net.parameters(), param_grads, main_opt)
|
||||
|
||||
# 2. set w as w'
|
||||
params = []
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
params.append(param.data.clone())
|
||||
param.data = params_new[i].data # use data only
|
||||
# 3. compute d_w' L_{D}(w')
|
||||
|
||||
loss_g = main_net(**g_input)[0]
|
||||
|
||||
params_new_grad = torch.autograd.grad(loss_g, main_net.parameters(), allow_unused=True)
|
||||
|
||||
# 4. revert from w' to w for main net
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
param.data = params[i]
|
||||
|
||||
# change main_net parameter
|
||||
_eps = 1e-6 # 1e-3 / _concat(params_new_grad).norm # eta 1e-6 before
|
||||
|
||||
# modify w to w+
|
||||
modify_parameters(main_net, params_new_grad, _eps)
|
||||
|
||||
loss_s_p = main_net(**s_input)[0]
|
||||
loss_s_p = (eps * loss_s_p).sum()
|
||||
if gold_ratio > 0:
|
||||
loss_train_p = main_net(**train_input)[0]
|
||||
loss_s_p = (loss_s_p + loss_train_p) / 2
|
||||
|
||||
|
||||
|
||||
# modify w to w- (from w+)
|
||||
modify_parameters(main_net, params_new_grad, -2 * _eps)
|
||||
loss_s_n = main_net(**s_input)[0]
|
||||
loss_s_n = (eps * loss_s_n).sum()
|
||||
if gold_ratio > 0:
|
||||
loss_train_n = main_net(**train_input)[0]
|
||||
loss_s_n = (loss_train_n + loss_s_n)
|
||||
|
||||
proxy_g = -eta * (loss_s_p - loss_s_n) / (2. * _eps)
|
||||
|
||||
# modify to original w
|
||||
modify_parameters(main_net, params_new_grad, _eps)
|
||||
eps_grad = torch.autograd.grad(proxy_g, eps, allow_unused=True)[0]
|
||||
|
||||
# update eps
|
||||
w = F.relu(-eps_grad)
|
||||
|
||||
if w.max() == 0:
|
||||
w = torch.ones_like(w)
|
||||
else:
|
||||
w = w / w.sum()
|
||||
|
||||
|
||||
loss_s = main_net(**s_input)[0]
|
||||
loss_s = (w * loss_s).sum()
|
||||
if gold_ratio > 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s = (loss_s + loss_train) / 2
|
||||
|
||||
# if info['step'] is not None:
|
||||
# writer.add_histogram("weight/GoldRatio_{}_InstanceWeight".format(info['gold_ratio']), w.detach(), global_step=info['step'])
|
||||
|
||||
if gold_ratio > 0:
|
||||
loss_s += main_net(**train_input)[0]
|
||||
|
||||
|
||||
|
||||
# main_opt.zero_grad()
|
||||
main_net.zero_grad()
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss_s, main_opt) as loss_s:
|
||||
loss_s.backward()
|
||||
else:
|
||||
loss_s.backward()
|
||||
main_opt.step()
|
||||
|
||||
|
||||
if type(main_scheduler).__name__ is "LambdaLR":
|
||||
main_scheduler.step(loss_s)
|
||||
else:
|
||||
main_scheduler.step()
|
||||
|
||||
main_net.eval()
|
||||
loss_g = main_net(**g_input)[0]
|
||||
return loss_g, loss_s
|
||||
|
||||
|
||||
|
||||
# w_net now computes both (ins_weight, g_weight)
|
||||
def step_l2w_group_net(main_net, main_opt, main_scheduler, g_input, s_input, train_input, args, gw, gw_opt,
|
||||
gw_scheduler, gold_ratio):
|
||||
#if args.fp16:
|
||||
# try:
|
||||
# from apex import amp
|
||||
# except ImportError:
|
||||
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
|
||||
# ATTENTION: s_input["labels"] is [bs, K]
|
||||
# forward function of gw is:
|
||||
# # group weight shape is [1, K]
|
||||
# group_weight = torch.sigmoid(self.pesu_group_weight)
|
||||
# final_weight = torch.matmul(iw.view(-1, 1), group_weight)
|
||||
# return (final_weight * (item_loss.view(final_weight.shape))).sum()
|
||||
|
||||
# get learn rate from optimizer or scheduler
|
||||
'''
|
||||
try:
|
||||
eta_group = gw_scheduler.get_lr()
|
||||
except:
|
||||
eta_group = gw_opt.defaults.get("lr", 0)
|
||||
'''
|
||||
|
||||
try:
|
||||
eta = main_scheduler.get_lr()
|
||||
if type(eta).__name__ == "list":
|
||||
eta = eta[0]
|
||||
except:
|
||||
eta = main_opt.defaults.get("lr", 0)
|
||||
|
||||
# calculate current weighted loss
|
||||
# ATTENTION: loss_s shape: [bs * K, 1]
|
||||
y_weak = s_input['labels']
|
||||
outputs_s = main_net(**s_input)
|
||||
s_feature = outputs_s[2]
|
||||
loss_s = outputs_s[0]
|
||||
loss_s, _ = gw(s_feature, y_weak, loss_s)
|
||||
|
||||
if gold_ratio > 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s = (loss_s + loss_train) / 2
|
||||
else:
|
||||
loss_s = loss_s
|
||||
|
||||
# get theta grads
|
||||
# 1. update w to w'
|
||||
param_grads = torch.autograd.grad(loss_s, main_net.parameters(), allow_unused=True)
|
||||
|
||||
# 2. set w as w'
|
||||
params = [param.data.clone() for param in main_net.parameters()]
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
if param_grads[i] is not None:
|
||||
param.data.sub_(eta*param_grads[i])
|
||||
|
||||
# 3. compute d_w' L_{D}(w')
|
||||
loss_g = main_net(**g_input)[0]
|
||||
|
||||
params_new_grad = torch.autograd.grad(loss_g, main_net.parameters(), allow_unused=True)
|
||||
|
||||
# 4. revert from w' to w for main net
|
||||
for i, param in enumerate(main_net.parameters()):
|
||||
param.data = params[i]
|
||||
|
||||
# change main_net parameter
|
||||
_eps = 1e-6 # 1e-3 / _concat(params_new_grad).norm # eta 1e-6 before
|
||||
|
||||
# modify w to w+
|
||||
modify_parameters(main_net, params_new_grad, _eps)
|
||||
outputs_s_p = main_net(**s_input)
|
||||
loss_s_p = outputs_s_p[0]
|
||||
s_p_feature = outputs_s_p[2]
|
||||
loss_s_p,_ = gw(s_p_feature, y_weak, loss_s_p)
|
||||
if gold_ratio > 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s_p = (loss_s_p + loss_train ) / 2
|
||||
|
||||
# loss_s_p = (eps * F.cross_entropy(logit_s_p, target_s, reduction='none')).sum()
|
||||
|
||||
# modify w to w- (from w+)
|
||||
modify_parameters(main_net, params_new_grad, -2 * _eps)
|
||||
outputs_s_n = main_net(**s_input)
|
||||
loss_s_n = outputs_s_n[0]
|
||||
s_n_feature = outputs_s_n[2]
|
||||
loss_s_n, _ = gw(s_n_feature, y_weak, loss_s_n)
|
||||
if gold_ratio > 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s_n = (loss_s_n + loss_train) / 2
|
||||
|
||||
# loss_s_n = (eps * F.cross_entropy(logit_s_n, target_s, reduction='none')).sum()
|
||||
|
||||
proxy_g = -eta * (loss_s_p - loss_s_n) / (2. * _eps)
|
||||
|
||||
# modify to original w
|
||||
modify_parameters(main_net, params_new_grad, _eps)
|
||||
|
||||
# eps_grad = torch.autograd.grad(proxy_g, eps)[0]
|
||||
# update gw
|
||||
gw_opt.zero_grad()
|
||||
if args.fp16:
|
||||
with amp.scale_loss(proxy_g, gw_opt) as proxy_gg:
|
||||
proxy_gg.backward()
|
||||
else:
|
||||
proxy_g.backward()
|
||||
gw_opt.step()
|
||||
|
||||
if type(main_scheduler).__name__ == "LambdaLR":
|
||||
gw_scheduler.step(proxy_g)
|
||||
else:
|
||||
gw_scheduler.step()
|
||||
# call scheduler for gw if applicable here
|
||||
|
||||
|
||||
outputs_s = main_net(**s_input)
|
||||
loss_s = outputs_s[0]
|
||||
s_feature = outputs_s[2]
|
||||
loss_s, instance_weight = gw(s_feature, y_weak, loss_s)
|
||||
|
||||
|
||||
# write the group weight and instance weight
|
||||
|
||||
# mean reduction
|
||||
if gold_ratio != 0:
|
||||
loss_train = main_net(**train_input)[0]
|
||||
loss_s = (loss_s + loss_train) / 2
|
||||
|
||||
main_opt.zero_grad()
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss_s, main_opt) as loss_ss:
|
||||
loss_ss.backward()
|
||||
else:
|
||||
loss_s.backward()
|
||||
main_opt.step()
|
||||
if type(main_scheduler).__name__ is "LambdaLR":
|
||||
main_scheduler.step(loss_s)
|
||||
else:
|
||||
main_scheduler.step()
|
||||
|
||||
return loss_g, loss_s, instance_weight
|
||||
|
||||
|
||||
|
||||
# def group_step_l2w(main_net, main_opt, group_weight, group_opt, val_input, s_input, g_input, args, scheduler,
|
||||
# group_scheduler, step=None, writer=None):
|
||||
# # init eps to 0
|
||||
# try:
|
||||
# eta = scheduler.get_lr()[0]
|
||||
# except:
|
||||
# eta = main_opt.defaults.get("lr", 0)
|
||||
#
|
||||
#
|
||||
# eps = nn.Parameter(torch.zeros_like(s_input['labels'][:,0].float()))
|
||||
# eps = eps.view(-1)
|
||||
#
|
||||
#
|
||||
#
|
||||
# # calculate current weighted loss
|
||||
# main_net.train()
|
||||
# loss_s = main_net(**s_input)[0]
|
||||
# # {reduction: "none"} in s_inputs
|
||||
# loss_s = (group_weight(eps, loss_s)).sum()
|
||||
#
|
||||
# # get theta grads
|
||||
# # 1. update w to w'
|
||||
# param_grads = torch.autograd.grad(loss_s, main_net.parameters(), allow_unused=True)
|
||||
#
|
||||
# params_new = update_params_sgd(main_net.parameters(), param_grads, main_opt, args, eta)
|
||||
# # params_new = update_params_adam(main_net.parameters(), param_grads, main_opt)
|
||||
#
|
||||
# # 2. set w as w'
|
||||
# params = []
|
||||
# for i, param in enumerate(main_net.parameters()):
|
||||
# params.append(param.data.clone())
|
||||
# param.data = params_new[i].data # use data only
|
||||
#
|
||||
# # 3. compute d_w' L_{D}(w')
|
||||
# loss_g = main_net(**val_input)[0]
|
||||
#
|
||||
# params_new_grad = torch.autograd.grad(loss_g, main_net.parameters(), allow_unused=True)
|
||||
#
|
||||
# # 4. revert from w' to w for main net
|
||||
# for i, param in enumerate(main_net.parameters()):
|
||||
# param.data = params[i]
|
||||
#
|
||||
# # change main_net parameter
|
||||
# _eps = 1e-6 # 1e-3 / _concat(params_new_grad).norm # eta 1e-6 before
|
||||
#
|
||||
# # modify w to w+
|
||||
# modify_parameters(main_net, params_new_grad, _eps)
|
||||
# loss_s_p = main_net(**s_input)[0]
|
||||
# loss_s_p = (group_weight(eps, loss_s_p)).sum()
|
||||
#
|
||||
# # modify w to w- (from w+)
|
||||
# modify_parameters(main_net, params_new_grad, -2 * _eps)
|
||||
# loss_s_n = main_net(**s_input)[0]
|
||||
# loss_s_n = (group_weight(eps, loss_s_n)).sum()
|
||||
#
|
||||
# proxy_g = -eta * (loss_s_p - loss_s_n) / (2. * _eps)
|
||||
#
|
||||
# # modify to original w
|
||||
# modify_parameters(main_net, params_new_grad, _eps)
|
||||
#
|
||||
# grads = torch.autograd.grad(proxy_g, [eps, group_weight.pesu_group_weight], allow_unused=True)
|
||||
# eps_grad = grads[0]
|
||||
# group_weight_grad = grads[1]
|
||||
#
|
||||
# # update eps
|
||||
# w = F.relu(-eps_grad)
|
||||
#
|
||||
# if w.max() == 0:
|
||||
# w = torch.ones_like(w)
|
||||
# else:
|
||||
# w = w / w.sum()
|
||||
#
|
||||
# group_opt.zero_grad()
|
||||
# group_weight.pesu_group_weight.grad = group_weight_grad
|
||||
# group_opt.step()
|
||||
# group_scheduler.step(proxy_g)
|
||||
#
|
||||
# loss_s = main_net(**s_input)[0]
|
||||
#
|
||||
# loss_s = (group_weight(w, loss_s)).sum()
|
||||
#
|
||||
# if step is not None:
|
||||
# writer.add_histogram("weight/InstanceWeight", w.detach(), global_step=step)
|
||||
# if group_weight is not None:
|
||||
# writer.add_histogram("Weight/GroupWeight", group_weight.pesu_group_weight.data, global_step=step)
|
||||
#
|
||||
# if g_input is not None:
|
||||
# loss_s += main_net(**g_input)[0]
|
||||
#
|
||||
#
|
||||
# # main_opt.zero_grad()
|
||||
# main_net.zero_grad()
|
||||
# loss_s.backward()
|
||||
# main_opt.step()
|
||||
#
|
||||
# if scheduler is not None:
|
||||
# scheduler.step(loss_s)
|
||||
#
|
||||
# main_net.eval()
|
||||
# loss_g = main_net(**val_input)[0]
|
||||
# return loss_g, loss_s
|
|
@ -0,0 +1,883 @@
|
|||
'''
|
||||
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
|
||||
Licensed under the MIT license.
|
||||
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li and Kai Shu
|
||||
'''
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||
from tqdm import trange
|
||||
from dataset import FakeNewsDataset, SnorkelDataset
|
||||
# from dataset import FakeNewsDataset
|
||||
from l2w import step_l2w
|
||||
from l2w import step_l2w_group_net
|
||||
# from model import RobertaForSequenceClassification, CNN_Text, GroupWeightModel
|
||||
from model import RobertaForSequenceClassification, CNN_Text, FullWeightModel, GroupWeightModel, \
|
||||
BertForSequenceClassification
|
||||
import time
|
||||
|
||||
import os, sys
|
||||
from model import DistilBertForSequenceClassification
|
||||
|
||||
from transformers import (
|
||||
AdamW,
|
||||
get_linear_schedule_with_warmup,
|
||||
DistilBertTokenizer,
|
||||
DistilBertConfig,
|
||||
RobertaConfig,
|
||||
RobertaTokenizer,
|
||||
BertTokenizer,
|
||||
BertConfig
|
||||
)
|
||||
|
||||
import shutil
|
||||
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
|
||||
import os
|
||||
|
||||
writer = None
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
except ImportError:
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#
|
||||
# MODEL_CLASSES = {
|
||||
# "cnn":(None, CNN_Text, DistilBertTokenizer),
|
||||
# "albert":(AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer)
|
||||
# }
|
||||
TRAIN_TYPE = ["gold", "silver", "gold_con_silver"]
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
torch.manual_seed(args.seed)
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
|
||||
def acc_f1_confusion(preds, labels):
|
||||
acc_f1 = {'acc': accuracy_score(y_pred=preds, y_true=labels), 'f1': f1_score(y_pred=preds, y_true=labels)}
|
||||
acc_f1.update({"acc_and_f1": (acc_f1['acc'] + acc_f1['f1']) / 2})
|
||||
c_m = ",".join([str(i) for i in confusion_matrix(y_true=labels, y_pred=preds).ravel()])
|
||||
acc_f1.update({"c_m": c_m})
|
||||
return acc_f1
|
||||
|
||||
|
||||
'''
|
||||
Test ray
|
||||
'''
|
||||
|
||||
|
||||
def ray_meta_train(config):
|
||||
for key, value in config.items():
|
||||
if "tuneP_" in key:
|
||||
key = key.replace("tuneP_", "")
|
||||
setattr(config['args'], key, value)
|
||||
# train_mnist(config, config['args'])
|
||||
meta_train(config['args'], config['gold_ratio'])
|
||||
# with open("/home/yichuan/ray_results/group_weight_np_array_{}.pkl".format(config['gold_ratio']),'wb') as f1:
|
||||
# pickle.dump(instance_weight, f1)
|
||||
# return np.mean(meta_train(config['args'], config['gold_ratio'])[-1][0:2])
|
||||
|
||||
|
||||
def build_model(args):
|
||||
if args.clf_model.lower() == "cnn":
|
||||
# easy for text tokenization
|
||||
tokenizer = DistilBertTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case)
|
||||
model = CNN_Text(args)
|
||||
|
||||
elif args.clf_model.lower() == "robert":
|
||||
print("name is {}".format(args.model_name_or_path))
|
||||
tokenizer = RobertaTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case
|
||||
)
|
||||
|
||||
config = RobertaConfig.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
num_labels=args.num_labels,
|
||||
finetuning_task=args.task_name)
|
||||
|
||||
model = RobertaForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
config=config
|
||||
)
|
||||
# freeze the weight for transformers
|
||||
if args.freeze:
|
||||
for n, p in model.named_parameters():
|
||||
if "bert" in n:
|
||||
p.requires_grad = False
|
||||
elif args.clf_model.lower() == "bert":
|
||||
tokenizer = BertTokenizer.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case
|
||||
)
|
||||
|
||||
config = BertConfig.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
num_labels=args.num_labels,
|
||||
finetuning_task=args.task_name)
|
||||
|
||||
model = BertForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
config=config
|
||||
)
|
||||
# freeze the weight for transformers
|
||||
# if args.freeze:
|
||||
# for n, p in model.named_parameters():
|
||||
# if "bert" in n:
|
||||
# p.requires_grad = False
|
||||
|
||||
else:
|
||||
tokenizer = DistilBertTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case)
|
||||
config = DistilBertConfig.from_pretrained(args.model_name_or_path, num_labels=args.num_labels,
|
||||
finetuning_task=args.task_name)
|
||||
model = DistilBertForSequenceClassification.from_pretrained(args.model_name_or_path, config=config)
|
||||
|
||||
model.expand_class_head(args.multi_head)
|
||||
model = model.to(args.device)
|
||||
return tokenizer, model
|
||||
|
||||
|
||||
def train(args, train_dataset, val_dataset, model, tokenizer, gold_ratio, **kwargs):
|
||||
""" Train the model """
|
||||
best_acc = 0.
|
||||
best_f1 = 0.
|
||||
val_acc_and_f1 = 0.
|
||||
best_acc_and_f1 = 0.
|
||||
best_c_m = ""
|
||||
best_loss_val = 10000
|
||||
|
||||
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
# train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=4)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(train_dataloader)) + 1
|
||||
else:
|
||||
t_total = len(train_dataloader) * args.num_train_epochs
|
||||
|
||||
if args.clf_model is not "cnn":
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
else:
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
logger.info("Optimizer type: ")
|
||||
logger.info(type(optimizer).__name__)
|
||||
|
||||
# Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num examples = %d", len(train_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
||||
args.train_batch_size
|
||||
)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
loss_scalar = 0.
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch"
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility
|
||||
|
||||
for _ in train_iterator:
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]}
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
if args.fp16:
|
||||
with amp.scale_loss(loss, optimizer) as loss:
|
||||
loss.backward()
|
||||
else:
|
||||
loss.backward()
|
||||
|
||||
tr_loss += loss.item()
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
if args.clf_model is not 'cnn':
|
||||
scheduler.step()
|
||||
else:
|
||||
scheduler.step(loss)
|
||||
|
||||
model.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
if args.logging_steps != 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
if (
|
||||
args.evaluate_during_training
|
||||
):
|
||||
|
||||
results = evaluate(args, model, tokenizer, gold_ratio, eval_dataset=val_dataset)
|
||||
results.update({"type": "val"})
|
||||
if (val_acc_and_f1 < results['acc_and_f1'] and args.val_acc_f1) \
|
||||
or (best_loss_val > results['loss'] and args.val_acc_f1 is False):
|
||||
val_acc_and_f1 = results['acc_and_f1']
|
||||
best_loss_val = results['loss']
|
||||
results = evaluate(args, model, tokenizer, gold_ratio)
|
||||
best_acc = results['acc']
|
||||
best_f1 = results['f1']
|
||||
best_acc_and_f1 = results["acc_and_f1"]
|
||||
best_c_m = results['c_m']
|
||||
results.update({"type": "test"})
|
||||
print(json.dumps(results))
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
logging.info(
|
||||
"Training Loss is {}".format(loss_scalar if loss_scalar > 0 else tr_loss / args.logging_steps))
|
||||
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
||||
learning_rate_scalar = optimizer.defaults.get("lr", 0)
|
||||
logs["learning_rate"] = learning_rate_scalar
|
||||
logs["loss"] = loss_scalar
|
||||
# if results['type'] == "test":
|
||||
|
||||
logging_loss = tr_loss
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
# logs_epoch = {}
|
||||
# results = evaluate(args, model, tokenizer, gold_ratio)
|
||||
# if val_acc_and_f1 < results['acc_and_f1']:
|
||||
# best_f1 = results['f1']
|
||||
# best_acc = results['acc']
|
||||
# best_c_m = results["c_m"]
|
||||
# val_acc_and_f1 = results['acc_and_f1']
|
||||
# for key, value in results.items():
|
||||
# eval_key = "eval_{}".format(key)
|
||||
# logs_epoch[eval_key] = value
|
||||
# print("EPOCH Finish")
|
||||
# print("EPOCH Result {}".format(json.dumps(logs_epoch)))
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
|
||||
return global_step, tr_loss / global_step, (best_f1, best_acc, best_c_m)
|
||||
|
||||
|
||||
import copy
|
||||
|
||||
|
||||
def meta_train(args, gold_ratio):
|
||||
""" Train the model """
|
||||
best_acc = 0.
|
||||
best_f1 = 0.
|
||||
best_loss_val = 100000
|
||||
val_acc_and_f1 = 0.
|
||||
best_cm = ""
|
||||
fake_acc_and_f1 = 0.
|
||||
fake_best_f1 = 0.
|
||||
fake_best_acc = 0.
|
||||
writer = None
|
||||
tokenizer, model = build_model(args)
|
||||
g_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=args.gold_train_path)
|
||||
s_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=args.silver_train_path, is_weak=True,
|
||||
weak_type=args.weak_type)
|
||||
val_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=args.val_path)
|
||||
|
||||
eval_dataset = copy.deepcopy(val_dataset)
|
||||
|
||||
# make a copy of train and test towards similar size as the weak source
|
||||
if True:
|
||||
max_length = max(len(g_dataset), len(s_dataset), len(val_dataset))
|
||||
g_dataset = torch.utils.data.ConcatDataset([g_dataset] * int(max_length / len(g_dataset)))
|
||||
s_dataset = torch.utils.data.ConcatDataset([s_dataset] * int(max_length / len(s_dataset)))
|
||||
val_dataset = torch.utils.data.ConcatDataset([val_dataset] * int(max_length / len(val_dataset)))
|
||||
|
||||
g_sampler = RandomSampler(val_dataset)
|
||||
g_dataloader = DataLoader(val_dataset, sampler=g_sampler, batch_size=args.g_train_batch_size)
|
||||
|
||||
train_sampler = RandomSampler(g_dataset)
|
||||
train_dataloader = DataLoader(g_dataset, sampler=train_sampler, batch_size=args.g_train_batch_size)
|
||||
|
||||
s_sampler = RandomSampler(s_dataset)
|
||||
s_dataloader = DataLoader(s_dataset, sampler=s_sampler, batch_size=args.s_train_batch_size)
|
||||
|
||||
if args.max_steps > 0:
|
||||
t_total = args.max_steps
|
||||
args.num_train_epochs = args.max_steps // (len(g_dataloader)) + 1
|
||||
else:
|
||||
if gold_ratio == 0:
|
||||
t_total = min(len(g_dataloader), len(s_dataloader)) * args.num_train_epochs
|
||||
else:
|
||||
t_total = min(len(g_dataloader), len(train_dataloader), len(s_dataloader)) * args.num_train_epochs
|
||||
|
||||
if args.clf_model is not "cnn":
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if
|
||||
not any(nd in n for nd in no_decay) and p.requires_grad],
|
||||
"weight_decay": args.weight_decay,
|
||||
},
|
||||
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
|
||||
"weight_decay": 0.0},
|
||||
]
|
||||
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
||||
scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
||||
)
|
||||
else:
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
|
||||
weight_decay=args.weight_decay)
|
||||
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, t_total / args.num_train_epochs)
|
||||
|
||||
if args.fp16:
|
||||
try:
|
||||
from apex import amp
|
||||
except ImportError:
|
||||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
if args.use_group_weight or args.use_group_net:
|
||||
#
|
||||
if args.use_group_weight:
|
||||
group_weight = GroupWeightModel(n_groups=args.multi_head)
|
||||
else:
|
||||
group_weight = FullWeightModel(n_groups=args.multi_head, hidden_size=args.hidden_size)
|
||||
group_weight = group_weight.to(args.device)
|
||||
parameters = [i for i in group_weight.parameters() if i.requires_grad]
|
||||
if "adam" in args.group_opt.lower():
|
||||
|
||||
if "w" in args.group_opt.lower():
|
||||
group_optimizer = AdamW(parameters, lr=args.group_lr, eps=args.group_adam_epsilon,
|
||||
weight_decay=args.group_weight_decay)
|
||||
else:
|
||||
group_optimizer = torch.optim.Adam(parameters, lr=args.group_lr, eps=args.group_adam_epsilon,
|
||||
weight_decay=args.group_weight_decay)
|
||||
|
||||
group_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(group_optimizer,
|
||||
t_total / args.num_train_epochs)
|
||||
# group_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps,
|
||||
# num_training_steps=t_total)
|
||||
elif args.group_opt.lower() == "sgd":
|
||||
group_optimizer = torch.optim.SGD(parameters, lr=args.group_lr, momentum=args.group_momentum,
|
||||
weight_decay=args.group_weight_decay)
|
||||
group_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(group_optimizer, 'min')
|
||||
|
||||
if args.fp16:
|
||||
group_weight, group_optimizer= amp.initialize(group_weight, group_optimizer, opt_level=args.fp16_opt_level)
|
||||
|
||||
# # Train!
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(" Num Gold examples = %d, Silver Examples = %d", len(val_dataset), len(s_dataset))
|
||||
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
||||
logger.info(
|
||||
" Total train batch size (w. parallel, distributed & accumulation) = %d, %d",
|
||||
args.g_train_batch_size, args.s_train_batch_size
|
||||
)
|
||||
logger.info(" Total optimization steps = %d", t_total)
|
||||
|
||||
global_step = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
|
||||
g_loss, logging_g_loss, logging_s_loss, s_loss = 0.0, 0.0, 0.0, 0.0
|
||||
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch"
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility
|
||||
temp_output = open(args.flat_output_file+"_step", "w+", 1)
|
||||
for _ in train_iterator:
|
||||
be_changed = False
|
||||
for step, (g_batch, s_batch, train_batch) in enumerate(zip(g_dataloader, s_dataloader, train_dataloader)):
|
||||
|
||||
# Skip past any already trained steps if resuming training
|
||||
if steps_trained_in_current_epoch > 0:
|
||||
steps_trained_in_current_epoch -= 1
|
||||
continue
|
||||
|
||||
model.train()
|
||||
g_batch = tuple(t.to(args.device) for t in g_batch)
|
||||
g_input = {"input_ids": g_batch[0], "attention_mask": g_batch[1], "labels": g_batch[2]}
|
||||
|
||||
s_batch = tuple(t.to(args.device) for t in s_batch)
|
||||
s_input = {"input_ids": s_batch[0], "attention_mask": s_batch[1], "labels": s_batch[2],
|
||||
"reduction": 'none'}
|
||||
|
||||
train_batch = tuple(t.to(args.device) for t in train_batch)
|
||||
train_input = {"input_ids": train_batch[0], "attention_mask": train_batch[1], "labels": train_batch[2]}
|
||||
# ATTENTION: RoBERTa does not need token types id
|
||||
if args.multi_head > 1:
|
||||
s_input.update({"is_gold": False})
|
||||
|
||||
if (global_step + 1) % args.logging_steps == 0:
|
||||
step_input = global_step
|
||||
else:
|
||||
step_input = None
|
||||
info = {"gold_ratio": gold_ratio, "step": step_input}
|
||||
|
||||
|
||||
if args.use_group_net:
|
||||
outputs = step_l2w_group_net(model, optimizer, scheduler, g_input, s_input, train_input, args,
|
||||
group_weight, group_optimizer, group_scheduler, gold_ratio)
|
||||
|
||||
loss_g, loss_s, instance_weight = outputs
|
||||
else:
|
||||
outputs = step_l2w(model, optimizer, scheduler, g_input, s_input, train_input, args, gold_ratio)
|
||||
loss_g, loss_s = outputs
|
||||
|
||||
g_loss += loss_g.item()
|
||||
s_loss += loss_s.item()
|
||||
global_step += 1
|
||||
|
||||
if args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
||||
logs = {}
|
||||
results = {}
|
||||
if (args.evaluate_during_training) or True:
|
||||
|
||||
results = evaluate(args, model, tokenizer, gold_ratio, eval_dataset=eval_dataset)
|
||||
results = {key + "_val": value for key, value in results.items()}
|
||||
results.update({"type": "val"})
|
||||
print(json.dumps(results))
|
||||
if val_acc_and_f1 < results['acc_and_f1_val'] :
|
||||
be_changed = True
|
||||
best_loss_val = results['loss_val']
|
||||
val_acc_and_f1 = results['acc_and_f1_val']
|
||||
test_results = evaluate(args, model, tokenizer, gold_ratio)
|
||||
best_acc = test_results['acc']
|
||||
best_f1 = test_results['f1']
|
||||
best_cm = test_results['c_m']
|
||||
best_acc_and_f1 = test_results["acc_and_f1"]
|
||||
temp_output.write("Step: {}, Test F1: {}, Test ACC: {}; Val Acc_and_F1: {}, Val Loss: {}\n".format(global_step, best_f1, best_acc, val_acc_and_f1, best_loss_val))
|
||||
temp_output.flush()
|
||||
# save the model
|
||||
if args.save_model:
|
||||
save_path = args.flat_output_file + "_save_model"
|
||||
save_dic = {"BaseModel": model,
|
||||
"LWN":group_weight,
|
||||
"step":global_step,
|
||||
"tokenizer":tokenizer
|
||||
}
|
||||
torch.save(save_dic, save_path)
|
||||
|
||||
test_results = {key + "_test": value for key, value in test_results.items()}
|
||||
test_results.update({"type": "test"})
|
||||
print(json.dumps(test_results))
|
||||
for key, value in results.items():
|
||||
eval_key = "eval_{}".format(key)
|
||||
logs[eval_key] = value
|
||||
|
||||
loss_scalar = (g_loss - logging_g_loss) / args.logging_steps
|
||||
learning_rate_scalar = optimizer.defaults.get("lr", 0)
|
||||
logs["train_learning_rate"] = learning_rate_scalar
|
||||
logs["train_g_loss"] = loss_scalar
|
||||
logs["train_s_loss"] = (s_loss - logging_s_loss) / args.logging_steps
|
||||
logging_g_loss = g_loss
|
||||
logging_s_loss = s_loss
|
||||
|
||||
# writer.add_scalar("Loss/g_train_{}".format(gold_ratio), logs['train_g_loss'], global_step)
|
||||
# writer.add_scalar("Loss/s_train_{}".format(gold_ratio), logs['train_s_loss'], global_step)
|
||||
# writer.add_scalar("Loss/val_train_{}".format(gold_ratio), results['loss_val'], global_step)
|
||||
|
||||
|
||||
if args.use_group_weight:
|
||||
try:
|
||||
eta_group = group_optimizer.get_lr()
|
||||
except:
|
||||
eta_group = group_optimizer.defaults.get("lr", 0)
|
||||
|
||||
# writer.add_scalar("Loss/group_lr_{}".format(gold_ratio), eta_group, global_step)
|
||||
|
||||
print(json.dumps({**{"step": global_step}, **logs}))
|
||||
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
break
|
||||
if (args.use_group_net or args.use_group_weight) and isinstance(group_scheduler,
|
||||
torch.optim.lr_scheduler.CosineAnnealingLR):
|
||||
group_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(group_optimizer,
|
||||
t_total / args.num_train_epochs)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, t_total / args.num_train_epochs)
|
||||
|
||||
|
||||
print("EPOCH Finish")
|
||||
if args.max_steps > 0 and global_step > args.max_steps:
|
||||
train_iterator.close()
|
||||
break
|
||||
temp_output.close()
|
||||
# return cache_instance_weight
|
||||
return global_step, g_loss / global_step, (best_f1, best_acc, best_cm)
|
||||
|
||||
|
||||
def evaluate(args, model, tokenizer, gold_ratio, prefix="", eval_dataset=None):
|
||||
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
||||
|
||||
|
||||
results = {}
|
||||
|
||||
if eval_dataset is None:
|
||||
eval_dataset = load_fake_news(args, tokenizer, evaluate=True)
|
||||
|
||||
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
||||
# Note that DistributedSampler samples randomly
|
||||
eval_sampler = SequentialSampler(eval_dataset)
|
||||
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
||||
|
||||
# Eval!
|
||||
print("\n")
|
||||
logger.info("***** Running evaluation {} *****".format(prefix))
|
||||
logger.info(" Num examples = %d", len(eval_dataset))
|
||||
logger.info(" Batch size = %d", args.eval_batch_size)
|
||||
eval_loss = 0.0
|
||||
nb_eval_steps = 0
|
||||
preds = None
|
||||
out_label_ids = None
|
||||
for batch in eval_dataloader:
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
|
||||
with torch.no_grad():
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[2]}
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
eval_loss += tmp_eval_loss.mean().item()
|
||||
nb_eval_steps += 1
|
||||
if preds is None:
|
||||
preds = logits.detach().cpu().numpy()
|
||||
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
||||
else:
|
||||
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
||||
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
||||
|
||||
eval_loss = eval_loss / nb_eval_steps
|
||||
|
||||
preds = np.argmax(preds, axis=1)
|
||||
|
||||
result = acc_f1_confusion(preds, out_label_ids)
|
||||
results.update(result)
|
||||
results.update({"loss": eval_loss})
|
||||
|
||||
logger.info("***** Eval results {} Gold Ratio={} *****".format(prefix, gold_ratio))
|
||||
for key in sorted(results.keys()):
|
||||
logger.info(" %s = %s", key, str(results[key]))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def load_fake_news(args, tokenizer, is_weak=False, evaluate=False, train_path=None, weak_type=""):
|
||||
file_path = args.eval_path if evaluate else train_path
|
||||
if args.use_snorkel and "noise" in file_path:
|
||||
dataset = SnorkelDataset(file_path, tokenizer, args.max_seq_length, overwrite=True)
|
||||
else:
|
||||
dataset = FakeNewsDataset(file_path, tokenizer, is_weak, args.max_seq_length, weak_type, args.overwrite_cache,
|
||||
args.balance_weak)
|
||||
return dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'--ray_dir',
|
||||
default='.',
|
||||
type=str,
|
||||
help='Path to Ray tuned results (Default: current directory)',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default="distilbert-base-uncased",
|
||||
type=str,
|
||||
help="Path to pre-trained model or shortcut name selected in the list",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--max_seq_length",
|
||||
default=256,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded.",
|
||||
)
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=16, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--g_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for gold training.")
|
||||
parser.add_argument("--s_train_batch_size", default=4, type=int, help="Batch size per GPU/CPU for silver training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=128, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=1e-3,type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
||||
)
|
||||
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
||||
|
||||
parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
# parser.add_argument("--train_path", type=str, default="/home/yichuan/MSS/data/gossip/weak", help="For distant debugging.")
|
||||
# parser.add_argument("--eval_path", type=str, default="/home/yichuan/MSS/data/gossip/test.csv", help="For distant debugging.")
|
||||
|
||||
parser.add_argument("--train_path", type=str, default="./data/gossip/weak",
|
||||
help="For distant debugging.")
|
||||
parser.add_argument("--eval_path", type=str, default="./data/gossip/test.csv",
|
||||
help="For distant debugging.")
|
||||
parser.add_argument("--meta_learn", action="store_true", help="Whether use meta learning or not")
|
||||
parser.add_argument("--train_type", type=int, default=0,
|
||||
help="0: only clean data, 1: only noise data, 2: concat clean and noise data")
|
||||
parser.add_argument("--weak_type", type=str, default="most_vote",
|
||||
help="method for the weak superivision; for multi-head please set as none")
|
||||
parser.add_argument("--multi_head", type=int, default=1, help="count of head for classification task")
|
||||
|
||||
# CNN parameters
|
||||
parser.add_argument('--dropout', type=float, default=0.5, help='the probability for dropout [default: 0.5]')
|
||||
parser.add_argument('--max-norm', type=float, default=3.0, help='l2 constraint of parameters [default: 3.0]')
|
||||
parser.add_argument('--kernel-num', type=int, default=100, help='number of each kind of kernel')
|
||||
parser.add_argument('--kernel-sizes', type=str, default='3,4,5',
|
||||
help='comma-separated kernel size to use for convolution')
|
||||
parser.add_argument('--momentum', type=float, default=0.9,
|
||||
help='momentum for classification cross-entropy classification')
|
||||
parser.add_argument("--use_group_weight", action="store_true")
|
||||
parser.add_argument("--use_group_net", action="store_true")
|
||||
|
||||
parser.add_argument("--clf_model", type=str, default="cnn", help="fake news classification model 'cnn', 'bert' ")
|
||||
parser.add_argument("--group_lr", type=float, default=1e-5, help="learn rate for group weight")
|
||||
parser.add_argument("--group_momentum", type=float, default=0.9, help="momentum for group weight")
|
||||
parser.add_argument("--group_weight_decay", type=float, default=0.0, help="weight decay for group weight")
|
||||
parser.add_argument("--group_adam_epsilon", type=float, default=1e-8, help="adam epsilon")
|
||||
parser.add_argument("--group_opt", type=str, default="sgd", help="optimizer type for group weight")
|
||||
parser.add_argument("--freeze", action="store_true")
|
||||
parser.add_argument("--balance_weak", action="store_true")
|
||||
|
||||
# validation setting
|
||||
parser.add_argument("--val_acc_f1", action="store_true",
|
||||
help="Whether use the (f1+acc)/2 as the metric for model selection on validation dataset")
|
||||
parser.add_argument("--gold_ratio", default=0, type=float, help="gold ratio selection")
|
||||
|
||||
# baseline setting
|
||||
parser.add_argument("--use_snorkel", action="store_true",
|
||||
help="Snorkel baseline which use LabelModel to combine multiple weak source")
|
||||
parser.add_argument("--fp16", action='store_true', help='whehter use fp16 or not')
|
||||
parser.add_argument(
|
||||
"--fp16_opt_level",
|
||||
type=str,
|
||||
default="O1",
|
||||
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
||||
"See details at https://nvidia.github.io/apex/amp.html",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--id",
|
||||
type=str,
|
||||
default="",
|
||||
help="id for this group of parameters"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--save_model",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.clf_model = args.clf_model.lower()
|
||||
# args.gold_ratio = [args.gold_ratio] if (
|
||||
# args.gold_ratio != 0 and args.gold_ratio in [0.02, 0.04, 0.06, 0.08, 0.1]) else [0.06, 0.04, 0.08, 0.02, 0.1]
|
||||
|
||||
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
|
||||
args.hidden_size = len(args.kernel_sizes) * args.kernel_num if args.clf_model == "cnn" else 768
|
||||
|
||||
# Setup CUDA, GPU & distributed training
|
||||
if args.no_cuda:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
||||
args.n_gpu = torch.cuda.device_count()
|
||||
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
||||
device = torch.device("cuda")
|
||||
args.n_gpu = 1
|
||||
args.device = device
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
# Set seed
|
||||
set_seed(args)
|
||||
|
||||
# binary classification task
|
||||
args.task_name = "mrpc".lower()
|
||||
# 0 for genuine news; 1 for fake news
|
||||
args.num_labels = 2
|
||||
if "/" != args.eval_path[0]:
|
||||
args.eval_path = os.path.join(os.getcwd(), args.eval_path)
|
||||
if "/" != args.train_path[0]:
|
||||
args.train_path = os.path.join(os.getcwd(), args.train_path)
|
||||
|
||||
dataset_type = "political" if "political" in args.eval_path else "gossip"
|
||||
args.t_out_dir = os.path.join(args.output_dir, "{}_{}_{}_{}".format(args.clf_model,
|
||||
"meta" if args.meta_learn else TRAIN_TYPE[
|
||||
args.train_type],
|
||||
dataset_type, args.weak_type))
|
||||
|
||||
assert args.use_group_weight + args.use_group_net != 2, "You should choose GroupWeight or GroupNet, not both of them "
|
||||
|
||||
if args.use_group_weight:
|
||||
args.t_out_dir += "_group"
|
||||
if args.use_group_net:
|
||||
args.t_out_dir += "_group_net"
|
||||
elif args.meta_learn:
|
||||
args.t_out_dir += "_L2W"
|
||||
if args.use_snorkel:
|
||||
args.t_out_dir += "_snorkel"
|
||||
|
||||
# ATTENTION: for batch run, the gold ratio should settup manually
|
||||
args.gold_ratio = [args.gold_ratio]
|
||||
assert len(args.gold_ratio) == 1, "For computation efficiency, please run one gold ratio at a time"
|
||||
flat_output_file = os.path.join(args.t_out_dir, "result_{}.txt".format(args.gold_ratio[0]))
|
||||
|
||||
if len(args.id) > 0:
|
||||
flat_output_file += "-" + str(args.id)
|
||||
else:
|
||||
flat_output_file += "-" + str(time.time())
|
||||
|
||||
|
||||
# will overwrite now
|
||||
#if os.path.exists(flat_output_file):
|
||||
# raise FileExistsError("The result file already exist, please check it")
|
||||
setattr(args, "flat_output_file", flat_output_file)
|
||||
logging.info(args.t_out_dir)
|
||||
|
||||
# try:
|
||||
# shutil.rmtree(args.t_out_dir)
|
||||
# except FileNotFoundError:
|
||||
# print("File Already deleted")
|
||||
global writer
|
||||
writer = SummaryWriter(args.t_out_dir)
|
||||
fout1 = open(flat_output_file, "w")
|
||||
fout1.write("GoldRatio\tF1\tACC\tCM\n")
|
||||
|
||||
# 1e-5; GN 1e-5
|
||||
for gold_ratio in args.gold_ratio:
|
||||
gold_train_path = os.path.join(args.train_path, "gold_{}.csv".format(gold_ratio))
|
||||
silver_train_path = os.path.join(args.train_path, "noise_{}.csv".format(gold_ratio))
|
||||
val_path = os.path.join(args.train_path, "../val.csv")
|
||||
|
||||
logger.info("Training/evaluation parameters %s", args)
|
||||
|
||||
# Training
|
||||
if args.do_train:
|
||||
|
||||
args.gold_train_path = gold_train_path
|
||||
args.silver_train_path = silver_train_path
|
||||
args.val_path = val_path
|
||||
if args.meta_learn:
|
||||
_, _, (best_f1, best_acc, best_cm) = meta_train(args, gold_ratio)
|
||||
fout1.write("{}\t{}\t{}\t{}\n".format(gold_ratio, best_f1, best_acc, best_cm))
|
||||
else:
|
||||
tokenizer, model = build_model(args)
|
||||
gold_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=gold_train_path)
|
||||
silver_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=silver_train_path,
|
||||
is_weak=True,
|
||||
weak_type=args.weak_type)
|
||||
val_dataset = load_fake_news(args, tokenizer, evaluate=False, train_path=val_path)
|
||||
if args.train_type == 0:
|
||||
train_dataset = gold_dataset
|
||||
elif args.train_type == 1:
|
||||
train_dataset = silver_dataset
|
||||
else:
|
||||
# make a copy here for data imbalance
|
||||
gold_dataset = torch.utils.data.ConcatDataset(
|
||||
[gold_dataset] * int(len(silver_dataset) / len(gold_dataset)))
|
||||
train_dataset = torch.utils.data.ConcatDataset([gold_dataset, silver_dataset])
|
||||
global_step, tr_loss, (f1, acc, c_m) = train(args, train_dataset, val_dataset, model, tokenizer,
|
||||
gold_ratio=gold_ratio)
|
||||
fout1.write("{}\t{}\t{}\t{}\n".format(gold_ratio, f1, acc, c_m))
|
||||
logger.info("Gold Ratio {} Training Finish".format(gold_ratio))
|
||||
# logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
||||
# logger.info(" f1 = %s, acc = %s, c_m = %s", f1, acc, c_m)
|
||||
# writer.add_scalar("BestResult/F1",f1,global_step=int(gold_ratio * 100))
|
||||
# writer.add_scalar("BestResult/Acc",acc,global_step=int(gold_ratio * 100))
|
||||
# writer.add_text("BestResult/ConfusionMatrix",c_m, global_step=int(gold_ratio * 100))
|
||||
# fout1.write("{}\t{}\t{}\t{}\n".format(gold_ratio, f1, acc, c_m))
|
||||
fout1.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,501 @@
|
|||
'''
|
||||
Copyright (c) Microsoft Corporation, Yichuan Li and Kai Shu.
|
||||
Licensed under the MIT license.
|
||||
Authors: Guoqing Zheng (zheng@microsoft.com), Yichuan Li and Kai Shu
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.bernoulli import Bernoulli
|
||||
from sklearn.metrics import precision_score, recall_score, accuracy_score, f1_score
|
||||
from itertools import chain
|
||||
from correction_matrix import correction_result, get_correction_matrix
|
||||
import model
|
||||
from itertools import chain
|
||||
|
||||
def train(gold_iter, sliver_iter, val_iter, model, args, C_hat=None, statues=""):
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
steps = 0
|
||||
best_acc = 0
|
||||
last_step = 0
|
||||
model.train()
|
||||
gold_batch_count_1 = len(gold_iter)
|
||||
sliver_batch_count = len(sliver_iter)
|
||||
sliver_time = sliver_batch_count / gold_batch_count_1
|
||||
gold_batch_count = int(sliver_time) * gold_batch_count_1
|
||||
gold_iter_list = [gold_iter for i in range(int(sliver_time))]
|
||||
gold_iter_list.append(sliver_iter)
|
||||
for epoch in range(1, args.epochs+1):
|
||||
sliver_gt_label = []
|
||||
sliver_target_label = []
|
||||
sliver_predic_pro = []
|
||||
|
||||
for batch_idx, batch in enumerate(chain(gold_iter_list)):
|
||||
model.train()
|
||||
feature, target = batch.text, batch.label
|
||||
feature = torch.transpose(feature, 1, 0)
|
||||
target = target - 1
|
||||
if args.cuda:
|
||||
feature, target = feature.cuda(), target.cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
logit = model(feature)
|
||||
|
||||
if batch_idx >= gold_batch_count and C_hat is not None:
|
||||
# switch to the sliver mode
|
||||
|
||||
sliver_gt_label.append((batch.gt_label.numpy() - 1).tolist())
|
||||
logit = correction_result(logit, C_hat)
|
||||
|
||||
sliver_target_label.append(target.cpu().numpy().tolist())
|
||||
sliver_predic_pro.append(torch.argmax(logit, dim=-1).cpu().numpy().tolist())
|
||||
|
||||
logit = torch.log(logit)
|
||||
th1 = target[target > 1]
|
||||
th2 = target[target < 0]
|
||||
assert len(th1) == 0 and len(th2) == 0
|
||||
|
||||
loss = F.nll_loss(logit, target)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
if steps % args.log_interval == 0:
|
||||
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
|
||||
accuracy = 100.0 * corrects/batch.batch_size
|
||||
# sys.stdout.write(
|
||||
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
|
||||
# loss.data,
|
||||
# accuracy,
|
||||
# corrects,
|
||||
# batch.batch_size))
|
||||
if steps % args.test_interval == 0:
|
||||
# if steps % 1== 0:
|
||||
dev_acc = eval(val_iter, model, args)
|
||||
dev_acc = dev_acc[0]
|
||||
if dev_acc > best_acc:
|
||||
best_acc = dev_acc
|
||||
last_step = steps
|
||||
# save the best model
|
||||
if args.save_best:
|
||||
save(model, args.save_dir, 'best_{}'.format(statues), 0)
|
||||
else:
|
||||
if steps - last_step >= args.early_stop:
|
||||
print('early stop by {} steps.'.format(args.early_stop))
|
||||
elif steps % args.save_interval == 0:
|
||||
save(model, args.save_dir, 'snapshot', steps)
|
||||
steps += 1
|
||||
if C_hat is not None:
|
||||
sliver_gt_label = list(chain.from_iterable(sliver_gt_label))
|
||||
sliver_target_label = list(chain.from_iterable(sliver_target_label))
|
||||
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
|
||||
acc = accuracy_score(sliver_gt_label, sliver_target_label)
|
||||
precision = precision_score(sliver_gt_label, sliver_target_label, average="macro")
|
||||
recall = recall_score(sliver_gt_label, sliver_target_label, average="macro")
|
||||
acc_sliver_target = accuracy_score(sliver_target_label, sliver_predic_pro)
|
||||
acc_sliver_gt = accuracy_score(sliver_gt_label, sliver_predic_pro)
|
||||
print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {} \n acc_target {}, acc_gt {}"
|
||||
.format(acc, precision, recall, acc_sliver_target, acc_sliver_gt))
|
||||
# print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {}".format(acc, precision, recall))
|
||||
# print("\n" + statues + "\tSliver " + "\t Acc {}, \tPrecision {}, \tRecall {}".format(acc, precision, recall))
|
||||
|
||||
def train_hydra_base(gold_iter, sliver_iter, val_iter, model, args, alpha, statues=""):
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
model.train()
|
||||
steps = 0
|
||||
best_acc = 0
|
||||
last_step = 0
|
||||
model.train()
|
||||
gold_batch_count = len(gold_iter)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
sliver_gt_all = []
|
||||
sliver_labels_all = []
|
||||
sliver_predic_pro = []
|
||||
for batch_idx, batch in enumerate(chain(gold_iter, sliver_iter)):
|
||||
model.train()
|
||||
feature, target = batch.text, batch.label
|
||||
feature = torch.transpose(feature, 1, 0).contiguous()
|
||||
|
||||
target = target - 1
|
||||
if args.cuda:
|
||||
feature, target = feature.cuda(), target.cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if batch_idx >= gold_batch_count:
|
||||
# switch to the sliver mode
|
||||
logit = model.forward_sliver(feature)
|
||||
|
||||
else:
|
||||
logit = model.forward_gold(feature)
|
||||
|
||||
logit = torch.log(logit)
|
||||
loss = F.nll_loss(logit, target)
|
||||
if batch_idx >= gold_batch_count:
|
||||
loss = loss * alpha
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
steps += 1
|
||||
if steps % args.log_interval == 0:
|
||||
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
|
||||
accuracy = 100.0 * corrects / batch.batch_size
|
||||
# sys.stdout.write(
|
||||
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
|
||||
# loss.data,
|
||||
# accuracy,
|
||||
# corrects,
|
||||
# batch.batch_size))
|
||||
if steps % args.test_interval == 0:
|
||||
dev_acc = eval(val_iter, model, args)
|
||||
dev_acc = dev_acc[0]
|
||||
if dev_acc > best_acc:
|
||||
best_acc = dev_acc
|
||||
last_step = steps
|
||||
# save the best model
|
||||
if args.save_best:
|
||||
save(model, args.save_dir, 'best_{}'.format(statues), 0)
|
||||
else:
|
||||
if steps - last_step >= args.early_stop:
|
||||
print('early stop by {} steps.'.format(args.early_stop))
|
||||
elif steps % args.save_interval == 0:
|
||||
save(model, args.save_dir, 'snapshot', steps)
|
||||
|
||||
|
||||
sliver_labels_all = list(chain.from_iterable(sliver_labels_all))
|
||||
sliver_gt_all = list(chain.from_iterable(sliver_gt_all))
|
||||
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
|
||||
acc = accuracy_score(sliver_gt_all,sliver_labels_all)
|
||||
recall = recall_score(sliver_gt_all, sliver_labels_all, average="macro")
|
||||
precesion = precision_score(sliver_gt_all, sliver_labels_all, average='macro')
|
||||
acc_gt = accuracy_score(sliver_gt_all, sliver_predic_pro)
|
||||
acc_target = accuracy_score(sliver_labels_all, sliver_predic_pro)
|
||||
print("\n\n[Correction Label Result] acc: {}, recall: {}, precesion: {}, \n acc_target: {}, acc_gt: {}"
|
||||
.format(acc, recall, precesion, acc_target, acc_gt))
|
||||
|
||||
def train_with_glc_label(gold_iter, sliver_iter, val_iter, glc_model, train_model, args, alpha, statues=""):
|
||||
if args.cuda:
|
||||
glc_model.cuda()
|
||||
train_model.cuda()
|
||||
|
||||
optimizer = torch.optim.Adam(train_model.parameters(), lr=args.lr)
|
||||
|
||||
steps = 0
|
||||
best_acc = 0
|
||||
last_step = 0
|
||||
train_model.train()
|
||||
glc_model.eval()
|
||||
gold_batch_count = len(gold_iter)
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
sliver_gt_all = []
|
||||
sliver_labels_all = []
|
||||
sliver_predic_pro = []
|
||||
for batch_idx, batch in enumerate(chain(gold_iter, sliver_iter)):
|
||||
feature, target = batch.text, batch.label
|
||||
feature = torch.transpose(feature, 1, 0).contiguous()
|
||||
train_model.train()
|
||||
target = target - 1
|
||||
if args.cuda:
|
||||
feature, target = feature.cuda(), target.cuda()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
if batch_idx >= gold_batch_count:
|
||||
# switch to the sliver mode
|
||||
sliver_logit = glc_model(feature)
|
||||
target = torch.argmax(sliver_logit, dim=-1)
|
||||
logit = train_model.forward_sliver(feature)
|
||||
|
||||
sliver_predic_pro.append(torch.argmax(logit, dim=-1).cpu().numpy().tolist())
|
||||
sliver_labels_all.append(target.cpu().numpy().tolist())
|
||||
sliver_gt_target = batch.gt_label - 1
|
||||
sliver_gt_all.append(sliver_gt_target.numpy().tolist())
|
||||
|
||||
else:
|
||||
logit = train_model.forward_gold(feature)
|
||||
|
||||
logit = torch.log(logit)
|
||||
loss = F.nll_loss(logit, target)
|
||||
if batch_idx >= gold_batch_count:
|
||||
loss = loss * alpha
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
steps += 1
|
||||
if steps % args.log_interval == 0:
|
||||
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
|
||||
accuracy = 100.0 * corrects / batch.batch_size
|
||||
# sys.stdout.write(
|
||||
# '\r{} Batch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(statues, steps,
|
||||
# loss.data,
|
||||
# accuracy,
|
||||
# corrects,
|
||||
# batch.batch_size))
|
||||
if steps % args.test_interval == 0:
|
||||
dev_acc = eval(val_iter, train_model, args)
|
||||
dev_acc = dev_acc[0]
|
||||
if dev_acc > best_acc:
|
||||
best_acc = dev_acc
|
||||
last_step = steps
|
||||
# save the best model
|
||||
if args.save_best:
|
||||
save(train_model, args.save_dir, 'best_{}'.format(statues), 0)
|
||||
else:
|
||||
if steps - last_step >= args.early_stop:
|
||||
print('early stop by {} steps.'.format(args.early_stop))
|
||||
elif steps % args.save_interval == 0:
|
||||
save(train_model, args.save_dir, 'snapshot', steps)
|
||||
|
||||
|
||||
sliver_labels_all = list(chain.from_iterable(sliver_labels_all))
|
||||
sliver_gt_all = list(chain.from_iterable(sliver_gt_all))
|
||||
sliver_predic_pro = list(chain.from_iterable(sliver_predic_pro))
|
||||
acc = accuracy_score(sliver_gt_all,sliver_labels_all)
|
||||
recall = recall_score(sliver_gt_all, sliver_labels_all, average="macro")
|
||||
precesion = precision_score(sliver_gt_all, sliver_labels_all, average='macro')
|
||||
acc_gt = accuracy_score(sliver_gt_all, sliver_predic_pro)
|
||||
acc_target = accuracy_score(sliver_labels_all, sliver_predic_pro)
|
||||
print("\n\n[Correction Label Result] acc: {}, recall: {}, precesion: {}, \n acc_target: {}, acc_gt: {}"
|
||||
.format(acc, recall, precesion, acc_target, acc_gt))
|
||||
|
||||
|
||||
|
||||
def estimate_c(model, gold_iter, args):
|
||||
# load the best pesudo-clf model
|
||||
model.eval()
|
||||
gold_pro_all = []
|
||||
gold_label_all = []
|
||||
with torch.no_grad():
|
||||
for batch in gold_iter:
|
||||
gold_feature, gold_target = batch.text, batch.label
|
||||
gold_target = gold_target - 1
|
||||
gold_feature = torch.transpose(gold_feature, 1, 0).contiguous()
|
||||
if args.cuda:
|
||||
gold_feature, gold_target = gold_feature.cuda(), gold_target.cuda()
|
||||
gold_pro_all.append(model(gold_feature))
|
||||
gold_label_all.append(gold_target)
|
||||
|
||||
gold_pro_all = torch.cat(gold_pro_all, dim=0)
|
||||
gold_label_all = torch.cat(gold_label_all, dim=0)
|
||||
|
||||
C_hat = get_correction_matrix(gold_pro=gold_pro_all, gold_label=gold_label_all, method=args.gold_method)
|
||||
return C_hat
|
||||
|
||||
|
||||
def eval(data_iter, model, args):
|
||||
model.eval()
|
||||
corrects, avg_loss = 0, 0
|
||||
prediction = []
|
||||
labels = []
|
||||
for batch in data_iter:
|
||||
feature, target = batch.text, batch.label
|
||||
feature = torch.transpose(feature, 1, 0).contiguous()
|
||||
target = target - 1
|
||||
# feature.data.t_(), target.data.sub_(1) # batch first, index align
|
||||
# if args.cuda:
|
||||
|
||||
if args.cuda:
|
||||
feature, target = feature.cuda(), target.cuda()
|
||||
|
||||
|
||||
logit = model(feature)
|
||||
loss = F.cross_entropy(logit, target, size_average=False)
|
||||
|
||||
avg_loss += loss.data
|
||||
prediction += torch.argmax(logit, 1).cpu().numpy().tolist()
|
||||
labels.extend(target.cpu().numpy().tolist())
|
||||
accuracy = accuracy_score(y_true=labels, y_pred=prediction)
|
||||
size = len(data_iter.dataset.examples)
|
||||
avg_loss /= size
|
||||
corrects = accuracy * size
|
||||
|
||||
recall = recall_score(y_true=labels, y_pred=prediction, average="macro")
|
||||
precision = precision_score(y_true=labels, y_pred=prediction, average="macro")
|
||||
f1 = f1_score(y_true=labels, y_pred=prediction, average="macro")
|
||||
|
||||
print('\nEvaluation - loss: {:.6f} recall: {:.4f}, precision: {:.4f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
|
||||
recall,
|
||||
precision,
|
||||
accuracy,
|
||||
corrects,
|
||||
size))
|
||||
return accuracy, recall, precision, f1
|
||||
|
||||
|
||||
def predict(text, model, text_field, label_feild, cuda_flag):
|
||||
assert isinstance(text, str)
|
||||
model.eval()
|
||||
# text = text_field.tokenize(text)
|
||||
text = text_field.preprocess(text)
|
||||
text = [[text_field.vocab.stoi[x] for x in text]]
|
||||
x = torch.tensor(text)
|
||||
x = autograd.Variable(x)
|
||||
if cuda_flag:
|
||||
x = x.cuda()
|
||||
print(x)
|
||||
output = model(x)
|
||||
_, predicted = torch.max(output, 1)
|
||||
#return label_feild.vocab.itos[predicted.data[0][0]+1]
|
||||
return label_feild.vocab.itos[predicted.data[0]+1]
|
||||
|
||||
|
||||
def save(model, save_dir, save_prefix, steps):
|
||||
if not os.path.isdir(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_prefix = os.path.join(save_dir, save_prefix)
|
||||
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
||||
def load_model(model, save_dir, save_prefix, steps):
|
||||
save_prefix = os.path.join(save_dir, save_prefix)
|
||||
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
|
||||
model.load_state_dict(torch.load(save_path))
|
||||
return model
|
||||
|
||||
|
||||
def train_in_one(args, gold_iter, sliver_iter, val_iter, test_iter, gold_frac, alpha):
|
||||
torch.manual_seed(123)
|
||||
fout = open(os.path.join(args.save_dir, "a.result"), "a")
|
||||
fout.write("-" * 90 + "Gold Ratio: {} Alpha: {}".format(gold_frac, alpha) + "-" * 90 + "\n")
|
||||
# train only on the weak data
|
||||
cnn = model.BiLSTM(args)
|
||||
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="only_weak")
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("only_weak"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("Weak Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
|
||||
str(f1)))
|
||||
fout.write("\n")
|
||||
del cnn
|
||||
del test_model
|
||||
|
||||
# train only on the weak and gold data
|
||||
cnn = model.BiLSTM(args)
|
||||
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="weak_gold")
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("weak_gold"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("WeakGold Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
|
||||
str(f1)))
|
||||
fout.write("\n")
|
||||
del cnn
|
||||
del test_model
|
||||
|
||||
|
||||
|
||||
# train only on the golden data
|
||||
cnn = model.BiLSTM(args)
|
||||
train(gold_iter=gold_iter, sliver_iter=gold_iter, val_iter=val_iter, model=cnn, args=args, statues="test")
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("test"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("Only Gold Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
|
||||
fout.write("\n")
|
||||
del cnn
|
||||
del test_model
|
||||
|
||||
# hydra-base model
|
||||
cnn = model.BiLSTM(args)
|
||||
train_hydra_base(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="hydra_base", alpha=alpha)
|
||||
|
||||
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("hydra_base"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("HydraBase Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision),
|
||||
str(f1)))
|
||||
fout.write("\n")
|
||||
del cnn
|
||||
del test_model
|
||||
# //////////////////////// train for estimation ////////////////////////
|
||||
cnn = model.BiLSTM(args)
|
||||
if args.cuda:
|
||||
cnn = cnn.cuda()
|
||||
|
||||
print("\n" + "*" * 40 + "Training in Base Estimation" + "*" * 40)
|
||||
train(gold_iter=sliver_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="esti")
|
||||
print("*" * 40 + "Finish in Base Estimation" + "*" * 40)
|
||||
del cnn
|
||||
|
||||
# # //////////////////////// estimate C ////////////////////////
|
||||
cnn = model.BiLSTM(args)
|
||||
cnn = load_model(cnn, args.save_dir, 'best_{}'.format("esti"), 0)
|
||||
if args.cuda:
|
||||
cnn.cuda()
|
||||
C_hat = estimate_c(cnn, gold_iter, args)
|
||||
|
||||
del cnn
|
||||
# //////////////////////// retrain with correction ////////////////////////
|
||||
cnn = model.BiLSTM(args)
|
||||
|
||||
print("\n" + "*"*40 + "Training in Correction" + "*"*40)
|
||||
|
||||
if args.cuda:
|
||||
cnn = cnn.cuda()
|
||||
train(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, model=cnn, args=args, statues="glc", C_hat=C_hat)
|
||||
# eval(data_iter=test_iter, model=cnn, args=args)
|
||||
|
||||
# del cnn
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("glc"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("GLC Result: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
|
||||
fout.write("\n")
|
||||
print("\n" + "*"*40 + "Finish in Correction" + "*"*40)
|
||||
|
||||
# //////////////////////// Using GLC labels for training ////////////////////////
|
||||
# correction labels for training the last classifier
|
||||
print("\n" + "*"*40 + "Training with GLC label" + "*"*40)
|
||||
glc_model = model.BiLSTM(args)
|
||||
glc_model = load_model(glc_model, args.save_dir, 'best_{}'.format("glc"), 0)
|
||||
|
||||
final_clf_model = model.BiLSTM(args)
|
||||
# final_clf_model = load_model(final_clf_model, args.save_dir, 'best_{}'.format("glc"), 0 )
|
||||
|
||||
train_with_glc_label(gold_iter=gold_iter, sliver_iter=sliver_iter, val_iter=val_iter, glc_model=glc_model,
|
||||
train_model=final_clf_model, args=args, statues="final", alpha=alpha)
|
||||
del glc_model
|
||||
del final_clf_model
|
||||
print("\n" + "*" * 40 + "Finish in GLC" + "*" * 40)
|
||||
|
||||
# //////////////////////// Test the model on Test dataset ////////////////////////
|
||||
print("\n" + "*" * 40 + "Evaluating in Test data" + "*" * 40)
|
||||
test_model = model.BiLSTM(args)
|
||||
test_model = load_model(test_model, args.save_dir, 'best_{}'.format("final"), 0)
|
||||
if args.cuda:
|
||||
test_model.cuda()
|
||||
|
||||
accuracy, recall, precision, f1 = eval(test_iter, test_model, args)
|
||||
fout.write("Hydra Acc: {}, recall: {}, precision: {}, f1: {}".format(str(accuracy), str(recall), str(precision), str(f1)))
|
||||
fout.write("\n")
|
||||
|
||||
fout.write("-" * 90 + "END THIS" + "-" * 90 + "\n\n\n")
|
||||
fout.close()
|
||||
|
||||
|
||||
|
Загрузка…
Ссылка в новой задаче