handling headers with field matching and streaming the file

This commit is contained in:
Arthur Wiedmer 2015-04-24 15:14:29 -07:00
Родитель 6bd5bc04c3
Коммит 844abbdba8
1 изменённых файлов: 44 добавлений и 11 удалений

Просмотреть файл

@ -37,6 +37,9 @@ class S3ToHiveTransfer(BaseOperator):
:param partition: target partition as a dict of partition columns :param partition: target partition as a dict of partition columns
and values and values
:type partition: dict :type partition: dict
:param headers: whether the file contains column names on the first
line
:type headers: bool
:param delimiter: field delimiter in the file :param delimiter: field delimiter in the file
:type delimiter: str :type delimiter: str
:param s3_conn_id: source s3 connection :param s3_conn_id: source s3 connection
@ -84,17 +87,47 @@ class S3ToHiveTransfer(BaseOperator):
raise Exception("The key {0} does not exists".format(self.s3_key)) raise Exception("The key {0} does not exists".format(self.s3_key))
s3_key_object = self.s3.get_key(self.s3_key) s3_key_object = self.s3.get_key(self.s3_key)
with NamedTemporaryFile("w") as f: with NamedTemporaryFile("w") as f:
logging.info("Dumping S3 file {0}" logging.info("Dumping S3 file {0} contents to local"
" contents to local file {1}".format(self.s3_key, f.name)) " file {1}".format(self.s3_key, f.name))
s3_key_object.get_contents_to_file(f) s3_key_object.get_contents_to_file(f)
f.flush() f.flush()
self.s3.connection.close() self.s3.connection.close()
logging.info("Loading file into Hive") if not self.headers:
self.hive.load_file( logging.info("Loading file into Hive")
f.name, self.hive.load_file(
self.hive_table, f.name,
field_dict=self.field_dict, self.hive_table,
create=self.create, field_dict=self.field_dict,
partition=self.partition, create=self.create,
delimiter=self.delimiter, partition=self.partition,
recreate=self.recreate) delimiter=self.delimiter,
recreate=self.recreate)
else:
f.seek(0)
header_l = f.readline()
header_line = header_l.replace('\n','')
header_list = header_line.split(self.delimiter)
field_names = list(self.field_dict.keys())
test_field_match = [h1.lower() == h2.lower() for h1, h2
in zip(header_list, field_names)]
if not all(test_field_match):
logging.warning("Headers do not match field names"
"File headers:\n {header_list}\n"
"Field names: \n {field_names}\n"
"".format(**locals()))
raise Exception("The headers do not match the field names")
with NamedTemporaryFile("w") as fnoheaders:
f.seek(0)
next(f)
for line in f:
fnoheaders.write(line)
fnoheaders.flush()
logging.info("Loading file without headers into Hive")
self.hive.load_file(
fnoheaders.name,
self.hive_table,
field_dict=self.field_dict,
create=self.create,
partition=self.partition,
delimiter=self.delimiter,
recreate=self.recreate)