- Main download logic at parity with current blobxfer
- Refactor multiprocess offload into base class
- Add multiprocess crypto offload
This commit is contained in:
Fred Park 2017-02-26 02:02:56 -08:00
Родитель 31ef912cd6
Коммит d70d404b46
13 изменённых файлов: 695 добавлений и 247 удалений

Просмотреть файл

@ -35,20 +35,13 @@ import collections
import hashlib
import hmac
import json
import logging
# non-stdlib imports
# local imports
import blobxfer.crypto.operations
import blobxfer.util
# encryption constants
_AES256_KEYLENGTH_BYTES = 32
_AES256_BLOCKSIZE_BYTES = 16
_HMACSHA256_DIGESTSIZE_BYTES = 32
_AES256CBC_HMACSHA256_OVERHEAD_BYTES = (
_AES256_BLOCKSIZE_BYTES + _HMACSHA256_DIGESTSIZE_BYTES
)
# named tuples
EncryptionBlobxferExtensions = collections.namedtuple(
@ -191,8 +184,8 @@ class EncryptionMetadata(object):
)
except KeyError:
pass
self.content_encryption_iv = ed[
EncryptionMetadata._JSON_KEY_CONTENT_IV]
self.content_encryption_iv = base64.b64decode(
ed[EncryptionMetadata._JSON_KEY_CONTENT_IV])
self.encryption_agent = EncryptionAgent(
encryption_algorithm=ed[
EncryptionMetadata._JSON_KEY_ENCRYPTION_AGENT][

Просмотреть файл

@ -31,7 +31,13 @@ from builtins import ( # noqa
next, oct, open, pow, round, super, filter, map, zip)
# stdlib imports
import base64
import enum
import logging
import os
try:
import queue
except ImportError: # noqa
import Queue as queue
# non-stdlib imports
import cryptography.hazmat.backends
import cryptography.hazmat.primitives.asymmetric.padding
@ -44,7 +50,13 @@ import cryptography.hazmat.primitives.hashes
import cryptography.hazmat.primitives.padding
import cryptography.hazmat.primitives.serialization
# local imports
import blobxfer.util
import blobxfer.offload
# create logger
logger = logging.getLogger(__name__)
# encryption constants
_AES256_KEYLENGTH_BYTES = 32
def load_rsa_private_key_file(rsakeyfile, passphrase):
@ -130,7 +142,7 @@ def rsa_encrypt_key_base64_encoded(rsaprivatekey, rsapublickey, plainkey):
return blobxfer.util.base64_encode_as_string(enckey)
def pad_pkcs7(buf):
def pkcs7_pad(buf):
# type: (bytes) -> bytes
"""Appends PKCS7 padding to an input buffer
:param bytes buf: buffer to add padding
@ -143,7 +155,7 @@ def pad_pkcs7(buf):
return padder.update(buf) + padder.finalize()
def unpad_pkcs7(buf):
def pkcs7_unpad(buf):
# type: (bytes) -> bytes
"""Removes PKCS7 padding a decrypted object
:param bytes buf: buffer to remove padding
@ -154,3 +166,107 @@ def unpad_pkcs7(buf):
cryptography.hazmat.primitives.ciphers.
algorithms.AES.block_size).unpadder()
return unpadder.update(buf) + unpadder.finalize()
def aes256_generate_random_key():
# type: (None) -> bytes
"""Generate random AES256 key
:rtype: bytes
:return: random key
"""
return os.urandom(_AES256_KEYLENGTH_BYTES)
def aes_cbc_decrypt_data(symkey, iv, encdata, unpad):
# type: (bytes, bytes, bytes, bool) -> bytes
"""Decrypt data using AES CBC
:param bytes symkey: symmetric key
:param bytes iv: initialization vector
:param bytes encdata: data to decrypt
:param bool unpad: unpad data
:rtype: bytes
:return: decrypted data
"""
cipher = cryptography.hazmat.primitives.ciphers.Cipher(
cryptography.hazmat.primitives.ciphers.algorithms.AES(symkey),
cryptography.hazmat.primitives.ciphers.modes.CBC(iv),
backend=cryptography.hazmat.backends.default_backend()).decryptor()
decrypted = cipher.update(encdata) + cipher.finalize()
if unpad:
return pkcs7_unpad(decrypted)
else:
return decrypted
def aes_cbc_encrypt_data(symkey, iv, data, pad):
# type: (bytes, bytes, bytes, bool) -> bytes
"""Encrypt data using AES CBC
:param bytes symkey: symmetric key
:param bytes iv: initialization vector
:param bytes data: data to encrypt
:param bool pad: pad data
:rtype: bytes
:return: encrypted data
"""
cipher = cryptography.hazmat.primitives.ciphers.Cipher(
cryptography.hazmat.primitives.ciphers.algorithms.AES(symkey),
cryptography.hazmat.primitives.ciphers.modes.CBC(iv),
backend=cryptography.hazmat.backends.default_backend()).encryptor()
if pad:
return cipher.update(pkcs7_pad(data)) + cipher.finalize()
else:
return cipher.update(data) + cipher.finalize()
class CryptoAction(enum.Enum):
Encrypt = 1
Decrypt = 2
class CryptoOffload(blobxfer.offload._MultiprocessOffload):
def __init__(self, num_workers):
# type: (CryptoOffload, int) -> None
"""Ctor for Crypto Offload
:param CryptoOffload self: this
:param int num_workers: number of worker processes
"""
super(CryptoOffload, self).__init__(num_workers, 'Crypto')
def _worker_process(self):
# type: (CryptoOffload) -> None
"""Crypto worker
:param CryptoOffload self: this
"""
while not self.terminated:
try:
inst = self._task_queue.get(True, 1)
except queue.Empty:
continue
if inst[0] == CryptoAction.Encrypt:
# TODO on upload
raise NotImplementedError()
elif inst[0] == CryptoAction.Decrypt:
final_path, offsets, symkey, iv, encdata = \
inst[1], inst[2], inst[3], inst[4], inst[5]
data = aes_cbc_decrypt_data(symkey, iv, encdata, offsets.unpad)
self._done_cv.acquire()
self._done_queue.put((final_path, offsets, data))
self._done_cv.notify()
self._done_cv.release()
def add_decrypt_chunk(
self, final_path, offsets, symkey, iv, encdata):
# type: (CryptoOffload, str, blobxfer.models.DownloadOffsets, bytes,
# bytes, bytes) -> None
"""Add a chunk to decrypt
:param CryptoOffload self: this
:param str final_path: final path
:param blobxfer.models.DownloadOffsets offsets: offsets
:param bytes symkey: symmetric key
:param bytes iv: initialization vector
:param bytes encdata: encrypted data
"""
self._task_queue.put(
(CryptoAction.Decrypt, final_path, offsets, symkey, iv,
encdata)
)

Просмотреть файл

@ -43,9 +43,12 @@ try:
except ImportError: # noqa
import Queue as queue
import threading
import time
# non-stdlib imports
import dateutil
# local imports
import blobxfer.crypto.models
import blobxfer.crypto.operations
import blobxfer.md5
import blobxfer.models
import blobxfer.operations
@ -75,18 +78,20 @@ class Downloader(object):
:param blobxfer.models.AzureStorageCredentials creds: creds
:param blobxfer.models.DownloadSpecification spec: download spec
"""
self._md5_meta_lock = threading.Lock()
self._download_lock = threading.Lock()
self._time_start = None
self._all_remote_files_processed = False
self._crypto_offload = None
self._md5_meta_lock = threading.Lock()
self._md5_map = {}
self._md5_offload = None
self._md5_check_thread = None
self._download_lock = threading.Lock()
self._download_queue = queue.Queue()
self._download_set = set()
self._download_threads = []
self._download_count = 0
self._download_total_bytes = 0
self._download_terminate = False
self._dd_map = {}
self._general_options = general_options
self._creds = creds
self._spec = spec
@ -177,46 +182,69 @@ class Downloader(object):
else:
self._add_to_download_queue(lpath, rfile)
def _initialize_check_md5_downloads_thread(self):
def _check_for_downloads_from_md5(self):
# type: (Downloader) -> None
"""Initialize the md5 done queue check thread
"""Check queue for a file to download
:param Downloader self: this
"""
def _check_for_downloads_from_md5(self):
# type: (Downloader) -> None
"""Check queue for a file to download
:param Downloader self: this
"""
cv = self._md5_offload.done_cv
while True:
with self._md5_meta_lock:
if (self._download_terminate or
(len(self._md5_map) == 0 and
self._all_remote_files_processed)):
break
result = None
cv.acquire()
while not self._download_terminate:
result = self._md5_offload.get_localfile_md5_done()
if result is None:
# use cv timeout due to possible non-wake while running
cv.wait(1)
# check for terminating conditions
with self._md5_meta_lock:
if (len(self._md5_map) == 0 and
self._all_remote_files_processed):
break
else:
break
cv.release()
if result is not None:
self._post_md5_skip_on_check(result[0], result[1])
cv = self._md5_offload.done_cv
while True:
with self._md5_meta_lock:
if (self._download_terminate or
(self._all_remote_files_processed and
len(self._md5_map) == 0 and
len(self._download_set) == 0)):
break
result = None
cv.acquire()
while not self._download_terminate:
result = self._md5_offload.pop_done_queue()
if result is None:
# use cv timeout due to possible non-wake while running
cv.wait(1)
# check for terminating conditions
with self._md5_meta_lock:
if (self._all_remote_files_processed and
len(self._md5_map) == 0 and
len(self._download_set) == 0):
break
else:
break
cv.release()
if result is not None:
self._post_md5_skip_on_check(result[0], result[1])
self._md5_check_thread = threading.Thread(
target=_check_for_downloads_from_md5,
args=(self,)
)
self._md5_check_thread.start()
def _check_for_crypto_done(self):
# type: (Downloader) -> None
"""Check queue for crypto done
:param Downloader self: this
"""
cv = self._crypto_offload.done_cv
while True:
with self._download_lock:
if (self._download_terminate or
(self._all_remote_files_processed and
len(self._download_set) == 0)):
break
result = None
cv.acquire()
while not self._download_terminate:
result = self._crypto_offload.pop_done_queue()
if result is None:
# use cv timeout due to possible non-wake while running
cv.wait(1)
# check for terminating conditions
with self._download_lock:
if (self._all_remote_files_processed and
len(self._download_set) == 0):
break
else:
break
cv.release()
if result is not None:
with self._download_lock:
dd = self._dd_map[result[0]]
self._complete_chunk_download(result[1], result[2], dd)
def _add_to_download_queue(self, lpath, rfile):
# type: (Downloader, pathlib.Path,
@ -229,6 +257,9 @@ class Downloader(object):
# prepare remote file for download
dd = blobxfer.models.DownloadDescriptor(
lpath, rfile, self._spec.options)
if dd.entity.is_encrypted:
with self._download_lock:
self._dd_map[str(dd.final_path)] = dd
# add download descriptor to queue
self._download_queue.put(dd)
@ -250,7 +281,8 @@ class Downloader(object):
:param Downloader self: this
:param bool terminate: terminate threads
"""
self._download_terminate = terminate
if terminate:
self._download_terminate = terminate
for thr in self._download_threads:
thr.join()
@ -273,17 +305,15 @@ class Downloader(object):
# get download offsets
offsets = dd.next_offsets()
# check if all operations completed
if offsets is None and dd.outstanding_operations == 0:
# TODO
# 1. complete integrity checks
# 2. set file uid/gid
# 3. set file modes
# 4. move file to final path
if offsets is None and dd.all_operations_completed:
# finalize file
dd.finalize_file()
# accounting
with self._download_lock:
if dd.entity.is_encrypted:
self._dd_map.pop(str(dd.final_path))
self._download_set.remove(dd.final_path)
self._download_count += 1
logger.info('download complete: {}/{} to {}'.format(
dd.entity.container, dd.entity.name, dd.final_path))
continue
# re-enqueue for other threads to download
self._download_queue.put(dd)
@ -291,39 +321,60 @@ class Downloader(object):
continue
# issue get range
if dd.entity.mode == blobxfer.models.AzureStorageModes.File:
chunk = blobxfer.file.operations.get_file_range(
data = blobxfer.file.operations.get_file_range(
dd.entity, offsets, self._general_options.timeout_sec)
else:
chunk = blobxfer.blob.operations.get_blob_range(
data = blobxfer.blob.operations.get_blob_range(
dd.entity, offsets, self._general_options.timeout_sec)
# accounting
with self._download_lock:
self._download_total_bytes += offsets.num_bytes
# decrypt if necessary
if dd.entity.is_encrypted:
# TODO via crypto pool
# 1. compute rolling hmac if present
# - roll through any subsequent unchecked parts
# 2. decrypt chunk
pass
# compute rolling md5 via md5 pool
if dd.must_compute_md5:
# TODO
# - roll through any subsequent unchecked parts
pass
# slice data to proper bounds
encdata = data[blobxfer.crypto.models._AES256_BLOCKSIZE_BYTES:]
intdata = encdata
# get iv for chunk and compute hmac
if offsets.chunk_num == 0:
iv = dd.entity.encryption_metadata.content_encryption_iv
# integrity check for first chunk must include iv
intdata = iv + data
else:
iv = data[:blobxfer.crypto.models._AES256_BLOCKSIZE_BYTES]
# integrity check data
dd.perform_chunked_integrity_check(offsets, intdata)
# decrypt data
if self._crypto_offload is not None:
self._crypto_offload.add_decrypt_chunk(
str(dd.final_path), offsets,
dd.entity.encryption_metadata.symmetric_key,
iv, encdata)
# data will be completed once retrieved from crypto queue
continue
else:
data = blobxfer.crypto.operations.aes_cbc_decrypt_data(
dd.entity.encryption_metadata.symmetric_key,
iv, encdata, offsets.unpad)
elif dd.must_compute_md5:
# rolling compute md5
dd.perform_chunked_integrity_check(offsets, data)
# complete chunk download
self._complete_chunk_download(offsets, data, dd)
# write data to disk
# if no integrity check could be performed due to current
# integrity offset mismatch, add to unchecked set
dd.dec_outstanding_operations()
# pickle dd to resume file
# rfile = dd._ase
# print('<<', rfile.container, rfile.name, rfile.lmt, rfile.size,
# rfile.md5, rfile.mode, rfile.encryption_metadata)
def _complete_chunk_download(self, offsets, data, dd):
# type: (Downloader, blobxfer.models.DownloadOffsets, bytes,
# blobxfer.models.DownloadDescriptor) -> None
"""Complete chunk download
:param Downloader self: this
:param blobxfer.models.DownloadOffsets offsets: offsets
:param bytes data: data
:param blobxfer.models.DownloadDescriptor dd: download descriptor
"""
# write data to disk
dd.write_data(offsets, data)
# decrement outstanding operations
dd.dec_outstanding_operations()
# TODO pickle dd to resume file
def _run(self):
# type: (Downloader) -> None
@ -335,7 +386,14 @@ class Downloader(object):
# initialize MD5 processes
self._md5_offload = blobxfer.md5.LocalFileMd5Offload(
num_workers=self._general_options.concurrency.md5_processes)
self._initialize_check_md5_downloads_thread()
self._md5_offload.initialize_check_thread(
self._check_for_downloads_from_md5)
# initialize crypto processes
if self._general_options.concurrency.crypto_processes > 0:
self._crypto_offload = blobxfer.crypto.operations.CryptoOffload(
num_workers=self._general_options.concurrency.crypto_processes)
self._crypto_offload.initialize_check_thread(
self._check_for_crypto_done)
# initialize download threads
self._initialize_download_threads()
# iterate through source paths to download
@ -344,6 +402,7 @@ class Downloader(object):
skipped_files = 0
total_size = 0
skipped_size = 0
self._time_start = time.clock()
for src in self._spec.sources:
for rfile in src.files(
self._creds, self._spec.options, self._general_options):
@ -369,33 +428,41 @@ class Downloader(object):
self._add_to_download_queue(lpath, rfile)
download_files = nfiles - skipped_files
download_size = total_size - skipped_size
download_size_mib = download_size / 1048576
# clean up processes and threads
with self._md5_meta_lock:
self._all_remote_files_processed = True
logger.debug(
('{0} remote files processed, waiting for download completion '
'of {1:.4f} MiB').format(nfiles, download_size / 1048576))
self._md5_check_thread.join()
'of {1:.4f} MiB').format(nfiles, download_size_mib))
self._wait_for_download_threads(terminate=False)
self._md5_offload.finalize_md5_processes()
end = time.clock()
runtime = end - self._time_start
if (self._download_count != download_files or
self._download_total_bytes != download_size):
raise RuntimeError(
'download mismatch: [count={}/{} bytes={}/{}]'.format(
self._download_count, download_files,
self._download_total_bytes, download_size))
logger.info('all files downloaded')
logger.info('all files downloaded: {0:.3f} sec {1:.4f} Mbps'.format(
runtime, download_size_mib * 8 / runtime))
def start(self):
# type: (Downloader) -> None
"""Start the Downloader"""
try:
self._run()
except KeyboardInterrupt:
logger.error(
'KeyboardInterrupt detected, force terminating '
'processes and threads (this may take a while)...')
except (KeyboardInterrupt, Exception) as ex:
if isinstance(ex, KeyboardInterrupt):
logger.error(
'KeyboardInterrupt detected, force terminating '
'processes and threads (this may take a while)...')
self._wait_for_download_threads(terminate=True)
self._md5_offload.finalize_md5_processes()
# TODO delete all temp files
# TODO close resume file in finally?
raise
finally:
if self._md5_offload is not None:
self._md5_offload.finalize_processes()
if self._crypto_offload is not None:
self._crypto_offload.finalize_processes()

Просмотреть файл

@ -30,8 +30,6 @@ from builtins import ( # noqa
)
# stdlib imports
import logging
import hashlib
import multiprocessing
try:
import queue
except ImportError: # noqa
@ -39,21 +37,14 @@ except ImportError: # noqa
# non-stdlib imports
# local imports
import blobxfer.download
import blobxfer.models
import blobxfer.offload
import blobxfer.util
# create logger
logger = logging.getLogger(__name__)
def new_md5_hasher():
# type: (None) -> md5.MD5
"""Create a new MD5 hasher
:rtype: md5.MD5
:return: new MD5 hasher
"""
return hashlib.md5()
def compute_md5_for_file_asbase64(filename, pagealign=False, blocksize=65536):
# type: (str, bool, int) -> str
"""Compute MD5 hash for file and encode as Base64
@ -63,7 +54,7 @@ def compute_md5_for_file_asbase64(filename, pagealign=False, blocksize=65536):
:rtype: str
:return: MD5 for file encoded as Base64
"""
hasher = new_md5_hasher()
hasher = blobxfer.util.new_md5_hasher()
with open(filename, 'rb') as filedesc:
while True:
buf = filedesc.read(blocksize)
@ -85,12 +76,12 @@ def compute_md5_for_data_asbase64(data):
:rtype: str
:return: MD5 for data
"""
hasher = new_md5_hasher()
hasher = blobxfer.util.new_md5_hasher()
hasher.update(data)
return blobxfer.util.base64_encode_as_string(hasher.digest())
class LocalFileMd5Offload(object):
class LocalFileMd5Offload(blobxfer.offload._MultiprocessOffload):
"""LocalFileMd5Offload"""
def __init__(self, num_workers):
# type: (LocalFileMd5Offload, int) -> None
@ -98,52 +89,14 @@ class LocalFileMd5Offload(object):
:param LocalFileMd5Offload self: this
:param int num_workers: number of worker processes
"""
self._task_queue = multiprocessing.Queue()
self._done_queue = multiprocessing.Queue()
self._done_cv = multiprocessing.Condition()
self._term_signal = multiprocessing.Value('i', 0)
self._md5_procs = []
self._initialize_md5_processes(num_workers)
super(LocalFileMd5Offload, self).__init__(num_workers, 'MD5')
@property
def done_cv(self):
# type: (LocalFileMd5Offload) -> multiprocessing.Condition
"""Get Download Done condition variable
:param LocalFileMd5Offload self: this
:rtype: multiprocessing.Condition
:return: cv for download done
"""
return self._done_cv
def _initialize_md5_processes(self, num_workers):
# type: (LocalFileMd5Offload, int) -> None
"""Initialize MD5 checking processes for files for download
:param LocalFileMd5Offload self: this
:param int num_workers: number of worker processes
"""
if num_workers is None or num_workers < 1:
raise ValueError('invalid num_workers: {}'.format(num_workers))
for _ in range(num_workers):
proc = multiprocessing.Process(
target=self._worker_compute_md5_localfile_process)
proc.start()
self._md5_procs.append(proc)
def finalize_md5_processes(self):
# type: (LocalFileMd5Offload) -> None
"""Finalize MD5 checking processes for files for download
:param LocalFileMd5Offload self: this
"""
self._term_signal.value = 1
for proc in self._md5_procs:
proc.join()
def _worker_compute_md5_localfile_process(self):
def _worker_process(self):
# type: (LocalFileMd5Offload) -> None
"""Compute MD5 for local file
:param LocalFileMd5Offload self: this
"""
while self._term_signal.value == 0:
while not self.terminated:
try:
filename, remote_md5, pagealign = self._task_queue.get(True, 1)
except queue.Empty:
@ -153,31 +106,17 @@ class LocalFileMd5Offload(object):
md5, remote_md5, filename))
self._done_cv.acquire()
self._done_queue.put((filename, md5 == remote_md5))
self.done_cv.notify()
self.done_cv.release()
def get_localfile_md5_done(self):
# type: (LocalFileMd5Offload) -> Tuple[str, bool]
"""Get from done queue of local files with MD5 completed
:param LocalFileMd5Offload self: this
:rtype: tuple or None
:return: (local file path, md5 match)
"""
try:
return self._done_queue.get_nowait()
except queue.Empty:
return None
self._done_cv.notify()
self._done_cv.release()
def add_localfile_for_md5_check(self, filename, remote_md5, mode):
# type: (LocalFileMd5Offload, str, str,
# blobxfer.models.AzureStorageModes) -> bool
"""Check an MD5 for a file for download
# blobxfer.models.AzureStorageModes) -> None
"""Add a local file to MD5 check queue
:param LocalFileMd5Offload self: this
:param str filename: file to compute MD5 for
:param str remote_md5: remote MD5 to compare against
:param blobxfer.models.AzureStorageModes mode: mode
:rtype: bool
:return: MD5 match comparison
"""
if mode == blobxfer.models.AzureStorageModes.Page:
pagealign = True

Просмотреть файл

@ -41,6 +41,8 @@ try:
except ImportError: # noqa
import pathlib
import multiprocessing
import tempfile
import threading
# non-stdlib imports
# local imports
from .api import (
@ -53,7 +55,6 @@ from azure.storage.blob.models import _BlobTypes as BlobTypes
import blobxfer.blob.operations
import blobxfer.file.operations
import blobxfer.crypto.models
import blobxfer.md5
import blobxfer.util
# create logger
@ -126,6 +127,7 @@ LocalPath = collections.namedtuple(
)
DownloadOffsets = collections.namedtuple(
'DownloadOffsets', [
'chunk_num',
'fd_start',
'num_bytes',
'range_end',
@ -133,6 +135,14 @@ DownloadOffsets = collections.namedtuple(
'unpad',
]
)
UncheckedChunk = collections.namedtuple(
'UncheckedChunk', [
'data_len',
'fd_start',
'file_path',
'temp',
]
)
class ConcurrencyOptions(object):
@ -147,16 +157,16 @@ class ConcurrencyOptions(object):
self.crypto_processes = crypto_processes
self.md5_processes = md5_processes
self.transfer_threads = transfer_threads
# allow crypto processes to be zero (which will inline crypto
# routines with main process)
if self.crypto_processes is None or self.crypto_processes < 1:
self.crypto_processes = multiprocessing.cpu_count() // 2 - 1
if self.crypto_processes < 1:
self.crypto_processes = 1
self.crypto_processes = 0
if self.md5_processes is None or self.md5_processes < 1:
self.md5_processes = multiprocessing.cpu_count() // 2
if self.md5_processes < 1:
self.md5_processes = 1
if self.transfer_threads is None or self.transfer_threads < 1:
self.transfer_threads = multiprocessing.cpu_count() * 2
self.transfer_threads = multiprocessing.cpu_count() * 3
class GeneralOptions(object):
@ -824,7 +834,8 @@ class DownloadDescriptor(object):
_tmp = list(lpath.parts[:-1])
_tmp.append(lpath.name + '.bxtmp')
self.local_path = pathlib.Path(*_tmp)
self._meta_lock = multiprocessing.Lock()
self._meta_lock = threading.Lock()
self._hasher_lock = threading.Lock()
self._ase = ase
# calculate the total number of ops required for transfer
self._chunk_size = min((options.chunk_size_bytes, self._ase.size))
@ -835,9 +846,10 @@ class DownloadDescriptor(object):
self._total_chunks = 0
self.hmac = None
self.md5 = None
self.offset = 0
self.integrity_counter = 0
self.unchecked_chunks = set()
self._offset = 0
self._chunk_num = 0
self._next_integrity_chunk = 0
self._unchecked_chunks = {}
self._outstanding_ops = self._total_chunks
self._completed_ops = 0
# initialize checkers and allocate space
@ -871,9 +883,15 @@ class DownloadDescriptor(object):
:param DownloadOptions options: download options
"""
if self._ase.is_encrypted:
# ensure symmetric key exists
if blobxfer.util.is_none_or_empty(
self._ase.encryption_metadata.symmetric_key):
raise RuntimeError(
'symmetric key is invalid: provide RSA private key '
'or metadata corrupt')
self.hmac = self._ase.encryption_metadata.initialize_hmac()
if self.hmac is None and options.check_file_md5:
self.md5 = blobxfer.md5.new_md5_hasher()
self.md5 = blobxfer.util.new_md5_hasher()
def _allocate_disk_space(self):
# type: (DownloadDescriptor, int) -> None
@ -912,48 +930,182 @@ class DownloadDescriptor(object):
:rtype: DownloadOffsets
:return: download offsets
"""
if self.offset >= self._ase.size:
return None
if self.offset + self._chunk_size > self._ase.size:
chunk = self._ase.size - self.offset
with self._meta_lock:
if self._offset >= self._ase.size:
return None
if self._offset + self._chunk_size > self._ase.size:
chunk = self._ase.size - self._offset
else:
chunk = self._chunk_size
# on download, num_bytes must be offset by -1 as the x-ms-range
# header expects it that way. x -> y bytes means first bits of the
# (x+1)th byte to the last bits of the (y+1)th byte. for example,
# 0 -> 511 means byte 1 to byte 512
num_bytes = chunk - 1
chunk_num = self._chunk_num
fd_start = self._offset
range_start = self._offset
if self._ase.is_encrypted:
# ensure start is AES block size aligned
range_start = range_start - \
(range_start % self._AES_BLOCKSIZE) - \
self._AES_BLOCKSIZE
if range_start <= 0:
range_start = 0
range_end = self._offset + num_bytes
self._offset += chunk
self._chunk_num += 1
if self._ase.is_encrypted and self._offset >= self._ase.size:
unpad = True
else:
unpad = False
return DownloadOffsets(
chunk_num=chunk_num,
fd_start=fd_start,
num_bytes=chunk,
range_start=range_start,
range_end=range_end,
unpad=unpad,
)
def _postpone_integrity_check(self, offsets, data):
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
"""Postpone integrity check for chunk
:param DownloadDescriptor self: this
:param DownloadOffsets offsets: download offsets
:param bytes data: data
"""
if self.must_compute_md5:
with self.local_path.open('r+b') as fd:
fd.seek(offsets.fd_start, 0)
fd.write(data)
unchecked = UncheckedChunk(
data_len=len(data),
fd_start=offsets.fd_start,
file_path=self.local_path,
temp=False,
)
else:
chunk = self._chunk_size
# on download, num_bytes must be offset by -1 as the x-ms-range
# header expects it that way. x -> y bytes means first bits of the
# (x+1)th byte to the last bits of the (y+1)th byte. for example,
# 0 -> 511 means byte 1 to byte 512
num_bytes = chunk - 1
fd_start = self.offset
range_start = self.offset
if self._ase.is_encrypted:
# ensure start is AES block size aligned
range_start = range_start - (range_start % self._AES_BLOCKSIZE) - \
self._AES_BLOCKSIZE
if range_start <= 0:
range_start = 0
range_end = self.offset + num_bytes
self.offset += chunk
if self._ase.is_encrypted and self.offset >= self._ase.size:
unpad = True
fname = None
with tempfile.NamedTemporaryFile(mode='wb', delete=False) as fd:
fname = fd.name
fd.write(data)
unchecked = UncheckedChunk(
data_len=len(data),
fd_start=0,
file_path=pathlib.Path(fname),
temp=True,
)
with self._meta_lock:
self._unchecked_chunks[offsets.chunk_num] = unchecked
def perform_chunked_integrity_check(self, offsets, data):
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
"""Hash data against stored MD5 hasher safely
:param DownloadDescriptor self: this
:param DownloadOffsets offsets: download offsets
:param bytes data: data
"""
self_check = False
hasher = self.hmac or self.md5
# iterate from next chunk to be checked
while True:
ucc = None
with self._meta_lock:
chunk_num = self._next_integrity_chunk
# check if the next chunk is ready
if chunk_num in self._unchecked_chunks:
ucc = self._unchecked_chunks.pop(chunk_num)
elif chunk_num != offsets.chunk_num:
break
# prepare data for hashing
if ucc is None:
chunk = data
self_check = True
else:
with ucc.file_path.open('rb') as fd:
fd.seek(ucc.fd_start, 0)
chunk = fd.read(ucc.data_len)
if ucc.temp:
ucc.file_path.unlink()
# hash data and set next integrity chunk
with self._hasher_lock:
hasher.update(chunk)
with self._meta_lock:
self._next_integrity_chunk += 1
# store data that hasn't been checked
if not self_check:
self._postpone_integrity_check(offsets, data)
def write_data(self, offsets, data):
# type: (DownloadDescriptor, DownloadOffsets, bytes) -> None
"""Postpone integrity check for chunk
:param DownloadDescriptor self: this
:param DownloadOffsets offsets: download offsets
:param bytes data: data
"""
with self.local_path.open('r+b') as fd:
fd.seek(offsets.fd_start, 0)
fd.write(data)
def finalize_file(self):
# type: (DownloadDescriptor) -> Tuple[bool, str]
"""Finalize file download
:param DownloadDescriptor self: this
:rtype: tuple
:return (if integrity check passed or not, message)
"""
# check final file integrity
check = False
msg = None
if self.hmac is not None:
mac = self._ase.encryption_metadata.encryption_authentication.\
message_authentication_code
digest = blobxfer.util.base64_encode_as_string(self.hmac.digest())
if digest == mac:
check = True
msg = '{}: {}, {} {} <L..R> {}'.format(
self._ase.encryption_metadata.encryption_authentication.
algorithm,
'OK' if check else 'MISMATCH',
self._ase.name,
digest,
mac,
)
elif self.md5 is not None:
digest = blobxfer.util.base64_encode_as_string(self.md5.digest())
if digest == self._ase.md5:
check = True
msg = 'MD5: {}, {} {} <L..R> {}'.format(
'OK' if check else 'MISMATCH',
self._ase.name,
digest,
self._ase.md5,
)
else:
unpad = False
return DownloadOffsets(
fd_start=fd_start,
num_bytes=chunk,
range_start=range_start,
range_end=range_end,
unpad=unpad,
)
check = True
msg = 'MD5: SKIPPED, {} None <L..R> {}'.format(
self._ase.name,
self._ase.md5
)
# cleanup if download failed
if not check:
logger.error(msg)
# delete temp download file
self.local_path.unlink()
return
logger.debug(msg)
# TODO set file uid/gid and mode
# move temp download file to final path
self.local_path.rename(self.final_path)
@property
def outstanding_operations(self):
def all_operations_completed(self):
with self._meta_lock:
return self._outstanding_ops
@property
def completed_operations(self):
with self._meta_lock:
return self._completed_ops
return (self._outstanding_ops == 0 and
len(self._unchecked_chunks) == 0)
def dec_outstanding_operations(self):
with self._meta_lock:

127
blobxfer/offload.py Normal file
Просмотреть файл

@ -0,0 +1,127 @@
# Copyright (c) Microsoft Corporation
#
# All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
# compat imports
from __future__ import (
absolute_import, division, print_function, unicode_literals
)
from builtins import ( # noqa
bytes, dict, int, list, object, range, ascii, chr, hex, input,
next, oct, open, pow, round, super, filter, map, zip)
# stdlib imports
import logging
import multiprocessing
import threading
try:
import queue
except ImportError: # noqa
import Queue as queue
# create logger
logger = logging.getLogger(__name__)
class _MultiprocessOffload(object):
def __init__(self, num_workers, description=None):
# type: (_MultiprocessOffload, int, str) -> None
"""Ctor for Crypto Offload
:param _MultiprocessOffload self: this
:param int num_workers: number of worker processes
:param str description: description
"""
self._task_queue = multiprocessing.Queue()
self._done_queue = multiprocessing.Queue()
self._done_cv = multiprocessing.Condition()
self._term_signal = multiprocessing.Value('i', 0)
self._procs = []
self._check_thread = None
self._initialize_processes(num_workers, description)
@property
def done_cv(self):
# type: (_MultiprocessOffload) -> multiprocessing.Condition
"""Get Done condition variable
:param _MultiprocessOffload self: this
:rtype: multiprocessing.Condition
:return: cv for download done
"""
return self._done_cv
@property
def terminated(self):
# type: (_MultiprocessOffload) -> bool
"""Check if terminated
:param _MultiprocessOffload self: this
:rtype: bool
:return: if terminated
"""
return self._term_signal.value == 1
def _initialize_processes(self, num_workers, description):
# type: (_MultiprocessOffload, int, str) -> None
"""Initialize processes
:param _MultiprocessOffload self: this
:param int num_workers: number of worker processes
:param str description: description
"""
if num_workers is None or num_workers < 1:
raise ValueError('invalid num_workers: {}'.format(num_workers))
logger.debug('initializing {}{} processes'.format(
num_workers, ' ' + description if not None else ''))
for _ in range(num_workers):
proc = multiprocessing.Process(target=self._worker_process)
proc.start()
self._procs.append(proc)
def finalize_processes(self):
# type: (_MultiprocessOffload) -> None
"""Finalize processes
:param _MultiprocessOffload self: this
"""
self._term_signal.value = 1
if self._check_thread is not None:
self._check_thread.join()
for proc in self._procs:
proc.join()
def pop_done_queue(self):
# type: (_MultiprocessOffload) -> object
"""Get item from done queue
:param _MultiprocessOffload self: this
:rtype: object or None
:return: object from done queue, if exists
"""
try:
return self._done_queue.get_nowait()
except queue.Empty:
return None
def initialize_check_thread(self, check_func):
# type: (_MultiprocessOffload, object) -> None
"""Initialize the crypto done queue check thread
:param Downloader self: this
:param object check_func: check function
"""
self._check_thread = threading.Thread(target=check_func)
self._check_thread.start()

Просмотреть файл

@ -32,6 +32,7 @@ from builtins import ( # noqa
import base64
import copy
import dateutil
import hashlib
import logging
import logging.handlers
import mimetypes
@ -164,6 +165,15 @@ def base64_decode_string(string):
return base64.b64decode(string)
def new_md5_hasher():
# type: (None) -> md5.MD5
"""Create a new MD5 hasher
:rtype: md5.MD5
:return: new MD5 hasher
"""
return hashlib.md5()
def page_align_content_length(length):
# type: (int) -> int
"""Compute page boundary alignment

Просмотреть файл

@ -265,9 +265,9 @@ def create_download_specifications(config):
elif confmode == 'block':
mode = blobxfer.models.AzureStorageModes.Block
elif confmode == 'file':
mode == blobxfer.models.AzureStorageModes.File
mode = blobxfer.models.AzureStorageModes.File
elif confmode == 'page':
mode == blobxfer.models.AzureStorageModes.Page
mode = blobxfer.models.AzureStorageModes.Page
else:
raise ValueError('unknown mode: {}'.format(confmode))
# load RSA private key PEM file if specified

Просмотреть файл

@ -4,7 +4,6 @@
# stdlib imports
# non-stdlib imports
import azure.storage
import pytest
# local imports
import blobxfer.models as models
# module under test

Просмотреть файл

@ -46,6 +46,42 @@ def test_rsa_encrypt_decrypt_keys():
def test_pkcs7_padding():
buf = os.urandom(32)
pbuf = ops.pad_pkcs7(buf)
buf2 = ops.unpad_pkcs7(pbuf)
pbuf = ops.pkcs7_pad(buf)
buf2 = ops.pkcs7_unpad(pbuf)
assert buf == buf2
def test_aes_cbc_encryption():
enckey = ops.aes256_generate_random_key()
assert len(enckey) == ops._AES256_KEYLENGTH_BYTES
# test random binary data, unaligned
iv = os.urandom(16)
plaindata = os.urandom(31)
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
assert encdata != plaindata
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
assert decdata == plaindata
# test random binary data aligned on boundary
plaindata = os.urandom(32)
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
assert encdata != plaindata
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
assert decdata == plaindata
# test "text" data
plaintext = 'attack at dawn!'
plaindata = plaintext.encode('utf8')
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, True)
assert encdata != plaindata
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, True)
assert decdata == plaindata
assert plaindata.decode('utf8') == plaintext
# test unpadded
plaindata = os.urandom(32)
encdata = ops.aes_cbc_encrypt_data(enckey, iv, plaindata, False)
assert encdata != plaindata
decdata = ops.aes_cbc_decrypt_data(enckey, iv, encdata, False)
assert decdata == plaindata

Просмотреть файл

@ -197,27 +197,18 @@ def test_post_md5_skip_on_check():
assert d._add_to_download_queue.call_count == 1
def test_initialize_check_md5_downloads_thread():
def test_check_for_downloads_from_md5():
lpath = 'lpath'
d = dl.Downloader(mock.MagicMock(), mock.MagicMock(), mock.MagicMock())
d._md5_map[lpath] = mock.MagicMock()
d._download_set.add(pathlib.Path(lpath))
d._md5_offload = mock.MagicMock()
d._md5_offload.done_cv = multiprocessing.Condition()
d._md5_offload.get_localfile_md5_done = mock.MagicMock()
d._md5_offload.get_localfile_md5_done.side_effect = [None, (lpath, False)]
d._md5_offload.pop_done_queue.side_effect = [None, (lpath, False)]
d._add_to_download_queue = mock.MagicMock()
d._initialize_check_md5_downloads_thread()
while len(d._md5_map) > 0:
d._md5_offload.done_cv.acquire()
d._md5_offload.done_cv.notify()
d._md5_offload.done_cv.release()
d._all_remote_files_processed = True
d._md5_offload.done_cv.acquire()
d._md5_offload.done_cv.notify()
d._md5_offload.done_cv.release()
d._md5_check_thread.join()
with pytest.raises(StopIteration):
d._check_for_downloads_from_md5()
assert d._add_to_download_queue.call_count == 1
@ -237,14 +228,15 @@ def test_initialize_and_terminate_download_threads():
assert not thr.is_alive()
@mock.patch('time.clock')
@mock.patch('blobxfer.md5.LocalFileMd5Offload')
@mock.patch('blobxfer.blob.operations.list_blobs')
@mock.patch('blobxfer.operations.ensure_local_destination', return_value=True)
def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
def test_start(patched_eld, patched_lb, patched_lfmo, patched_tc, tmpdir):
d = dl.Downloader(mock.MagicMock(), mock.MagicMock(), mock.MagicMock())
d._initialize_check_md5_downloads_thread = mock.MagicMock()
d._initialize_download_threads = mock.MagicMock()
d._md5_check_thread = mock.MagicMock()
patched_lfmo._check_thread = mock.MagicMock()
d._general_options.concurrency.crypto_processes = 0
d._spec.sources = []
d._spec.options = mock.MagicMock()
d._spec.options.chunk_size_bytes = 1
@ -270,12 +262,14 @@ def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
d._check_download_conditions = mock.MagicMock()
d._check_download_conditions.return_value = dl.DownloadAction.Skip
patched_tc.side_effect = [1, 2]
d.start()
assert d._pre_md5_skip_on_check.call_count == 0
patched_lb.side_effect = [[b]]
d._all_remote_files_processed = False
d._check_download_conditions.return_value = dl.DownloadAction.CheckMd5
patched_tc.side_effect = [1, 2]
with pytest.raises(RuntimeError):
d.start()
assert d._pre_md5_skip_on_check.call_count == 1
@ -284,6 +278,7 @@ def test_start(patched_eld, patched_lb, patched_lfmo, tmpdir):
patched_lb.side_effect = [[b]]
d._all_remote_files_processed = False
d._check_download_conditions.return_value = dl.DownloadAction.Download
patched_tc.side_effect = [1, 2]
with pytest.raises(RuntimeError):
d.start()
assert d._download_queue.qsize() == 1

Просмотреть файл

@ -36,7 +36,7 @@ def test_done_cv():
assert a.done_cv == a._done_cv
finally:
if a:
a.finalize_md5_processes()
a.finalize_processes()
def test_finalize_md5_processes():
@ -48,9 +48,9 @@ def test_finalize_md5_processes():
a = md5.LocalFileMd5Offload(num_workers=1)
finally:
if a:
a.finalize_md5_processes()
a.finalize_processes()
for proc in a._md5_procs:
for proc in a._procs:
assert not proc.is_alive()
@ -63,7 +63,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
a = None
try:
a = md5.LocalFileMd5Offload(num_workers=1)
result = a.get_localfile_md5_done()
result = a.pop_done_queue()
assert result is None
a.add_localfile_for_md5_check(
@ -71,7 +71,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
i = 33
checked = False
while i > 0:
result = a.get_localfile_md5_done()
result = a.pop_done_queue()
if result is None:
time.sleep(0.3)
i -= 1
@ -84,7 +84,7 @@ def test_from_add_to_done_non_pagealigned(tmpdir):
assert checked
finally:
if a:
a.finalize_md5_processes()
a.finalize_processes()
def test_from_add_to_done_pagealigned(tmpdir):
@ -96,7 +96,7 @@ def test_from_add_to_done_pagealigned(tmpdir):
a = None
try:
a = md5.LocalFileMd5Offload(num_workers=1)
result = a.get_localfile_md5_done()
result = a.pop_done_queue()
assert result is None
a.add_localfile_for_md5_check(
@ -104,7 +104,7 @@ def test_from_add_to_done_pagealigned(tmpdir):
i = 33
checked = False
while i > 0:
result = a.get_localfile_md5_done()
result = a.pop_done_queue()
if result is None:
time.sleep(0.3)
i -= 1
@ -117,4 +117,4 @@ def test_from_add_to_done_pagealigned(tmpdir):
assert checked
finally:
if a:
a.finalize_md5_processes()
a.finalize_processes()

Просмотреть файл

@ -25,9 +25,9 @@ def test_concurrency_options(patched_cc):
transfer_threads=-2,
)
assert a.crypto_processes == 1
assert a.crypto_processes == 0
assert a.md5_processes == 1
assert a.transfer_threads == 2
assert a.transfer_threads == 3
def test_general_options():
@ -359,12 +359,16 @@ def test_downloaddescriptor(tmpdir):
ase = models.AzureStorageEntity('cont')
ase._size = 1024
ase._encryption = mock.MagicMock()
with pytest.raises(RuntimeError):
d = models.DownloadDescriptor(lp, ase, opts)
ase._encryption.symmetric_key = b'123'
d = models.DownloadDescriptor(lp, ase, opts)
assert d.entity == ase
assert not d.must_compute_md5
assert d._total_chunks == 64
assert d.offset == 0
assert d._offset == 0
assert d.final_path == lp
assert str(d.local_path) == str(lp) + '.bxtmp'
assert d.local_path.stat().st_size == 1024 - 16
@ -400,6 +404,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
offsets = d.next_offsets()
assert d._total_chunks == 1
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 128
assert offsets.range_start == 0
@ -416,6 +421,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 1
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 1
assert offsets.range_start == 0
@ -427,6 +433,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 1
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 256
assert offsets.range_start == 0
@ -438,12 +445,14 @@ def test_downloaddescriptor_next_offsets(tmpdir):
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 2
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 256
assert offsets.range_start == 0
assert offsets.range_end == 255
assert not offsets.unpad
offsets = d.next_offsets()
assert offsets.chunk_num == 1
assert offsets.fd_start == 256
assert offsets.num_bytes == 16
assert offsets.range_start == 256
@ -452,10 +461,12 @@ def test_downloaddescriptor_next_offsets(tmpdir):
assert d.next_offsets() is None
ase._encryption = mock.MagicMock()
ase._encryption.symmetric_key = b'123'
ase._size = 128
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 1
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 128
assert offsets.range_start == 0
@ -467,6 +478,7 @@ def test_downloaddescriptor_next_offsets(tmpdir):
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 1
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 256
assert offsets.range_start == 0
@ -478,12 +490,14 @@ def test_downloaddescriptor_next_offsets(tmpdir):
d = models.DownloadDescriptor(lp, ase, opts)
offsets = d.next_offsets()
assert d._total_chunks == 2
assert offsets.chunk_num == 0
assert offsets.fd_start == 0
assert offsets.num_bytes == 256
assert offsets.range_start == 0
assert offsets.range_end == 255
assert not offsets.unpad
offsets = d.next_offsets()
assert offsets.chunk_num == 1
assert offsets.fd_start == 256
assert offsets.num_bytes == 32
assert offsets.range_start == 256 - 16