93 строки
4.1 KiB
Python
93 строки
4.1 KiB
Python
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# Licensed under the MIT license.
|
|
import os
|
|
import shutil
|
|
from os import listdir
|
|
import tarfile
|
|
import argparse
|
|
from sys import version_info
|
|
from sklearn.model_selection import train_test_split
|
|
if version_info.major == 2:
|
|
import urllib as urldownload
|
|
else:
|
|
import urllib.request as urldownload
|
|
|
|
|
|
class NewsGroup(object):
|
|
def __init__(self, params):
|
|
self.params = params
|
|
self.url = "http://archive.ics.uci.edu/ml/machine-learning-databases/20newsgroups-mld/20_newsgroups.tar.gz"
|
|
self.file_name = '20_newsgroups.tar.gz'
|
|
self.dirname = '20_newsgroups'
|
|
|
|
def download_or_zip(self):
|
|
if not os.path.exists(self.params.root_dir):
|
|
os.mkdir(self.params.root_dir)
|
|
path = os.path.join(self.params.root_dir, self.dirname)
|
|
if not os.path.isdir(path):
|
|
file_path = os.path.join(self.params.root_dir, self.file_name)
|
|
if not os.path.isfile(file_path):
|
|
print('DownLoading...')
|
|
urldownload.urlretrieve(self.url, file_path)
|
|
with tarfile.open(file_path, 'r', encoding='utf-8') as fin:
|
|
print('Extracting...')
|
|
fin.extractall(self.params.root_dir)
|
|
return path
|
|
|
|
def read_process_file(self, file_path):
|
|
text_lines = []
|
|
with open(file_path, 'rb') as fin:
|
|
for single_line in fin:
|
|
text_lines.append(str(single_line))
|
|
return ''.join(text_lines).replace('\n', ' ').replace('\t', ' ')
|
|
|
|
def data_combination(self):
|
|
data_dir_path = self.download_or_zip()
|
|
class_name_folders = listdir(data_dir_path)
|
|
assert len(class_name_folders) == 20, 'The 20_newsgroups data has 20 classes and 20 sub folder accordingly, but we found %d' % len(class_name_folders)
|
|
pathname_list = []
|
|
label_list = []
|
|
for sub_folder in class_name_folders:
|
|
sub_folder_path = os.path.join(data_dir_path, sub_folder)
|
|
for single_file in listdir(sub_folder_path):
|
|
pathname_list.append(os.path.join(sub_folder_path, single_file))
|
|
label_list.append(sub_folder)
|
|
# prepare folder and write data
|
|
if not os.path.exists(self.params.output_dir):
|
|
os.mkdir(self.params.output_dir)
|
|
data_all = []
|
|
print('Preprocessing...')
|
|
for single_file_path, singel_label in zip(pathname_list, label_list):
|
|
text_line = '%s\t%s\n' % (singel_label, self.read_process_file(single_file_path))
|
|
data_all.append(text_line)
|
|
|
|
print('Write output file...')
|
|
if self.params.isSplit:
|
|
output_train_file_path = os.path.join(self.params.output_dir, 'train.tsv')
|
|
output_test_file_path = os.path.join(self.params.output_dir, 'test.tsv')
|
|
train_data, test_data = train_test_split(data_all, test_size=self.params.test_size, random_state=123)
|
|
with open(output_train_file_path, 'w', encoding='utf-8') as fout:
|
|
fout.writelines(train_data)
|
|
with open(output_test_file_path, 'w', encoding='utf-8') as fout:
|
|
fout.writelines(test_data)
|
|
else:
|
|
output_file_path = os.path.join(self.output_dir, 'output.tsv')
|
|
with open(output_file_path, 'w', encoding='utf-8') as fout:
|
|
fout.writelines(data_all)
|
|
try:
|
|
if os.path.exists(self.params.root_dir):
|
|
shutil.rmtree(self.params.root_dir)
|
|
except:
|
|
pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parse = argparse.ArgumentParser(description='20_newsgroups data preprocess')
|
|
parse.add_argument("--root_dir", type=str, default='./data', help='the folder path of saving download file and untar files')
|
|
parse.add_argument("--output_dir", type=str, default='20_newsgroups', help='the folder path of saving tsv format files after preprocess')
|
|
parse.add_argument("--isSplit", type=bool, default=True, help='appoint split data into train dataset and test dataset or not')
|
|
parse.add_argument("--test_size", type=float, default=0.2)
|
|
params, _ = parse.parse_known_args()
|
|
newsgroup = NewsGroup(params)
|
|
newsgroup.data_combination()
|