Fixes to gensen and gensen_utils files
This commit is contained in:
Родитель
324a7d6501
Коммит
c2c551d094
|
@ -68,20 +68,19 @@ class GenSenClassifier:
|
|||
json_object = json.load(open(config_file, "r", encoding="utf-8"))
|
||||
return json_object
|
||||
|
||||
def fit(self, train_df, dev_df):
|
||||
def fit(self, train_df, dev_df, test_df):
|
||||
|
||||
""" Method to train the Gensen model.
|
||||
|
||||
Args:
|
||||
train_df: A dataframe containing tokenized sentences from the training set.
|
||||
dev_df: A dataframe containing tokenized sentences from the validation set.
|
||||
|
||||
test_df: A dataframe containing tokenized sentences from the test set.
|
||||
"""
|
||||
|
||||
self._validate_params()
|
||||
config = self._read_config(self.config_file)
|
||||
self.cache_dir = self._get_gensen_tokens(train_df, dev_df)
|
||||
print(self.cache_dir)
|
||||
self.cache_dir = self._get_gensen_tokens(train_df, dev_df, test_df)
|
||||
train.train(
|
||||
data_folder=self.cache_dir,
|
||||
config=config,
|
||||
|
|
|
@ -24,28 +24,28 @@ def _preprocess(data_path):
|
|||
base_txt_path = os.path.join(
|
||||
data_path, "clean/snli_1.0/snli_1.0_{}.txt".format(file_split)
|
||||
)
|
||||
|
||||
df["s1.tok"] = df["sentence1_tokens"].apply(lambda x: " ".join(x))
|
||||
df["s2.tok"] = df["sentence2_tokens"].apply(lambda x: " ".join(x))
|
||||
df["s1.tok"].to_csv(
|
||||
"{}.s1.tok".format(base_txt_path),
|
||||
sep=" ",
|
||||
header=False,
|
||||
index=False,
|
||||
"{}.s1.tok".format(base_txt_path), sep=" ", header=False, index=False
|
||||
)
|
||||
df["s2.tok"].to_csv(
|
||||
"{}.s2.tok".format(base_txt_path),
|
||||
sep=" ",
|
||||
header=False,
|
||||
index=False,
|
||||
"{}.s2.tok".format(base_txt_path), sep=" ", header=False, index=False
|
||||
)
|
||||
df["score"].to_csv(
|
||||
"{}.lab".format(base_txt_path), sep=" ", header=False, index=False
|
||||
)
|
||||
|
||||
df_clean = df[["s1.tok", "s2.tok", "score"]]
|
||||
df_noblank = df_clean.loc[df_clean["score"] == "-"].copy()
|
||||
df_clean.to_csv(
|
||||
"{}.clean".format(base_txt_path), sep="\t", header=False, index=False
|
||||
)
|
||||
# remove rows with blank scores
|
||||
df_noblank = df_clean.loc[df_clean["score"] != "-"].copy()
|
||||
print(base_txt_path)
|
||||
df_noblank.to_csv(
|
||||
"{}.clean.noblank".format(base_txt_path),
|
||||
sep="\t",
|
||||
header=False,
|
||||
index=False,
|
||||
"{}.clean.noblank".format(base_txt_path), sep="\t", header=False,
|
||||
index=False
|
||||
)
|
||||
|
||||
|
||||
|
@ -62,22 +62,18 @@ def _split_and_cleanup(data_path):
|
|||
|
||||
for file_split in SPLIT_MAP.keys():
|
||||
s1_tok_path = os.path.join(
|
||||
data_path,
|
||||
"clean/snli_1.0/snli_1.0_{}.txt.s1.tok".format(file_split),
|
||||
data_path, "clean/snli_1.0/snli_1.0_{}.txt.s1.tok".format(file_split)
|
||||
)
|
||||
s2_tok_path = os.path.join(
|
||||
data_path,
|
||||
"clean/snli_1.0/snli_1.0_{}.txt.s2.tok".format(file_split),
|
||||
data_path, "clean/snli_1.0/snli_1.0_{}.txt.s2.tok".format(file_split)
|
||||
)
|
||||
with open(s1_tok_path, "r") as fin, open(
|
||||
"{}.tmp".format(s1_tok_path), "w"
|
||||
) as tmp:
|
||||
with open(s1_tok_path, "r") as fin, open("{}.tmp".format(s1_tok_path),
|
||||
"w") as tmp:
|
||||
for line in fin:
|
||||
s = line.replace('"', "")
|
||||
tmp.write(s)
|
||||
with open(s2_tok_path, "r") as fin, open(
|
||||
"{}.tmp".format(s2_tok_path), "w"
|
||||
) as tmp:
|
||||
with open(s2_tok_path, "r") as fin, open("{}.tmp".format(s2_tok_path),
|
||||
"w") as tmp:
|
||||
for line in fin:
|
||||
s = line.replace('"', "")
|
||||
tmp.write(s)
|
||||
|
@ -109,21 +105,5 @@ def gensen_preprocess(train_tok, dev_tok, test_tok, data_path):
|
|||
|
||||
_preprocess(data_path)
|
||||
_split_and_cleanup(data_path)
|
||||
_clean(data_path)
|
||||
|
||||
return os.path.join(data_path, "clean/snli_1.0")
|
||||
|
||||
|
||||
def _clean(data_path):
|
||||
for file_split in SPLIT_MAP.keys():
|
||||
src_file_path = os.path.join(
|
||||
data_path, "raw/snli_1.0/snli_1.0_{}.txt".format(file_split)
|
||||
)
|
||||
dest_file_path = os.path.join(
|
||||
data_path,
|
||||
"clean/snli_1.0/snli_1.0_clean_{}.txt".format(file_split),
|
||||
)
|
||||
clean_df = clean_snli(
|
||||
src_file_path
|
||||
).dropna() # drop rows with any NaN vals
|
||||
clean_df.to_csv(dest_file_path)
|
||||
|
|
Загрузка…
Ссылка в новой задаче