add support for different dataset (#3)

* Update README.md

* update README.md

* add figure

* update README.md

* add model directory

* provide the generalization of the dataset
This commit is contained in:
Yichuan LI 2021-04-12 22:27:45 -04:00 коммит произвёл GitHub
Родитель 744cbb9dc0
Коммит ca0b478036
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 14 добавлений и 9 удалений

Просмотреть файл

@ -40,7 +40,15 @@ c. __"--multi_head"__ is to set the weak source count, if you have three differe
d. __"--group_opt"__: specific optimizer for group weight. You can choose __Adam__ and __SGD__.
e. __"--gold_ratio"__: Float gold ratio for the training data. Default is 0 which will use \[0.02, 0.04, 0.06, 0.08, 0.1\] all the gold ratio. For gold ratio 0.02, set it as "--gold_ratio 0.02"
e. __"--gold_ratio"__: Float gold ratio for the training data. Default is 0 which will use \[0.02, 0.04, 0.06, 0.08, 0.1\] all the gold ratio. For gold ratio 0.02, set it as "--gold_ratio 0.02"
f. __"--weak_type"__: type of weak label.
1. "none" is the default setting for our paper, where each weak labeled instance will have the number of weak label sources for the model training.
2. "flat" will copy the weak labeled instance for number of weak labeling sources, and ignore the difference of weak sources.
3. "most_vote",
4. "0" for weak_label_1, "1" for weak_label_2 and etc.
If you choose other variables except "none", you should change the multi_head to 2 (one for the clean dataset and another one for the weak labeled data).
- Finetune on RoBERTa Group Weight
@ -79,8 +87,6 @@ The log information will stored in
- CNN Baseline Model
python run_classifiy.py \

Просмотреть файл

@ -55,12 +55,11 @@ class FakeNewsDataset(Dataset):
self.attention_mask = list(
chain.from_iterable([[i] * self.weak_label_count for i in self.attention_mask]))
#"credit_label","polarity_label","bias_label"
elif weak_type == "cred":
self.weak_labels = [i[0] for i in self.weak_labels]
elif weak_type == "polar":
self.weak_labels = [i[1] for i in self.weak_labels]
elif weak_type == "bias":
self.weak_labels = [i[2] for i in self.weak_labels]
elif weak_type.isdigit():
self.weak_labels = [i[int(weak_type)] for i in self.weak_labels]
else:
print("Default setting of dataset")
self.is_weak = is_weak
self.weak_type = weak_type
if self.is_weak and balance_weak: