Fixes to gensen and gensen_utils files

This commit is contained in:
Abhiram E 2019-05-28 07:24:02 -07:00
Родитель 324a7d6501
Коммит c2c551d094
2 изменённых файлов: 23 добавлений и 44 удалений

Просмотреть файл

@ -68,20 +68,19 @@ class GenSenClassifier:
json_object = json.load(open(config_file, "r", encoding="utf-8"))
return json_object
def fit(self, train_df, dev_df):
def fit(self, train_df, dev_df, test_df):
""" Method to train the Gensen model.
Args:
train_df: A dataframe containing tokenized sentences from the training set.
dev_df: A dataframe containing tokenized sentences from the validation set.
test_df: A dataframe containing tokenized sentences from the test set.
"""
self._validate_params()
config = self._read_config(self.config_file)
self.cache_dir = self._get_gensen_tokens(train_df, dev_df)
print(self.cache_dir)
self.cache_dir = self._get_gensen_tokens(train_df, dev_df, test_df)
train.train(
data_folder=self.cache_dir,
config=config,

Просмотреть файл

@ -24,28 +24,28 @@ def _preprocess(data_path):
base_txt_path = os.path.join(
data_path, "clean/snli_1.0/snli_1.0_{}.txt".format(file_split)
)
df["s1.tok"] = df["sentence1_tokens"].apply(lambda x: " ".join(x))
df["s2.tok"] = df["sentence2_tokens"].apply(lambda x: " ".join(x))
df["s1.tok"].to_csv(
"{}.s1.tok".format(base_txt_path),
sep=" ",
header=False,
index=False,
"{}.s1.tok".format(base_txt_path), sep=" ", header=False, index=False
)
df["s2.tok"].to_csv(
"{}.s2.tok".format(base_txt_path),
sep=" ",
header=False,
index=False,
"{}.s2.tok".format(base_txt_path), sep=" ", header=False, index=False
)
df["score"].to_csv(
"{}.lab".format(base_txt_path), sep=" ", header=False, index=False
)
df_clean = df[["s1.tok", "s2.tok", "score"]]
df_noblank = df_clean.loc[df_clean["score"] == "-"].copy()
df_clean.to_csv(
"{}.clean".format(base_txt_path), sep="\t", header=False, index=False
)
# remove rows with blank scores
df_noblank = df_clean.loc[df_clean["score"] != "-"].copy()
print(base_txt_path)
df_noblank.to_csv(
"{}.clean.noblank".format(base_txt_path),
sep="\t",
header=False,
index=False,
"{}.clean.noblank".format(base_txt_path), sep="\t", header=False,
index=False
)
@ -62,22 +62,18 @@ def _split_and_cleanup(data_path):
for file_split in SPLIT_MAP.keys():
s1_tok_path = os.path.join(
data_path,
"clean/snli_1.0/snli_1.0_{}.txt.s1.tok".format(file_split),
data_path, "clean/snli_1.0/snli_1.0_{}.txt.s1.tok".format(file_split)
)
s2_tok_path = os.path.join(
data_path,
"clean/snli_1.0/snli_1.0_{}.txt.s2.tok".format(file_split),
data_path, "clean/snli_1.0/snli_1.0_{}.txt.s2.tok".format(file_split)
)
with open(s1_tok_path, "r") as fin, open(
"{}.tmp".format(s1_tok_path), "w"
) as tmp:
with open(s1_tok_path, "r") as fin, open("{}.tmp".format(s1_tok_path),
"w") as tmp:
for line in fin:
s = line.replace('"', "")
tmp.write(s)
with open(s2_tok_path, "r") as fin, open(
"{}.tmp".format(s2_tok_path), "w"
) as tmp:
with open(s2_tok_path, "r") as fin, open("{}.tmp".format(s2_tok_path),
"w") as tmp:
for line in fin:
s = line.replace('"', "")
tmp.write(s)
@ -109,21 +105,5 @@ def gensen_preprocess(train_tok, dev_tok, test_tok, data_path):
_preprocess(data_path)
_split_and_cleanup(data_path)
_clean(data_path)
return os.path.join(data_path, "clean/snli_1.0")
def _clean(data_path):
for file_split in SPLIT_MAP.keys():
src_file_path = os.path.join(
data_path, "raw/snli_1.0/snli_1.0_{}.txt".format(file_split)
)
dest_file_path = os.path.join(
data_path,
"clean/snli_1.0/snli_1.0_clean_{}.txt".format(file_split),
)
clean_df = clean_snli(
src_file_path
).dropna() # drop rows with any NaN vals
clean_df.to_csv(dest_file_path)