зеркало из https://github.com/Azure/blobxfer.git
Current download parity
- Main download logic at parity with current blobxfer - Refactor multiprocess offload into base class - Add multiprocess crypto offload
This commit is contained in:
Родитель
31ef912cd6
Коммит
d70d404b46
|
@ -35,20 +35,13 @@ import collections
|
|||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
# non-stdlib imports
|
||||
# local imports
|
||||
import blobxfer.crypto.operations
|
||||
import blobxfer.util
|
||||
|
||||
|
||||
# encryption constants
|
||||
_AES256_KEYLENGTH_BYTES = 32
|
||||
_AES256_BLOCKSIZE_BYTES = 16
|
||||
_HMACSHA256_DIGESTSIZE_BYTES = 32
|
||||
_AES256CBC_HMACSHA256_OVERHEAD_BYTES = (
|
||||
_AES256_BLOCKSIZE_BYTES + _HMACSHA256_DIGESTSIZE_BYTES
|
||||
)
|
||||
|
||||
# named tuples
|
||||
EncryptionBlobxferExtensions = collections.namedtuple(
|
||||
|
@ -191,8 +184,8 @@ class EncryptionMetadata(object):
|
|||
)
|
||||
except KeyError:
|
||||
pass
|
||||
self.content_encryption_iv = ed[
|
||||
EncryptionMetadata._JSON_KEY_CONTENT_IV]
|
||||
self.content_encryption_iv = base64.b64decode(
|
||||
ed[EncryptionMetadata._JSON_KEY_CONTENT_IV])
|
||||
self.encryption_agent = EncryptionAgent(
|
||||
encryption_algorithm=ed[
|
||||
EncryptionMetadata._JSON_KEY_ENCRYPTION_AGENT][
|
||||
|
|
|
@ -31,7 +31,13 @@ from builtins import ( # noqa
|
|||
next, oct, open, pow, round, super, filter, map, zip)
|
||||
# stdlib imports
|
||||
import base64
|
||||
import enum
|
||||
import logging
|
||||
import os
|
||||
try:
|
||||
import queue
|
||||
except ImportError: # noqa
|
||||
import Queue as queue
|
||||
# non-stdlib imports
|
||||
import cryptography.hazmat.backends
|
||||
import cryptography.hazmat.primitives.asymmetric.padding
|
||||
|
@ -44,7 +50,13 @@ import cryptography.hazmat.primitives.hashes
|
|||
import cryptography.hazmat.primitives.padding
|
||||
import cryptography.hazmat.primitives.serialization
|
||||
# local imports
|
||||
import blobxfer.util
|
||||
import blobxfer.offload
|
||||
|
||||
# create logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# encryption constants
|
||||
_AES256_KEYLENGTH_BYTES = 32
|
||||
|
||||
|
||||
def load_rsa_private_key_file(rsakeyfile, passphrase):
|
||||
|
@ -130,7 +142,7 @@ def rsa_encrypt_key_base64_encoded(rsaprivatekey, rsapublickey, plainkey):
|
|||
return blobxfer.util.base64_encode_as_string(enckey)
|
||||
|
||||
|
||||
def pad_pkcs7(buf):
|
||||
def pkcs7_pad(buf):
|
||||
# type: (bytes) -> bytes
|
||||
"""Appends PKCS7 padding to an input buffer
|
||||
:param bytes buf: buffer to add padding
|
||||
|
@ -143,7 +155,7 @@ def pad_pkcs7(buf):
|
|||
return padder.update(buf) + padder.finalize()
|
||||
|
||||
|
||||
def unpad_pkcs7(buf):
|
||||
def pkcs7_unpad(buf):
|
||||
# type: (bytes) -> bytes
|
||||
"""Removes PKCS7 padding a decrypted object
|
||||
:param bytes buf: buffer to remove padding
|
||||
|
@ -154,3 +166,107 @@ def unpad_pkcs7(buf):
|
|||
cryptography.hazmat.primitives.ciphers.
|
||||
algorithms.AES.block_size).unpadder()
|
||||
return unpadder.update(buf) + unpadder.finalize()
|
||||
|
||||
|
||||
def aes256_generate_random_key():
|
||||
# type: (None) -> bytes
|
||||
"""Generate random AES256 key
|
||||
:rtype: bytes
|
||||
:return: random key
|
||||
"""
|
||||
return os.urandom(_AES256_KEYLENGTH_BYTES)
|
||||
|
||||
|
||||
def aes_cbc_decrypt_data(symkey, iv, encdata, unpad):
|
||||
# type: (bytes, bytes, bytes, bool) -> bytes
|
||||
"""Decrypt data using AES CBC
|
||||
:param bytes symkey: symmetric key
|
||||
:param bytes iv: initialization vector
|
||||
:param bytes encdata: data to decrypt
|
||||
:param bool unpad: unpad data
|
||||
:rtype: bytes
|
||||
:return: decrypted data
|
||||
"""
|
||||
cipher = cryptography.hazmat.primitives.ciphers.Cipher(
|
||||
cryptography.hazmat.primitives.ciphers.algorithms.AES(symkey),
|
||||
cryptography.hazmat.primitives.ciphers.modes.CBC(iv),
|
||||
backend=cryptography.hazmat.backends.default_backend()).decryptor()
|
||||
decrypted = cipher.update(encdata) + cipher.finalize()
|
||||
if unpad:
|
||||
return pkcs7_unpad(decrypted)
|
||||
else:
|
||||
return decrypted
|
||||
|
||||
|
||||
def aes_cbc_encrypt_data(symkey, iv, data, pad):
|
||||
# type: (bytes, bytes, bytes, bool) -> bytes
|
||||
"""Encrypt data using AES CBC
|
||||
:param bytes symkey: symmetric key
|
||||
:param bytes iv: initialization vector
|
||||
:param bytes data: data to encrypt
|
||||
:param bool pad: pad data
|
||||
:rtype: bytes
|
||||
:return: encrypted data
|
||||
"""
|
||||
cipher = cryptography.hazmat.primitives.ciphers.Cipher(
|
||||
cryptography.hazmat.primitives.ciphers.algorithms.AES(symkey),
|
||||
cryptography.hazmat.primitives.ciphers.modes.CBC(iv),
|
||||
backend=cryptography.hazmat.backends.default_backend()).encryptor()
|
||||
if pad:
|
||||
return cipher.update(pkcs7_pad(data)) + cipher.finalize()
|
||||
else:
|
||||
return cipher.update(data) + cipher.finalize()
|
||||
|
||||
|
||||
class CryptoAction(enum.Enum):
|
||||
Encrypt = 1
|
||||
Decrypt = 2
|
||||
|
||||
|
||||
class CryptoOffload(blobxfer.offload._MultiprocessOffload):
|
||||
def __init__(self, num_workers):
|
||||
# type: (CryptoOffload, int) -> None
|
||||
"""Ctor for Crypto Offload
|
||||
:param CryptoOffload self: this
|
||||
:param int num_workers: number of worker processes
|
||||
"""
|
||||
super(CryptoOffload, self).__init__(num_workers, 'Crypto')
|
||||
|
||||
def _worker_process(self):
|
||||
# type: (CryptoOffload) -> None
|
||||
"""Crypto worker
|
||||
:param CryptoOffload self: this
|
||||
"""
|
||||
while not self.terminated:
|
||||
try:
|
||||
inst = self._task_queue.get(True, 1)
|
||||
except queue.Empty:
|
||||
continue
|
||||
if inst[0] == CryptoAction.Encrypt:
|
||||
# TODO on upload
|
||||
raise NotImplementedError()
|
||||
elif inst[0] == CryptoAction.Decrypt:
|
||||
final_path, offsets, symkey, iv, encdata = \
|
||||
inst[1], inst[2], inst[3], inst[4], inst[5]
|
||||
data = aes_cbc_decrypt_data(symkey, iv, encdata, offsets.unpad)
|
||||
self._done_cv.acquire()
|
||||
self._done_queue.put((final_path, offsets, data))
|
||||
self._done_cv.notify()
|
||||
self._done_cv.release()
|
||||
|
||||
def add_decrypt_chunk(
|
||||
self, final_path, offsets, symkey, iv, encdata):
|
||||
# type: (CryptoOffload, str, blobxfer.models.DownloadOffsets, bytes,
|
||||
# bytes, bytes) -> None
|
||||
"""Add a chunk to decrypt
|
||||
:param CryptoOffload self: this
|
||||
:param str final_path: final path
|
||||
:param blobxfer.models.DownloadOffsets offsets: offsets
|
||||
:param bytes symkey: symmetric key
|
||||
:param bytes iv: initialization vector
|
||||
:param bytes encdata: encrypted data
|
||||
"""
|
||||
self._task_queue.put(
|
||||
(CryptoAction.Decrypt, final_path, offsets, symkey, iv,
|
||||
encdata)
|
||||
)
|
||||
|
|
|
@ -43,9 +43,12 @@ try:
|
|||
except ImportError: # noqa
|
||||
import Queue as queue
|
||||
import threading
|
||||
import time
|
||||
# non-stdlib imports
|
||||
import dateutil
|
||||
# local imports
|
||||
import blobxfer.crypto.models
|
||||
import blobxfer.crypto.operations
|
||||
import blobxfer.md5
|
||||
import blobxfer.models
|
||||
import blobxfer.operations
|
||||
|
@ -75,18 +78,20 @@ class Downloader(object):
|
|||
:param blobxfer.models.AzureStorageCredentials creds: creds
|
||||
:param blobxfer.models.DownloadSpecification spec: download spec
|
||||
"""
|
||||
self._md5_meta_lock = threading.Lock()
|
||||
self._download_lock = threading.Lock()
|
||||
self._time_start = None
|
||||
self._all_remote_files_processed = False
|
||||
self._crypto_offload = None
|
||||
self._md5_meta_lock = threading.Lock()
|
||||
self._md5_map = {}
|
||||
self._md5_offload = None
|
||||
self._md5_check_thread = None
|
||||
self._download_lock = threading.Lock()
|
||||
self._download_queue = queue.Queue()
|
||||
self._download_set = set()
|
||||
self._download_threads = []
|
||||
self._download_count = 0
|
||||
self._download_total_bytes = 0
|
||||
self._download_terminate = False
|
||||
self._dd_map = {}
|
||||
self._general_options = general_options
|
||||
self._creds = creds
|
||||
self._spec = spec
|
||||
|
@ -177,46 +182,69 @@ class Downloader(object):
|
|||
else:
|
||||
self._add_to_download_queue(lpath, rfile)
|
||||
|
||||
def _initialize_check_md5_downloads_thread(self):
|
||||
def _check_for_downloads_from_md5(self):
|
||||
# type: (Downloader) -> None
|
||||
"""Initialize the md5 done queue check thread
|
||||
"""Check queue for a file to download
|
||||
:param Downloader self: this
|
||||
"""
|
||||
def _check_for_downloads_from_md5(self):
|
||||
# type: (Downloader) -> None
|
||||
"""Check queue for a file to download
|
||||
:param Downloader self: this
|
||||
"""
|
||||
cv = self._md5_offload.done_cv
|
||||
while True:
|
||||
with self._md5_meta_lock:
|
||||
if (self._download_terminate or
|
||||
(len(self._md5_map) == 0 and
|
||||
self._all_remote_files_processed)):
|
||||
break
|
||||
result = None
|
||||
cv.acquire()
|
||||
while not self._download_terminate:
|
||||
result = self._md5_offload.get_localfile_md5_done()
|
||||
if result is None:
|
||||
# use cv timeout due to possible non-wake while running
|
||||
cv.wait(1)
|
||||
# check for terminating conditions
|
||||
with self._md5_meta_lock:
|
||||
if (len(self._md5_map) == 0 and
|
||||
self._all_remote_files_processed):
|
||||
break
|
||||
else:
|
||||
break
|
||||
cv.release()
|
||||
if result is not None:
|
||||
self._post_md5_skip_on_check(result[0], result[1])
|
||||
cv = self._md5_offload.done_cv
|
||||
while True:
|
||||
with self._md5_meta_lock:
|
||||
if (self._download_terminate or
|
||||
(self._all_remote_files_processed and
|
||||
len(self._md5_map) == 0 and
|
||||
len(self._download_set) == 0)):
|
||||
break
|
||||
result = None
|
||||
cv.acquire()
|
||||
while not self._download_terminate:
|
||||
result = self._md5_offload.pop_done_queue()
|
||||
if result is None:
|
||||
# use cv timeout due to possible non-wake while running
|
||||
cv.wait(1)
|
||||
# check for terminating conditions
|
||||
with self._md5_meta_lock:
|
||||
if (self._all_remote_files_processed and
|
||||
len(self._md5_map) == 0 and
|
||||
len(self._download_set) == 0):
|
||||
break
|
||||
else:
|
||||
break
|
||||
cv.release()
|
||||
if result is not None:
|
||||
self._post_md5_skip_on_check(result[0], result[1])
|
||||
|
||||
self._md5_check_thread = threading.Thread(
|
||||
target=_check_for_downloads_from_md5,
|
||||
args=(self,)
|
||||
)
|
||||
self._md5_check_thread.start()
|
||||
def _check_for_crypto_done(self):
|
||||
# type: (Downloader) -> None
|
||||
"""Check queue for crypto done
|
||||
:param Downloader self: this
|
||||
"""
|
||||
cv = self._crypto_offload.done_cv
|
||||
while True:
|
||||
with self._download_lock:
|
||||
if (self._download_terminate or
|
||||
(self._all_remote_files_processed and
|
||||
len(self._download_set) == 0)):
|
||||
break
|
||||
result = None
|
||||
cv.acquire()
|
||||
while not self._download_terminate:
|
||||
result = self._crypto_offload.pop_done_queue()
|
||||
if result is None:
|
||||
# use cv timeout due to possible non-wake while running
|
||||
cv.wait(1)
|
||||
# check for terminating conditions
|
||||
with self._download_lock:
|
||||
if (self._all_remote_files_processed and
|
||||
len(self._download_set) == 0):
|
||||
break
|
||||
else:
|
||||
break
|
||||
cv.release()
|
||||
if result is not None:
|
||||
with self._download_lock:
|
||||
dd = self._dd_map[result[0]]
|
||||
self._complete_chunk_download(result[1], result[2], dd)
|
||||
|
||||
def _add_to_download_queue(self, lpath, rfile):
|
||||
# type: (Downloader, pathlib.Path,
|
||||
|
@ -229,6 +257,9 @@ class Downloader(object):
|
|||
# prepare remote file for download
|
||||
dd = blobxfer.models.DownloadDescriptor(
|
||||
lpath, rfile, self._spec.options)
|
||||
if dd.entity.is_encrypted:
|
||||
with self._download_lock:
|
||||
self._dd_map[str(dd.final_path)] = dd
|
||||
# add download descriptor to queue
|
||||
self._download_queue.put(dd)
|
||||
|
||||
|
@ -250,7 +281,8 @@ class Downloader(object):
|
|||
:param Downloader self: this
|
||||
:param bool terminate: terminate threads
|
||||
"""
|
||||
self._download_terminate = terminate
|
||||
if terminate:
|
||||
self._download_terminate = terminate
|
||||
for thr in self._download_threads:
|
||||
thr.join()
|
||||
|
||||
|
@ -273,17 +305,15 @@ class Downloader(object):
|
|||
# get download offsets
|
||||
offsets = dd.next_offsets()
|
||||
# check if all operations completed
|
||||
if offsets is None and dd.outstanding_operations == 0:
|
||||
# TODO
|
||||
# 1. complete integrity checks
|
||||
# 2. set file uid/gid
|
||||
# 3. set file modes
|
||||
# 4. move file to final path
|
||||
if offsets is None and dd.all_operations_completed:
|
||||
# finalize file
|
||||
dd.finalize_file()
|
||||
# accounting
|
||||
with self._download_lock:
|
||||
if dd.entity.is_encrypted:
|
||||
self._dd_map.pop(str(dd.final_path))
|
||||
self._download_set.remove(dd.final_path)
|
||||
self._download_count += 1
|
||||
logger.info('download complete: {}/{} to {}'.format(
|
||||
dd.entity.container, dd.entity.name, dd.final_path))
|
||||
continue
|
||||
# re-enqueue for other threads to download
|
||||
self._download_queue.put(dd)
|
||||
|
@ -291,39 +321,60 @@ class Downloader(object):
|
|||
continue
|
||||
# issue get range
|
||||
if dd.entity.mode == blobxfer.models.AzureStorageModes.File:
|
||||
chunk = blobxfer.file.operations.get_file_range(
|
||||
data = blobxfer.file.operations.get_file_range(
|
||||
dd.entity, offsets, self._general_options.timeout_sec)
|
||||
else:
|
||||
chunk = blobxfer.blob.operations.get_blob_range(
|
||||
data = blobxfer.blob.operations.get_blob_range(
|
||||
dd.entity, offsets, self._general_options.timeout_sec)
|
||||
# accounting
|
||||
with self._download_lock:
|
||||
self._download_total_bytes += offsets.num_bytes
|
||||
# decrypt if necessary
|
||||
if dd.entity.is_encrypted:
|
||||
# TODO via crypto pool
|
||||
# 1. compute rolling hmac if present
|
||||
# - roll through any subsequent unchecked parts
|
||||
# 2. decrypt chunk
|
||||
pass
|
||||
# compute rolling md5 via md5 pool
|
||||
if dd.must_compute_md5:
|
||||
# TODO
|
||||
# - roll through any subsequent unchecked parts
|
||||
pass
|
||||
# slice data to proper bounds
|
||||
encdata = data[blobxfer.crypto.models._AES256_BLOCKSIZE_BYTES:]
|
||||
intdata = encdata
|
||||
# get iv for chunk and compute hmac
|
||||
if offsets.chunk_num == 0:
|
||||
iv = dd.entity.encryption_metadata.content_encryption_iv
|
||||
# integrity check for first chunk must include iv
|
||||
intdata = iv + data
|
||||
else:
|
||||
iv = data[:blobxfer.crypto.models._AES256_BLOCKSIZE_BYTES]
|
||||
# integrity check data
|
||||
dd.perform_chunked_integrity_check(offsets, intdata)
|
||||
# decrypt data
|
||||
if self._crypto_offload is not None:
|
||||
self._crypto_offload.add_decrypt_chunk(
|
||||
str(dd.final_path), offsets,
|
||||
dd.entity.encryption_metadata.symmetric_key,
|
||||
iv, encdata)
|
||||
# data will be completed once retrieved from crypto queue
|
||||
continue
|
||||
else:
|
||||
data = blobxfer.crypto.operations.aes_cbc_decrypt_data(
|
||||
dd.entity.encryption_metadata.symmetric_key,
|
||||
iv, encdata, offsets.unpad)
|
||||
elif dd.must_compute_md5:
|
||||
# rolling compute md5
|
||||
dd.perform_chunked_integrity_check(offsets, data)
|
||||
# complete chunk download
|
||||
self._complete_chunk_download(offsets, data, dd)
|
||||
|
||||
# write data to disk
|
||||
|
||||
# if no integrity check could be performed due to current
|
||||
# integrity offset mismatch, add to unchecked set
|
||||
|
||||
dd.dec_outstanding_operations()
|
||||
|
||||
# pickle dd to resume file
|
||||
|
||||
# rfile = dd._ase
|
||||
# print('<<', rfile.container, rfile.name, rfile.lmt, rfile.size,
|
||||
# rfile.md5, rfile.mode, rfile.encryption_metadata)
|
||||
def _complete_chunk_download(self, offsets, data, dd):
|
||||
# type: (Downloader, blobxfer.models.DownloadOffsets, bytes,
|
||||
# blobxfer.models.DownloadDescriptor) -> None
|
||||
"""Complete chunk download
|
||||
:param Downloader self: this
|
||||
:param blobxfer.models.DownloadOffsets offsets: offsets
|
||||
:param bytes data: data
|
||||
:param blobxfer.models.DownloadDescriptor dd: download descriptor
|
||||
"""
|
||||
# write data to disk
|
||||
dd.write_data(offsets, data)
|
||||
# decrement outstanding operations
|
||||
dd.dec_outstanding_operations()
|
||||
# TODO pickle dd to resume file
|
||||
|
||||
def _run(self):
|
||||
# type: (Downloader) -> None
|
||||
|
@ -335,7 +386,14 @@ class Downloader(object):
|
|||
# initialize MD5 processes
|
||||
self._md5_offload = blobxfer.md5.LocalFileMd5Offload(
|
||||
num_workers=self._general_options.concurrency.md5_processes)
|
||||
self._initialize_check_md5_downloads_thread()
|
||||
self._md5_offload.initialize_check_thread(
|
||||
self._check_for_downloads_from_md5)
|
||||
# initialize crypto processes
|
||||
if self._general_options.concurrency.crypto_processes > 0:
|
||||
self._crypto_offload = blobxfer.crypto.operations.CryptoOffload(
|
||||
num_workers=self._general_options.concurrency.crypto_processes)
|
||||
self._crypto_offload.initialize_check_thread(
|
||||
self._check_for_crypto_done)
|
||||
# initialize download threads
|
||||
self._initialize_download_threads()
|
||||
# iterate through source paths to download
|
||||
|
@ -344,6 +402,7 @@ class Downloader(object):
|
|||
skipped_files = 0
|
||||
total_size = 0
|
||||
skipped_size = 0
|
||||
self._time_start = time.clock()
|
||||
for src in self._spec.sources:
|
||||
for rfile in src.files(
|
||||
self._creds, self._spec.options, self._general_options):
|
||||
|
@ -369,33 +428,41 @@ class Downloader(object):
|
|||
self._add_to_download_queue(lpath, rfile)
|
||||
download_files = nfiles - skipped_files
|
||||
download_size = total_size - skipped_size
|
||||
download_size_mib = download_size / 1048576
|
||||
# clean up processes and threads
|
||||
with self._md5_meta_lock:
|
||||
self._all_remote_files_processed = True
|
||||
logger.debug(
|
||||
('{0} remote files processed, waiting for download completion '
|
||||
'of {1:.4f} MiB').format(nfiles, download_size / 1048576))
|
||||
self._md5_check_thread.join()
|
||||
'of {1:.4f} MiB').format(nfiles, download_size_mib))
|
||||
self._wait_for_download_threads(terminate=False)
|
||||
self._md5_offload.finalize_md5_processes()
|
||||
end = time.clock()
|
||||
runtime = end - self._time_start
|
||||
if (self._download_count != download_files or
|
||||
self._download_total_bytes != download_size):
|
||||
raise RuntimeError(
|
||||
'download mismatch: [count={}/{} bytes={}/{}]'.format(
|
||||
self._download_count, download_files,
|
||||
self._download_total_bytes, download_size))
|
||||
logger.info('all files downloaded')
|
||||
logger.info('all files downloaded: {0:.3f} sec {1:.4f} Mbps'.format(
|
||||
runtime, download_size_mib * 8 / runtime))
|
||||
|
||||
def start(self):
|
||||
# type: (Downloader) -> None
|
||||
"""Start the Downloader"""
|
||||
try:
|
||||
self._run()
|
||||
except KeyboardInterrupt:
|
||||
logger.error(
|
||||
'KeyboardInterrupt detected, force terminating '
|
||||
'processes and threads (this may take a while)...')
|
||||
except (KeyboardInterrupt, Exception) as ex:
|
||||
if isinstance(ex, KeyboardInterrupt):
|
||||
logger.error(
|
||||
'KeyboardInterrupt detected, force terminating '
|
||||
'processes and threads (this may take a while)...')
|
||||
self._wait_for_download_threads(terminate=True)
|
||||
self._md5_offload.finalize_md5_processes()
|
||||
# TODO delete all temp files
|
||||
# TODO close resume file in finally?
|
||||
raise
|
||||
finally:
|
||||
if self._md5_offload is not None:
|
||||
self._md5_offload.finalize_processes()
|
||||
if self._crypto_offload is not None:
|
||||
self._crypto_offload.finalize_processes()
|
||||
|
|
|
@ -30,8 +30,6 @@ from builtins import ( # noqa
|
|||
)
|
||||
# stdlib imports
|
||||
import logging
|
||||
import hashlib
|
||||
import multiprocessing
|
||||
try:
|
||||
import queue
|
||||
except ImportError: # noqa
|
||||
|
@ -39,21 +37,14 @@ except ImportError: # noqa
|
|||
# non-stdlib imports
|
||||
# local imports
|
||||
import blobxfer.download
|
||||
import blobxfer.models
|
||||
import blobxfer.offload
|
||||
import blobxfer.util
|
||||
|
||||
# create logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def new_md5_hasher():
|
||||
# type: (None) -> md5.MD5
|
||||
"""Create a new MD5 hasher
|
||||
:rtype: md5.MD5
|
||||
:return: new MD5 hasher
|
||||
"""
|
||||
return hashlib.md5()
|
||||
|
||||
|
||||
def compute_md5_for_file_asbase64(filename, pagealign=False, blocksize=65536):
|
||||
# type: (str, bool, int) -> str
|
||||
"""Compute MD5 hash for file and encode as Base64
|
||||
|
@ -63,7 +54,7 @@ def compute_md5_for_file_asbase64(filename, pagealign=False, blocksize=65536):
|
|||
:rtype: str
|
||||
:return: MD5 for file encoded as Base64
|
||||
"""
|
||||
hasher = new_md5_hasher()
|
||||
hasher = blobxfer.util.new_md5_hasher()
|
||||
with open(filename, 'rb') as filedesc:
|
||||
while True:
|
||||
buf = filedesc.read(blocksize)
|
||||
|
@ -85,12 +76,12 @@ def compute_md5_for_data_asbase64(data):
|
|||
:rtype: str
|
||||
:return: MD5 for data
|
||||
"""
|
||||
hasher = new_md5_hasher()
|
||||
hasher = blobxfer.util.new_md5_hasher()
|
||||
hasher.update(data)
|
||||
return blobxfer.util.base64_encode_as_string(hasher.digest())
|
||||
|
||||
|
||||
class LocalFileMd5Offload(object):
|
||||
class LocalFileMd5Offload(blobxfer.offload._MultiprocessOffload):
|
||||
"""LocalFileMd5Offload"""
|
||||
def __init__(self, num_workers):
|
||||
# type: (LocalFileMd5Offload, int) -> None
|
||||
|
@ -98,52 +89,14 @@ class LocalFileMd5Offload(object):
|
|||
:param LocalFileMd5Offload self: this
|
||||
:param int num_workers: number of worker processes
|
||||
"""
|
||||
self._task_queue = multiprocessing.Queue()
|
||||
self._done_queue = multiprocessing.Queue()
|
||||
self._done_cv = multiprocessing.Condition()
|
||||
self._term_signal = multiprocessing.Value('i', 0)
|
||||
self._md5_procs = []
|
||||
self._initialize_md5_processes(num_workers)
|
||||
super(LocalFileMd5Offload, self).__init__(num_workers, 'MD5')
|
||||
|
||||
@property
|
||||
def done_cv(self):
|
||||
# type: (LocalFileMd5Offload) -> multiprocessing.Condition
|
||||
"""Get Download Done condition variable
|
||||
:param LocalFileMd5Offload self: this
|
||||
:rtype: multiprocessing.Condition
|
||||
:return: cv for download done
|
||||
"""
|
||||
return self._done_cv
|
||||
|
||||
def _initialize_md5_processes(self, num_workers):
|
||||
# type: (LocalFileMd5Offload, int) -> None
|
||||
"""Initialize MD5 checking processes for files for download
|
||||
:param LocalFileMd5Offload self: this
|
||||
:param int num_workers: number of worker processes
|
||||
"""
|
||||
if num_workers is None or num_workers < 1:
|
||||
raise ValueError('invalid num_workers: {}'.format(num_workers))
|
||||
for _ in range(num_workers):
|
||||
proc = multiprocessing.Process(
|
||||
target=self._worker_compute_md5_localfile_process)
|
||||
proc.start()
|
||||
self._md5_procs.append(proc)
|
||||
|
||||
def finalize_md5_processes(self):
|
||||
# type: (LocalFileMd5Offload) -> None
|
||||
"""Finalize MD5 checking processes for files for download
|
||||
:param LocalFileMd5Offload self: this
|
||||
"""
|
||||
self._term_signal.value = 1
|
||||
for proc in self._md5_procs:
|
||||
proc.join()
|
||||
|
||||
def _worker_compute_md5_localfile_process(self):
|
||||
def _worker_process(self):
|
||||
# type: (LocalFileMd5Offload) -> None
|
||||
"""Compute MD5 for local file
|
||||
:param LocalFileMd5Offload self: this
|
||||
"""
|
||||
while self._term_signal.value == 0:
|
||||
while not self.terminated:
|
||||
try:
|
||||
filename, remote_md5, pagealign = self._task_queue.get(True, 1)
|
||||
except queue.Empty:
|
||||
|
@ -153,31 +106,17 @@ class LocalFileMd5Offload(object):
|
|||
md5, remote_md5, filename))
|
||||
self._done_cv.acquire()
|
||||
self._done_queue.put((filename, md5 == remote_md5))
|
||||
self.done_cv.notify()
|
||||
self.done_cv.release()
|
||||
|
||||
def get_localfile_md5_done(self):
|
||||
# type: (LocalFileMd5Offload) -> Tuple[str, bool]
|
||||
"""Get from done queue of local files with MD5 completed
|
||||
:param LocalFileMd5Offload self: this
|
||||
:rtype: tuple or None
|
||||
:return: (local file path, md5 match)
|
||||
"""
|
||||
try:
|
||||
return self._done_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
return None
|
||||
self._done_cv.notify()
|
||||
self._done_cv.release()
|
||||
|
||||
def add_localfile_for_md5_check(self, filename, remote_md5, mode):
|
||||
# type: (LocalFileMd5Offload, str, str,
|
||||
# blobxfer.models.AzureStorageModes) -> bool
|
||||
"""Check an MD5 for a file for download
|
||||
# blobxfer.models.AzureStorageModes) -> None
|
||||
"""Add a local file to MD5 check queue
|
||||
:param LocalFileMd5Offload self: this
|
||||
:param str filename: file to compute MD5 for
|
||||
:param str remote_md5: remote MD5 to compare against
|
||||
:param blobxfer.models.AzureStorageModes mode: mode
|
||||
:rtype: bool
|
||||
:return: MD5 match comparison
|
||||
"""
|
||||
if mode == blobxfer.models.AzureStorageModes.Page:
|
||||
pagealign = True
|
||||
|
|
|
@ -41,6 +41,8 @@ try:
|
|||
except ImportError: # noqa
|
||||
import pathlib
|
||||
import multiprocessing
|
||||
import tempfile
|
||||
import threading
|
||||
# non-stdlib imports
|
||||
# local imports
|
||||
from .api import (
|
||||
|
@ -53,7 +55,6 @@ from azure.storage.blob.models import _BlobTypes as BlobTypes
|
|||
import blobxfer.blob.operations
|
||||
import blobxfer.file.operations
|
||||
import blobxfer.crypto.models
|
||||
import blobxfer.md5
|
||||
import blobxfer.util
|
||||
|
||||
# create logger
|
||||
|
@ -126,6 +127,7 @@ LocalPath = collections.namedtuple(
|
|||
)
|
||||
DownloadOffsets = collections.namedtuple(
|
||||
'DownloadOffsets', [
|
||||
'chunk_num',
|
||||
'fd_start',
|
||||
'num_bytes',
|
||||
'range_end',
|
||||
|
@ -133,6 +135,14 @@ DownloadOffsets = collections.namedtuple(
|
|||
'unpad',
|
||||
]
|
||||
)
|
||||
UncheckedChunk = collections.namedtuple(
|
||||
'UncheckedChunk', [
|
||||
'data_len',
|
||||
'fd_start',
|
||||
'file_path',
|
||||
'temp',
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConcurrencyOptions(object):
|
||||
|
@ -147,16 +157,16 @@ class ConcurrencyOptions(object):
|
|||
self.crypto_processes = crypto_processes
|
||||
self.md5_processes = md5_processes
|
||||
self.transfer_threads = transfer_threads
|
||||
# allow crypto processes to be zero (which will inline crypto
|
||||
# routines with main process)
|
||||
if self.crypto_processes is None or self.crypto_processes < 1:
|
||||
self.crypto_processes = multiprocessing.cpu_count() // 2 - 1
|
||||
if self.crypto_processes < 1:
|
||||
self.crypto_processes = 1
|
||||
self.crypto_processes = 0
|
||||
if self.md5_processes is None or self.md5_processes < 1:
|
||||
self.md5_processes = multiprocessing.cpu_count() // 2
|
||||
if self.md5_processes < 1:
|
||||
self.md5_processes = 1
|
||||
if self.transfer_threads is None or self.transfer_threads < 1:
|
||||
self.transfer_threads = multiprocessing.cpu_count() * 2
|
||||
self.transfer_threads = multiprocessing.cpu_count() * 3
|
||||
|
||||
|
||||
class GeneralOptions(object):
|
||||
|
@ -824,7 +834,8 @@ class DownloadDescriptor(object):
|
|||
_tmp = list(lpath.parts[:-1])
|
||||
_tmp.append(lpath.name + '.bxtmp')
|
||||
self.local_path = pathlib.Path(*_tmp)
|
||||
self._meta_lock = multiprocessing.Lock()
|
||||
self._meta_lock = threading.Lock()
|
||||
self._hasher_lock = threading.Lock()
|
||||
self._ase = ase
|
||||
# calculate the total number of ops required for transfer
|
||||
self._chunk_size = min((options.chunk_size_bytes, self._ase.size))
|
||||
|
@ -835,9 +846,10 @@ class DownloadDescriptor(object):
|
|||
self._total_chunks = 0
|
||||
self.hmac = None
|
||||
self.md5 = None
|
||||
self.offset = 0
|
||||
self.integrity_counter = 0
|
||||
self.unchecked_chunks = set()
|
||||
self._offset = 0
|
||||
self._chunk_num = 0
|
||||
self._next_integrity_chunk = 0
|
||||
self._unchecked_chunks = {}
|
||||
self._outstanding_ops = self._total_chunks
|
||||
self._completed_ops = 0
|
||||
# initialize checkers and allocate space
|
||||
|
@ -871,9 +883,15 @@ class DownloadDescriptor(object):
|
|||
:param DownloadOptions options: download options
|
||||
"""
|
||||
if self._ase.is_encrypted:
|
||||
# ensure symmetric key exists
|
||||
if blobxfer.util.is_none_or_empty(
|
||||
self._ase.encryption_metadata.symmetric_key):
|
||||
raise RuntimeError(
|
||||
'symmetric key is invalid: provide RSA private key '
|
||||
'or metadata corrupt')
|
||||
self.hmac = self._ase.encryption_metadata.initialize_hmac()
|
||||
if self.hmac is None and options.check_file_md5:
|
||||
self.md5 = blobxfer.md5.new_md5_hasher()
|
||||
self.md5 = blobxfer.util.new_md5_hasher()
|
||||
|
||||
def _allocate_disk_space(self):
|
||||
# type: (DownloadDescriptor, int) -> None
|
||||
|
@ -912,48 +930,182 @@ class DownloadDescriptor(object):
|
|||
:rtype: DownloadOffsets
|
||||
:return: download offsets
|
||||
"""
|
||||
if self.offset >= self._ase.size:
|
||||
return None
|
||||
if self.offset + self._chunk_size > self._ase.size:
|
||||
chunk = self._ase.size - self.offset
|
||||
with self._meta_lock:
|
||||
if self._offset >= self._ase.size:
|
||||
return None
|
||||
if self._offset + self._chunk_size > self._ase.size:
|
||||
chunk = self._ase.size - self._offset
|
||||
else:
|
||||
chunk = self._chunk_size
|
||||
# on download, num_bytes must be offset by -1 as the x-ms-range
|
||||
# header expects it that way. x -> y bytes means first bits of the
|
||||
# (x+1)th byte to the last bits of the (y+1)th byte. for example,
|
||||
# 0 -> 511 means byte 1 to byte 512
|
||||
num_bytes = chunk - 1
|
||||
chunk_num = self._chunk_num
|
||||
fd_start = self._offset
|
||||
range_start = self._offset
|
||||
if self._ase.is_encrypted:
|
||||
# ensure start is AES block size aligned
|
||||
range_start = range_start - \
|
||||
(range_start % self._AES_BLOCKSIZE) - \
|
||||
self._AES_BLOCKSIZE
|
||||
if range_start <= 0:
|
||||
range_start = 0
|
||||
range_end = self._offset + num_bytes
|
||||
self._offset += chunk
|
||||
self._chunk_num += 1
|
||||
if self._ase.is_encrypted and self._offset >= self._ase.size:
|
||||
unpad = True
|
||||
else:
|
||||
unpad = False
|
||||
return DownloadOffsets(
|
||||
chunk_num=chunk_num,
|
||||
fd_start=fd_start,
|
||||
num_bytes=chunk,
|
||||
range_start=range_start,
|
||||
range_end=range_end,
|
||||
unpad=unpad,
|
||||
)
|
||||
|
||||
def _postpone_integrity_check(self, offsets, data):
|
||||
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
|
||||
"""Postpone integrity check for chunk
|
||||
:param DownloadDescriptor self: this
|
||||
:param DownloadOffsets offsets: download offsets
|
||||
:param bytes data: data
|
||||
"""
|
||||
if self.must_compute_md5:
|
||||
with self.local_path.open('r+b') as fd:
|
||||
fd.seek(offsets.fd_start, 0)
|
||||
fd.write(data)
|
||||
unchecked = UncheckedChunk(
|
||||
data_len=len(data),
|
||||
fd_start=offsets.fd_start,
|
||||
file_path=self.local_path,
|
||||
temp=False,
|
||||
)
|
||||
else:
|
||||
chunk = self._chunk_size
|
||||
# on download, num_bytes must be offset by -1 as the x-ms-range
|
||||
# header expects it that way. x -> y bytes means first bits of the
|
||||
# (x+1)th byte to the last bits of the (y+1)th byte. for example,
|
||||
# 0 -> 511 means byte 1 to byte 512
|
||||
num_bytes = chunk - 1
|
||||
fd_start = self.offset
|
||||
range_start = self.offset
|
||||
if self._ase.is_encrypted:
|
||||
# ensure start is AES block size aligned
|
||||
range_start = range_start - (range_start % self._AES_BLOCKSIZE) - \
|
||||
self._AES_BLOCKSIZE
|
||||
if range_start <= 0:
|
||||
range_start = 0
|
||||
range_end = self.offset + num_bytes
|
||||
self.offset += chunk
|
||||
if self._ase.is_encrypted and self.offset >= self._ase.size:
|
||||
unpad = True
|
||||
fname = None
|
||||
with tempfile.NamedTemporaryFile(mode='wb', delete=False) as fd:
|
||||
fname = fd.name
|
||||
fd.write(data)
|
||||
unchecked = UncheckedChunk(
|
||||
data_len=len(data),
|
||||
fd_start=0,
|
||||
file_path=pathlib.Path(fname),
|
||||
temp=True,
|
||||
)
|
||||
with self._meta_lock:
|
||||
self._unchecked_chunks[offsets.chunk_num] = unchecked
|
||||
|
||||
def perform_chunked_integrity_check(self, offsets, data):
|
||||
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
|
||||
"""Hash data against stored MD5 hasher safely
|
||||
:param DownloadDescriptor self: this
|
||||
:param DownloadOffsets offsets: download offsets
|
||||
:param bytes data: data
|
||||
"""
|
||||
self_check = False
|
||||
hasher = self.hmac or self.md5
|
||||
# iterate from next chunk to be checked
|
||||
while True:
|
||||
ucc = None
|
||||
with self._meta_lock:
|
||||
chunk_num = self._next_integrity_chunk
|
||||
# check if the next chunk is ready
|
||||
if chunk_num in self._unchecked_chunks:
|
||||
ucc = self._unchecked_chunks.pop(chunk_num)
|
||||
elif chunk_num != offsets.chunk_num:
|
||||
break
|
||||
# prepare data for hashing
|
||||
if ucc is None:
|
||||
chunk = data
|
||||
self_check = True
|
||||
else:
|
||||
with ucc.file_path.open('rb') as fd:
|
||||
fd.seek(ucc.fd_start, 0)
|
||||
chunk = fd.read(ucc.data_len)
|
||||
if ucc.temp:
|
||||
ucc.file_path.unlink()
|
||||
# hash data and set next integrity chunk
|
||||
with self._hasher_lock:
|
||||
hasher.update(chunk)
|
||||
with self._meta_lock:
|
||||
self._next_integrity_chunk += 1
|
||||
# store data that hasn't been checked
|
||||
if not self_check:
|
||||
self._postpone_integrity_check(offsets, data)
|
||||
|
||||
def write_data(self, offsets, data):
|
||||
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
|
||||
"""Postpone integrity check for chunk
|
||||
:param DownloadDescriptor self: this
|
||||
:param DownloadOffsets offsets: download offsets
|
||||
:param bytes data: data
|
||||
"""
|
||||
with self.local_path.open('r+b') as fd:
|
||||
fd.seek(offsets.fd_start, 0)
|
||||
fd.write(data)
|
||||
|
||||
def finalize_file(self):
|
||||
# type: (DownloadDescriptor) -> Tuple[bool, str]
|
||||
"""Finalize file download
|
||||
:param DownloadDescriptor self: this
|
||||
:rtype: tuple
|
||||
:return (if integrity check passed or not, message)
|
||||
"""
|
||||
# check final file integrity
|
||||
check = False
|
||||
msg = None
|
||||
if self.hmac is not None:
|
||||
mac = self._ase.encryption_metadata.encryption_authentication.\
|
||||
message_authentication_code
|
||||
digest = blobxfer.util.base64_encode_as_string(self.hmac.digest())
|
||||
if digest == mac:
|
||||
check = True
|
||||
msg = '{}: {}, {} {} <L..R> {}'.format(
|
||||
self._ase.encryption_metadata.encryption_authentication.
|
||||
algorithm,
|
||||
'OK' if check else 'MISMATCH',
|
||||
self._ase.name,
|
||||
digest,
|
||||
mac,
|
||||
)
|
||||
elif self.md5 is not None:
|
||||
digest = blobxfer.util.base64_encode_as_string(self.md5.digest())
|
||||
if digest == self._ase.md5:
|
||||
check = True
|
||||
msg = 'MD5: {}, {} {} <L..R> {}'.format(
|
||||
'OK' if check else 'MISMATCH',
|
||||
self._ase.name,
|
||||
digest,
|
||||
self._ase.md5,
|
||||
)
|
||||
else:
|
||||
unpad = False
|
||||
return DownloadOffsets(
|
||||
fd_start=fd_start,
|
||||
num_bytes=chunk,
|
||||
range_start=range_start,
|
||||
range_end=range_end,
|
||||
unpad=unpad,
|
||||
)
|
||||
check = True
|
||||
msg = 'MD5: SKIPPED, {} None <L..R> {}'.format(
|
||||
self._ase.name,
|
||||
self._ase.md5
|
||||
)
|
||||
# cleanup if download failed
|
||||
if not check:
|
||||
logger.error(msg)
|
||||
# delete temp download file
|
||||
self.local_path.unlink()
|
||||
return
|
||||
logger.debug(msg)
|
||||
|
||||
# TODO set file uid/gid and mode
|
||||
|
||||
# move temp download file to final path
|
||||
self.local_path.rename(self.final_path)
|
||||
|
||||
@property
|
||||
def outstanding_operations(self):
|
||||
def all_operations_completed(self):
|
||||
with self._meta_lock:
|
||||
return self._outstanding_ops
|
||||
|
||||
@property
|
||||
def completed_operations(self):
|
||||
with self._meta_lock:
|
||||
return self._completed_ops
|
||||
return (self._outstanding_ops == 0 and
|
||||
len(self._unchecked_chunks) == 0)
|
||||
|
||||
def dec_outstanding_operations(self):
|
||||
with self._meta_lock:
|
||||
|
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Microsoft Corporation
|
||||
#
|
||||
# All rights reserved.
|
||||
#
|
||||
# MIT License
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a
|
||||
# copy of this software and associated documentation files (the "Software"),
|
||||
# to deal in the Software without restriction, including without limitation
|
||||
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
# and/or sell copies of the Software, and to permit persons to whom the
|
||||
# Software is furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in
|
||||
# all copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
# DEALINGS IN THE SOFTWARE.
|
||||
|
||||
# compat imports
|
||||
from __future__ import (
|
||||
absolute_import, division, print_function, unicode_literals
|
||||
)
|
||||
from builtins import ( # noqa
|
||||
bytes, dict, int, list, object, range, ascii, chr, hex, input,
|
||||
next, oct, open, pow, round, super, filter, map, zip)
|
||||
# stdlib imports
|
||||
import logging
|
||||
import multiprocessing
|
||||
import threading
|
||||
try:
|
||||
import queue
|
||||
except ImportError: # noqa
|
||||
import Queue as queue
|
||||
|
||||
# create logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _MultiprocessOffload(object):
|
||||
def __init__(self, num_workers, description=None):
|
||||
# type: (_MultiprocessOffload, int, str) -> None
|
||||
"""Ctor for Crypto Offload
|
||||
:param _MultiprocessOffload self: this
|
||||
:param int num_workers: number of worker processes
|
||||
:param str description: description
|
||||
"""
|
||||
self._task_queue = multiprocessing.Queue()
|
||||
self._done_queue = multiprocessing.Queue()
|
||||
self._done_cv = multiprocessing.Condition()
|
||||
self._term_signal = multiprocessing.Value('i', 0)
|
||||
self._procs = []
|
||||
self._check_thread = None
|
||||
self._initialize_processes(num_workers, description)
|
||||
|
||||
@property
|
||||
def done_cv(self):
|
||||
# type: (_MultiprocessOffload) -> multiprocessing.Condition
|
||||
"""Get Done condition variable
|
||||
:param _MultiprocessOffload self: this
|
||||
:rtype: multiprocessing.Condition
|
||||
:return: cv for download done
|
||||
"""
|
||||
return self._done_cv
|
||||
|
||||
@property
|
||||
def terminated(self):
|
||||
# type: (_MultiprocessOffload) -> bool
|
||||
"""Check if terminated
|
||||
:param _MultiprocessOffload self: this
|
||||
:rtype: bool
|
||||
:return: if terminated
|
||||
"""
|
||||
return self._term_signal.value == 1
|
||||
|
||||
def _initialize_processes(self, num_workers, description):
|
||||
# type: (_MultiprocessOffload, int, str) -> None
|
||||
"""Initialize processes
|
||||
:param _MultiprocessOffload self: this
|
||||
:param int num_workers: number of worker processes
|
||||
:param str description: description
|
||||
"""
|
||||
if num_workers is None or num_workers < 1:
|
||||
raise ValueError('invalid num_workers: {}'.format(num_workers))
|
||||
logger.debug('initializing {}{} processes'.format(
|
||||
num_workers, ' ' + description if not None else ''))
|
||||
for _ in range(num_workers):
|
||||
proc = multiprocessing.Process(target=self._worker_process)
|
||||
proc.start()
|
||||
self._procs.append(proc)
|
||||
|
||||
def finalize_processes(self):
|
||||
# type: (_MultiprocessOffload) -> None
|
||||
"""Finalize processes
|
||||
:param _MultiprocessOffload self: this
|
||||
"""
|
||||
self._term_signal.value = 1
|
||||
if self._check_thread is not None:
|
||||
self._check_thread.join()
|
||||
for proc in self._procs:
|
||||
proc.join()
|
||||
|
||||
def pop_done_queue(self):
|
||||
# type: (_MultiprocessOffload) -> object
|
||||
"""Get item from done queue
|
||||
:param _MultiprocessOffload self: this
|
||||
:rtype: object or None
|
||||
:return: object from done queue, if exists
|
||||
"""
|
||||
try:
|
||||
return self._done_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def initialize_check_thread(self, check_func):
|
||||
# type: (_MultiprocessOffload, object) -> None
|
||||
"""Initialize the crypto done queue check thread
|
||||
:param Downloader self: this
|
||||
:param object check_func: check function
|
||||
"""
|
||||
self._check_thread = threading.Thread(target=check_func)
|
||||
self._check_thread.start()
|
|
@ -32,6 +32,7 @@ from builtins import ( # noqa
|
|||
import base64
|
||||
import copy
|
||||
import dateutil
|
||||
import hashlib
|
||||
import logging
|
||||
import logging.handlers
|
||||
import mimetypes
|
||||
|
@ -164,6 +165,15 @@ def base64_decode_string(string):
|
|||
return base64.b64decode(string)
|
||||
|
||||
|
||||
def new_md5_hasher():
|
||||
# type: (None) -> md5.MD5
|
||||
"""Create a new MD5 hasher
|
||||
:rtype: md5.MD5
|
||||
:return: new MD5 hasher
|
||||
"""
|
||||
return hashlib.md5()
|
||||
|
||||
|
||||
def page_align_content_length(length):
|
||||
# type: (int) -> int
|
||||
"""Compute page boundary alignment
|
||||
|
|
|
@ -265,9 +265,9 @@ def create_download_specifications(config):
|
|||
elif confmode == 'block':
|
||||
mode = blobxfer.models.AzureStorageModes.Block
|
||||
elif confmode == 'file':
|
||||
mode == blobxfer.models.AzureStorageModes.File
|
||||
mode = blobxfer.models.AzureStorageModes.File
|
||||
elif confmode == 'page':
|
||||
mode == blobxfer.models.AzureStorageModes.Page
|
||||
mode = blobxfer.models.AzureStorageModes.Page
|
||||
else:
|
||||
raise ValueError('unknown mode: {}'.format(confmode))
|
||||
# load RSA private key PEM file if specified
|
||||
|
|
|
@ -4,7 +4,6 @@
|
|||
# stdlib imports
|
||||
# non-stdlib imports
|
||||
import azure.storage
|
||||
import pytest
|
||||
# local imports
|
||||
import blobxfer.models as models
|
||||
# module under test
|
||||
|
|
|
@ -46,6 +46,42 @@ def test_rsa_encrypt_decrypt_keys():
|
|||
|
||||
def test_pkcs7_padding():
|
||||
buf = os.urandom(32)
|
||||
pbuf = ops.pad_pkcs7(buf)
|
||||
buf2 = ops.unpad_pkcs7(pbuf)
|
||||
pbuf = ops.pkcs7_pad(buf)
|
||||
buf2 = ops.pkcs7_unpad(pbuf)
|
||||
assert buf == buf2
|
||||
|
||||
|
||||
def test_aes_cbc_encryption():
|
||||
enckey = ops.aes256_generate_random_key()
|
||||
assert len(enckey) == ops._AES256_KEYLENGTH_BYTES
|
||||
|
||||
# test random binary data, unaligned
|
||||
iv = os.urandom(16)
|
||||
plaindata = os.urandom(31)
|
||||
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
|
||||
assert encdata != plaindata
|
||||
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
|
||||
assert decdata == plaindata
|
||||
|
||||
# test random binary data aligned on boundary
|
||||
plaindata = os.urandom(32)
|
||||
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
|
||||
assert encdata != plaindata
|
||||
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
|
||||
assert decdata == plaindata
|
||||
|
||||
# test "text" data
|
||||
plaintext = 'attack at dawn!'
|
||||
plaindata = plaintext.encode('utf8')
|
||||
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
|
||||
assert encdata != plaindata
|
||||
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
|
||||
assert decdata == plaindata
|
||||
assert plaindata.decode('utf8') == plaintext
|
||||
|
||||
# test unpadded
|
||||
plaindata = os.urandom(32)
|
||||
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, False)
|
||||
assert encdata != plaindata
|
||||
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, False)
|
||||
assert decdata == plaindata
|
||||
|
|
|
@ -197,27 +197,18 @@ def test_post_md5_skip_on_check():
|
|||
assert d._add_to_download_queue.call_count == 1
|
||||
|
||||
|
||||
def test_initialize_check_md5_downloads_thread():
|
||||
def test_check_for_downloads_from_md5():
|
||||
lpath = 'lpath'
|
||||
d = dl.Downloader(mock.MagicMock(), mock.MagicMock(), mock.MagicMock())
|
||||
d._md5_map[lpath] = mock.MagicMock()
|
||||
d._download_set.add(pathlib.Path(lpath))
|
||||
d._md5_offload = mock.MagicMock()
|
||||
d._md5_offload.done_cv = multiprocessing.Condition()
|
||||
d._md5_offload.get_localfile_md5_done = mock.MagicMock()
|
||||
d._md5_offload.get_localfile_md5_done.side_effect = [None, (lpath, False)]
|
||||
d._md5_offload.pop_done_queue.side_effect = [None, (lpath, False)]
|
||||
d._add_to_download_queue = mock.MagicMock()
|
||||
|
||||
d._initialize_check_md5_downloads_thread()
|
||||
while len(d._md5_map) > 0:
|
||||
d._md5_offload.done_cv.acquire()
|
||||
d._md5_offload.done_cv.notify()
|
||||
d._md5_offload.done_cv.release()
|
||||
d._all_remote_files_processed = True
|
||||
d._md5_offload.done_cv.acquire()
|
||||
d._md5_offload.done_cv.notify()
|
||||
d._md5_offload.done_cv.release()
|
||||
d._md5_check_thread.join()
|
||||
with pytest.raises(StopIteration):
|
||||
d._check_for_downloads_from_md5()
|
||||
|
||||
assert d._add_to_download_queue.call_count == 1
|
||||
|
||||
|
@ -237,14 +228,15 @@ def test_initialize_and_terminate_download_threads():
|
|||
assert not thr.is_alive()
|
||||
|
||||
|
||||
@mock.patch('time.clock')
|
||||
@mock.patch('blobxfer.md5.LocalFileMd5Offload')
|
||||
@mock.patch('blobxfer.blob.operations.list_blobs')
|
||||
@mock.patch('blobxfer.operations.ensure_local_destination', return_value=True)
|
||||
def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
|
||||
def test_start(patched_eld, patched_lb, patched_lfmo, patched_tc, tmpdir):
|
||||
d = dl.Downloader(mock.MagicMock(), mock.MagicMock(), mock.MagicMock())
|
||||
d._initialize_check_md5_downloads_thread = mock.MagicMock()
|
||||
d._initialize_download_threads = mock.MagicMock()
|
||||
d._md5_check_thread = mock.MagicMock()
|
||||
patched_lfmo._check_thread = mock.MagicMock()
|
||||
d._general_options.concurrency.crypto_processes = 0
|
||||
d._spec.sources = []
|
||||
d._spec.options = mock.MagicMock()
|
||||
d._spec.options.chunk_size_bytes = 1
|
||||
|
@ -270,12 +262,14 @@ def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
|
|||
|
||||
d._check_download_conditions = mock.MagicMock()
|
||||
d._check_download_conditions.return_value = dl.DownloadAction.Skip
|
||||
patched_tc.side_effect = [1, 2]
|
||||
d.start()
|
||||
assert d._pre_md5_skip_on_check.call_count == 0
|
||||
|
||||
patched_lb.side_effect = [[b]]
|
||||
d._all_remote_files_processed = False
|
||||
d._check_download_conditions.return_value = dl.DownloadAction.CheckMd5
|
||||
patched_tc.side_effect = [1, 2]
|
||||
with pytest.raises(RuntimeError):
|
||||
d.start()
|
||||
assert d._pre_md5_skip_on_check.call_count == 1
|
||||
|
@ -284,6 +278,7 @@ def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
|
|||
patched_lb.side_effect = [[b]]
|
||||
d._all_remote_files_processed = False
|
||||
d._check_download_conditions.return_value = dl.DownloadAction.Download
|
||||
patched_tc.side_effect = [1, 2]
|
||||
with pytest.raises(RuntimeError):
|
||||
d.start()
|
||||
assert d._download_queue.qsize() == 1
|
||||
|
|
|
@ -36,7 +36,7 @@ def test_done_cv():
|
|||
assert a.done_cv == a._done_cv
|
||||
finally:
|
||||
if a:
|
||||
a.finalize_md5_processes()
|
||||
a.finalize_processes()
|
||||
|
||||
|
||||
def test_finalize_md5_processes():
|
||||
|
@ -48,9 +48,9 @@ def test_finalize_md5_processes():
|
|||
a = md5.LocalFileMd5Offload(num_workers=1)
|
||||
finally:
|
||||
if a:
|
||||
a.finalize_md5_processes()
|
||||
a.finalize_processes()
|
||||
|
||||
for proc in a._md5_procs:
|
||||
for proc in a._procs:
|
||||
assert not proc.is_alive()
|
||||
|
||||
|
||||
|
@ -63,7 +63,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
|
|||
a = None
|
||||
try:
|
||||
a = md5.LocalFileMd5Offload(num_workers=1)
|
||||
result = a.get_localfile_md5_done()
|
||||
result = a.pop_done_queue()
|
||||
assert result is None
|
||||
|
||||
a.add_localfile_for_md5_check(
|
||||
|
@ -71,7 +71,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
|
|||
i = 33
|
||||
checked = False
|
||||
while i > 0:
|
||||
result = a.get_localfile_md5_done()
|
||||
result = a.pop_done_queue()
|
||||
if result is None:
|
||||
time.sleep(0.3)
|
||||
i -= 1
|
||||
|
@ -84,7 +84,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
|
|||
assert checked
|
||||
finally:
|
||||
if a:
|
||||
a.finalize_md5_processes()
|
||||
a.finalize_processes()
|
||||
|
||||
|
||||
def test_from_add_to_done_pagealigned(tmpdir):
|
||||
|
@ -96,7 +96,7 @@ def test_from_add_to_done_pagealigned(tmpdir):
|
|||
a = None
|
||||
try:
|
||||
a = md5.LocalFileMd5Offload(num_workers=1)
|
||||
result = a.get_localfile_md5_done()
|
||||
result = a.pop_done_queue()
|
||||
assert result is None
|
||||
|
||||
a.add_localfile_for_md5_check(
|
||||
|
@ -104,7 +104,7 @@ def test_from_add_to_done_pagealigned(tmpdir):
|
|||
i = 33
|
||||
checked = False
|
||||
while i > 0:
|
||||
result = a.get_localfile_md5_done()
|
||||
result = a.pop_done_queue()
|
||||
if result is None:
|
||||
time.sleep(0.3)
|
||||
i -= 1
|
||||
|
@ -117,4 +117,4 @@ def test_from_add_to_done_pagealigned(tmpdir):
|
|||
assert checked
|
||||
finally:
|
||||
if a:
|
||||
a.finalize_md5_processes()
|
||||
a.finalize_processes()
|
||||
|
|
|
@ -25,9 +25,9 @@ def test_concurrency_options(patched_cc):
|
|||
transfer_threads=-2,
|
||||
)
|
||||
|
||||
assert a.crypto_processes == 1
|
||||
assert a.crypto_processes == 0
|
||||
assert a.md5_processes == 1
|
||||
assert a.transfer_threads == 2
|
||||
assert a.transfer_threads == 3
|
||||
|
||||
|
||||
def test_general_options():
|
||||
|
@ -359,12 +359,16 @@ def test_downloaddescriptor(tmpdir):
|
|||
ase = models.AzureStorageEntity('cont')
|
||||
ase._size = 1024
|
||||
ase._encryption = mock.MagicMock()
|
||||
with pytest.raises(RuntimeError):
|
||||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
|
||||
ase._encryption.symmetric_key = b'123'
|
||||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
|
||||
assert d.entity == ase
|
||||
assert not d.must_compute_md5
|
||||
assert d._total_chunks == 64
|
||||
assert d.offset == 0
|
||||
assert d._offset == 0
|
||||
assert d.final_path == lp
|
||||
assert str(d.local_path) == str(lp) + '.bxtmp'
|
||||
assert d.local_path.stat().st_size == 1024 - 16
|
||||
|
@ -400,6 +404,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 1
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 128
|
||||
assert offsets.range_start == 0
|
||||
|
@ -416,6 +421,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 1
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 1
|
||||
assert offsets.range_start == 0
|
||||
|
@ -427,6 +433,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 1
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 256
|
||||
assert offsets.range_start == 0
|
||||
|
@ -438,12 +445,14 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 2
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 256
|
||||
assert offsets.range_start == 0
|
||||
assert offsets.range_end == 255
|
||||
assert not offsets.unpad
|
||||
offsets = d.next_offsets()
|
||||
assert offsets.chunk_num == 1
|
||||
assert offsets.fd_start == 256
|
||||
assert offsets.num_bytes == 16
|
||||
assert offsets.range_start == 256
|
||||
|
@ -452,10 +461,12 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
assert d.next_offsets() is None
|
||||
|
||||
ase._encryption = mock.MagicMock()
|
||||
ase._encryption.symmetric_key = b'123'
|
||||
ase._size = 128
|
||||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 1
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 128
|
||||
assert offsets.range_start == 0
|
||||
|
@ -467,6 +478,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 1
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 256
|
||||
assert offsets.range_start == 0
|
||||
|
@ -478,12 +490,14 @@ def test_downloaddescriptor_next_offsets(tmpdir):
|
|||
d = models.DownloadDescriptor(lp, ase, opts)
|
||||
offsets = d.next_offsets()
|
||||
assert d._total_chunks == 2
|
||||
assert offsets.chunk_num == 0
|
||||
assert offsets.fd_start == 0
|
||||
assert offsets.num_bytes == 256
|
||||
assert offsets.range_start == 0
|
||||
assert offsets.range_end == 255
|
||||
assert not offsets.unpad
|
||||
offsets = d.next_offsets()
|
||||
assert offsets.chunk_num == 1
|
||||
assert offsets.fd_start == 256
|
||||
assert offsets.num_bytes == 32
|
||||
assert offsets.range_start == 256 - 16
|
||||
|
|
Загрузка…
Ссылка в новой задаче