Merge branch 'master' of https://github.com/Microsoft/ai4eutils
# Conflicts: # ai4e_azure_utils.py
This commit is contained in:
Коммит
c2a5fb4fdc
|
@ -1,11 +1,12 @@
|
|||
"""
|
||||
Miscellaneous Azure Blob Storage utilities
|
||||
|
||||
Requires azure-storage-blob>=12.4.0
|
||||
"""
|
||||
import json
|
||||
from typing import Any, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from azure.storage.blob._models import BlobPrefix
|
||||
from azure.storage.blob import ContainerClient
|
||||
from azure.storage.blob import BlobPrefix, ContainerClient
|
||||
|
||||
import sas_blob_utils
|
||||
|
||||
|
@ -104,7 +105,7 @@ def write_list_to_file(output_file: str, strings: Sequence[str]) -> None:
|
|||
f.write('\n'.join(strings))
|
||||
|
||||
|
||||
def read_list_from_file(filename: str):
|
||||
def read_list_from_file(filename: str) -> List[str]:
|
||||
"""
|
||||
Reads a json-formatted list of strings from a file.
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
import rasterio
|
||||
from rasterio.windows import Window
|
||||
from rasterio.errors import RasterioIOError
|
||||
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
from torch.utils.data.dataset import IterableDataset
|
||||
|
||||
class StreamingGeospatialDataset(IterableDataset):
|
||||
|
||||
def __init__(self, imagery_fns, label_fns=None, chip_size=256, num_chips_per_tile=200, windowed_sampling=False, transform=None, verbose=False):
|
||||
"""A torch Dataset for randomly sampling chips from a list of tiles. When used in conjunction with a DataLoader that has `num_workers>1` this Dataset will assign each worker to sample chips from disjoint sets of tiles.
|
||||
|
||||
Args:
|
||||
imagery_fns: A list of filenames (or web addresses -- anything that `rasterio.open()` can read) pointing to imagery tiles.
|
||||
label_fns: A list of filenames of the same size as `imagery_fns` pointing to label mask tiles or `None` if the Dataset should operate in "imagery only mode". Note that we expect `imagery_fns[i]` and `label_fns[i]` to have the same dimension and coordinate system.
|
||||
chip_size: Desired size of chips (in pixels).
|
||||
num_chips_per_tile: Desired number of chips to sample for each tile.
|
||||
windowed_sampling: Flag indicating whether we should sample each chip with a read using `rasterio.windows.Window` or whether we should read the whole tile into memory, then sample chips.
|
||||
transform: The torchvision.transform object to apply to each chip.
|
||||
verbose: If `False` we will be quiet.
|
||||
"""
|
||||
|
||||
if label_fns is None:
|
||||
self.fns = imagery_fns
|
||||
self.use_labels = False
|
||||
else:
|
||||
self.fns = list(zip(imagery_fns, label_fns))
|
||||
self.use_labels = True
|
||||
|
||||
self.chip_size = chip_size
|
||||
self.num_chips_per_tile = num_chips_per_tile
|
||||
self.windowed_sampling = windowed_sampling
|
||||
|
||||
self.transform = transform
|
||||
self.verbose = verbose
|
||||
|
||||
if self.verbose:
|
||||
print("Constructed StreamingGeospatialDataset")
|
||||
|
||||
def stream_tile_fns(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None: # In this case we are not loading through a DataLoader with multiple workers
|
||||
worker_id = 0
|
||||
num_workers = 1
|
||||
else:
|
||||
worker_id = worker_info.id
|
||||
num_workers = worker_info.num_workers
|
||||
|
||||
# We only want to shuffle the order we traverse the files if we are the first worker (else, every worker will shuffle the files...)
|
||||
if worker_id == 0:
|
||||
np.random.shuffle(self.fns) # in place
|
||||
# NOTE: A warning, when different workers are created they will all have the same numpy random seed, however will have different torch random seeds. If you want to use numpy random functions, seed appropriately.
|
||||
#seed = torch.randint(low=0,high=2**32-1,size=(1,)).item()
|
||||
#np.random.seed(seed) # when different workers spawn, they have the same numpy random seed...
|
||||
|
||||
if self.verbose:
|
||||
print("Creating a filename stream for worker %d" % (worker_id))
|
||||
|
||||
# This logic splits up the list of filenames into `num_workers` chunks. Each worker will recieve ceil(num_filenames / num_workers) filenames to generate chips from. If the number of workers doesn't divide the number of filenames evenly then the last worker will have fewer filenames.
|
||||
N = len(self.fns)
|
||||
num_files_per_worker = int(np.ceil(N / num_workers))
|
||||
lower_idx = worker_id * num_files_per_worker
|
||||
upper_idx = min(N, (worker_id+1) * num_files_per_worker)
|
||||
for idx in range(lower_idx, upper_idx):
|
||||
|
||||
label_fn = None
|
||||
if self.use_labels:
|
||||
img_fn, label_fn = self.fns[idx]
|
||||
else:
|
||||
img_fn = self.fns[idx]
|
||||
|
||||
if self.verbose:
|
||||
print("Worker %d, yielding file %d" % (worker_id, idx))
|
||||
|
||||
yield (img_fn, label_fn)
|
||||
|
||||
def stream_chips(self):
|
||||
for img_fn, label_fn in self.stream_tile_fns():
|
||||
|
||||
# Open file pointers
|
||||
img_fp = rasterio.open(img_fn, "r")
|
||||
label_fp = rasterio.open(label_fn, "r") if self.use_labels else None
|
||||
|
||||
height, width = img_fp.shape
|
||||
if self.use_labels: # garuntee that our label mask has the same dimensions as our imagery
|
||||
t_height, t_width = label_fp.shape
|
||||
assert height == t_height and width == t_width
|
||||
|
||||
try:
|
||||
# If we aren't in windowed sampling mode then we should read the entire tile up front
|
||||
if not self.windowed_sampling:
|
||||
img_data = np.rollaxis(img_fp.read(), 0, 3)
|
||||
if self.use_labels:
|
||||
label_data = label_fp.read().squeeze() # assume the label geotiff has a single channel
|
||||
|
||||
|
||||
for i in range(self.num_chips_per_tile):
|
||||
# Select the top left pixel of our chip randomly
|
||||
x = np.random.randint(0, width-self.chip_size)
|
||||
y = np.random.randint(0, height-self.chip_size)
|
||||
|
||||
if self.windowed_sampling:
|
||||
img = np.rollaxis(img_fp.read(window=Window(x, y, self.chip_size, self.chip_size)), 0, 3)
|
||||
if self.use_labels:
|
||||
labels = label_fp.read(window=Window(x, y, self.chip_size, self.chip_size)).squeeze()
|
||||
else:
|
||||
img = img_data[y:y+self.chip_size, x:x+self.chip_size, :]
|
||||
if self.use_labels:
|
||||
labels = label_data[y:y+self.chip_size, x:x+self.chip_size]
|
||||
|
||||
# TODO: check for nodata and throw away the chip if necessary. Not sure how to do this in a dataset independent way.
|
||||
|
||||
if self.use_labels:
|
||||
labels = transforms.ToTensor()(labels).squeeze()
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
if self.use_labels:
|
||||
yield img, labels
|
||||
else:
|
||||
yield img
|
||||
except RasterioIOError as e: # NOTE(caleb): I put this here to catch weird errors that I was seeing occasionally when trying to read from COGS - I don't remember the details though
|
||||
print("Reading %s failed, skipping..." % (fn))
|
||||
|
||||
# Close file pointers
|
||||
img_fp.close()
|
||||
if self.use_labels:
|
||||
label_fp.close()
|
||||
|
||||
def __iter__(self):
|
||||
if self.verbose:
|
||||
print("Creating a new StreamingGeospatialDataset iterator")
|
||||
return iter(self.stream_chips())
|
|
@ -4,7 +4,8 @@ This is a list of recipes for working with geospatial data using the GDAL comman
|
|||
|
||||
## Table of Contents
|
||||
|
||||
- [Clip shapefile to the extent of a raster](#clip-shapefile-to-the-extent-of-a-raster)
|
||||
- [Clip shapefile to the extent of a raster](#clip-shapefile-to-the-extent-of-a-raster)
|
||||
- [Create polygon of the extent of a raster)(#create-polygon-of-the-extent-of-a-raster)
|
||||
- [Convert shapefile to geojson](#convert-shapefile-to-geojson)
|
||||
- [Reproject a raster](#reproject-a-raster)
|
||||
- [Convert a raster into an XYZ basemap](#convert-a-raster-into-an-xyz-basemap)
|
||||
|
@ -43,6 +44,15 @@ gdaltindex -t_srs epsg:4326 -f GeoJSON OUTPUT_EXTENT.geojson INPUT_RASTER.tif
|
|||
ogr2ogr -f GeoJSON -clipsrc OUTPUT_EXTENT OUTPUT_SHAPES_CLIPPED.geojson INPUT_SHAPES.shp
|
||||
```
|
||||
|
||||
|
||||
### Create polygon of the extent of a raster
|
||||
<a name="create-polygon-of-the-extent-of-a-raster"></a>
|
||||
|
||||
```
|
||||
gdaltindex -t_srs epsg:4326 -f GeoJSON OUTPUT_EXTENT.geojson INPUT_RASTER.tif
|
||||
```
|
||||
|
||||
|
||||
### Convert shapefile to geojson
|
||||
<a name="convert-shapefile-to-geojson"></a>
|
||||
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
"""
|
||||
Class to visualize raster mask labels and hardmax or softmax model predictions, for semantic segmentation tasks.
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
from typing import Union, Tuple, List
|
||||
from typing import Union, Tuple
|
||||
|
||||
import matplotlib.colors as mcolors
|
||||
import matplotlib.pyplot as plt
|
||||
|
@ -32,7 +31,6 @@ class RasterLabelVisualizer(object):
|
|||
as a color; additionally a (R, G, B) tuple or list with uint8 values will also be parsed)
|
||||
}
|
||||
"""
|
||||
|
||||
if isinstance(label_map, str):
|
||||
assert os.path.exists(label_map)
|
||||
with open(label_map) as f:
|
||||
|
@ -139,25 +137,26 @@ class RasterLabelVisualizer(object):
|
|||
colormap[num] = RasterLabelVisualizer.matplotlib_color_to_uint8_rgb(color)
|
||||
return colormap
|
||||
|
||||
def get_tool_colormap(self) -> List[dict]:
|
||||
"""Returns a list of items specifying the name and color of categories. Example:
|
||||
[
|
||||
def get_tool_colormap(self) -> str:
|
||||
"""Returns a string that is a JSON of a list of items specifying the name and color
|
||||
of classes. Example:
|
||||
"[
|
||||
{"name": "Water", "color": "#0000FF"},
|
||||
{"name": "Tree Canopy", "color": "#008000"},
|
||||
{"name": "Field", "color": "#80FF80"},
|
||||
{"name": "Built", "color": "#806060"}
|
||||
]
|
||||
]"
|
||||
"""
|
||||
li = []
|
||||
classes = []
|
||||
for num, name in self.num_to_name.items():
|
||||
color = self.num_to_color[num]
|
||||
color_hex = mcolors.to_hex(color)
|
||||
li.append({
|
||||
classes.append({
|
||||
'name': name,
|
||||
'color': color_hex
|
||||
})
|
||||
return li
|
||||
|
||||
classes = json.dumps(classes, indent=4)
|
||||
return classes
|
||||
|
||||
@staticmethod
|
||||
def plot_colortable(name_to_color: dict, title: str, sort_colors: bool = False, emptycols: int = 0) -> plt.Figure:
|
||||
|
@ -216,7 +215,7 @@ class RasterLabelVisualizer(object):
|
|||
|
||||
return fig
|
||||
|
||||
def plot_color_legend(self, legend_title: str = 'Categories'):
|
||||
def plot_color_legend(self, legend_title: str = 'Categories') -> plt.Figure:
|
||||
"""Builds a legend of color block, numerical categories and names of the categories.
|
||||
|
||||
Returns:
|
||||
|
@ -229,8 +228,8 @@ class RasterLabelVisualizer(object):
|
|||
fig = RasterLabelVisualizer.plot_colortable(label_map, legend_title, sort_colors=False, emptycols=3)
|
||||
return fig
|
||||
|
||||
def show_label_raster(self, label_raster: Union[Image.Image, np.ndarray], size: Tuple[int, int] = (10, 10)) -> Tuple[
|
||||
Image.Image, BytesIO]:
|
||||
def show_label_raster(self, label_raster: Union[Image.Image, np.ndarray],
|
||||
size: Tuple[int, int] = (10, 10)) -> Tuple[Image.Image, BytesIO]:
|
||||
"""Visualizes a label mask or hardmax predictions of a model, according to the category color map
|
||||
provided when the class was initialized.
|
||||
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# global options
|
||||
[mypy]
|
||||
disallow_incomplete_defs = True
|
||||
namespace_packages = True
|
||||
warn_redundant_casts = True
|
||||
warn_unreachable = True
|
||||
warn_unused_configs = True
|
||||
warn_unused_ignores = True
|
||||
|
|
@ -8,13 +8,13 @@ This module contains helper functions for dealing with Shared Access Signatures
|
|||
The default Azure Storage SAS URI format is:
|
||||
https://<account>.blob.core.windows.net/<container>/<blob>?<sas_token>
|
||||
|
||||
This module assumes azure-storage-blob version 12.3.
|
||||
This module assumes azure-storage-blob version 12.5.
|
||||
|
||||
Documentation for Azure Blob Storage:
|
||||
https://docs.microsoft.com/en-us/azure/developer/python/sdk/storage/storage-blob-readme
|
||||
docs.microsoft.com/en-us/azure/developer/python/sdk/storage/storage-blob-readme
|
||||
|
||||
Documentation for SAS:
|
||||
https://docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
|
||||
docs.microsoft.com/en-us/azure/storage/common/storage-sas-overview
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
import io
|
||||
|
@ -31,9 +31,7 @@ from azure.storage.blob import (
|
|||
BlobProperties,
|
||||
ContainerClient,
|
||||
ContainerSasPermissions,
|
||||
generate_container_sas,
|
||||
upload_blob_to_url)
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
generate_container_sas)
|
||||
|
||||
|
||||
def build_azure_storage_uri(
|
||||
|
@ -47,7 +45,7 @@ def build_azure_storage_uri(
|
|||
Args:
|
||||
account: str, name of Azure Storage account
|
||||
container: optional str, name of Azure Blob Storage container
|
||||
blob: optional str, name of blob
|
||||
blob: optional str, name of blob, not URL-escaped
|
||||
if blob is given, must also specify container
|
||||
sas_token: optional str, Shared Access Signature (SAS)
|
||||
does not start with '?'
|
||||
|
@ -63,6 +61,7 @@ def build_azure_storage_uri(
|
|||
uri = f'{uri}/{container}'
|
||||
if blob is not None:
|
||||
assert container is not None
|
||||
blob = parse.quote(blob)
|
||||
uri = f'{uri}/{blob}'
|
||||
if sas_token is not None:
|
||||
assert sas_token[0] != '?'
|
||||
|
@ -149,7 +148,8 @@ def get_sas_token_from_uri(sas_uri: str) -> Optional[str]:
|
|||
Args:
|
||||
sas_uri: str, Azure blob storage SAS token
|
||||
|
||||
Returns: Query part of the SAS token, or None if URI has no token.
|
||||
Returns: str, query part of the SAS token (without leading '?'),
|
||||
or None if URI has no token.
|
||||
"""
|
||||
url_parts = parse.urlsplit(sas_uri)
|
||||
sas_token = url_parts.query or None # None if query is empty string
|
||||
|
@ -180,8 +180,8 @@ def get_endpoint_suffix(sas_uri):
|
|||
Args:
|
||||
sas_uri: str, Azure blob storage URI with SAS token
|
||||
|
||||
Returns: A string, usually 'core.windows.net' or 'core.chinacloudapi.cn', to use for the
|
||||
`endpoint` argument in various blob storage SDK functions.
|
||||
Returns: A string, usually 'core.windows.net' or 'core.chinacloudapi.cn', to
|
||||
use for the `endpoint` argument in various blob storage SDK functions.
|
||||
"""
|
||||
url_parts = parse.urlsplit(sas_uri)
|
||||
suffix = url_parts.netloc.split('.blob.')[1].split('/')[0]
|
||||
|
@ -219,15 +219,20 @@ def get_all_query_parts(sas_uri: str) -> Dict[str, Any]:
|
|||
return parse.parse_qs(url_parts.query)
|
||||
|
||||
|
||||
def check_blob_existence(sas_uri: str,
|
||||
blob_name: Optional[str] = None) -> bool:
|
||||
def check_blob_exists(sas_uri: str, blob_name: Optional[str] = None) -> bool:
|
||||
"""Checks whether a given URI points to an actual blob.
|
||||
|
||||
Assumes that sas_uri points to Azure Blob Storage account hosted at
|
||||
a default Azure URI. Does not work for locally-emulated Azure Storage
|
||||
or Azure Storage hosted at custom endpoints. In these cases, create a
|
||||
BlobClient using the default constructor, instead of from_blob_url(),
|
||||
and use the BlobClient.exists() method directly.
|
||||
|
||||
Args:
|
||||
sas_uri: str, URI to a container or a blob
|
||||
if blob_name is given, sas_uri is treated as a container URI
|
||||
otherwise, sas_uri is treated as a blob URI
|
||||
blob_name: optional str, name of blob
|
||||
blob_name: optional str, name of blob, not URL-escaped
|
||||
must be given if sas_uri is a URI to a container
|
||||
|
||||
Returns: bool, whether the sas_uri given points to an existing blob
|
||||
|
@ -236,15 +241,8 @@ def check_blob_existence(sas_uri: str,
|
|||
sas_uri = build_blob_uri(
|
||||
container_uri=sas_uri, blob_name=blob_name)
|
||||
|
||||
# until Azure implements a proper BlobClient.exists() method, we can
|
||||
# only use try/except to determine blob existence
|
||||
# see: https://github.com/Azure/azure-sdk-for-python/issues/9507
|
||||
with BlobClient.from_blob_url(sas_uri) as blob_client:
|
||||
try:
|
||||
blob_client.get_blob_properties()
|
||||
except ResourceNotFoundError:
|
||||
return False
|
||||
return True
|
||||
return blob_client.exists()
|
||||
|
||||
|
||||
def list_blobs_in_container(
|
||||
|
@ -331,18 +329,13 @@ def generate_writable_container_sas(account_name: str,
|
|||
|
||||
Raises: azure.core.exceptions.ResourceExistsError, if container already
|
||||
exists
|
||||
|
||||
NOTE: This method currently fails on non-default Azure Storage URLs. The
|
||||
initializer for ContainerClient() assumes the default Azure Storage URL
|
||||
format, which is a bug that has been reported here:
|
||||
https://github.com/Azure/azure-sdk-for-python/issues/12568
|
||||
"""
|
||||
if account_url is None:
|
||||
account_url = build_azure_storage_uri(account=account_name)
|
||||
container_client = ContainerClient(account_url=account_url,
|
||||
container_name=container_name,
|
||||
credential=account_key)
|
||||
container_client.create_container()
|
||||
with ContainerClient(account_url=account_url,
|
||||
container_name=container_name,
|
||||
credential=account_key) as container_client:
|
||||
container_client.create_container()
|
||||
|
||||
permissions = ContainerSasPermissions(read=True, write=True, list=True)
|
||||
container_sas_token = generate_container_sas(
|
||||
|
@ -356,7 +349,8 @@ def generate_writable_container_sas(account_name: str,
|
|||
|
||||
|
||||
def upload_blob(container_uri: str, blob_name: str,
|
||||
data: Union[Iterable[AnyStr], IO[AnyStr]]) -> str:
|
||||
data: Union[Iterable[AnyStr], IO[AnyStr]],
|
||||
overwrite: bool = False) -> str:
|
||||
"""Creates a new blob of the given name from an IO stream.
|
||||
|
||||
Args:
|
||||
|
@ -364,12 +358,15 @@ def upload_blob(container_uri: str, blob_name: str,
|
|||
blob_name: str, name of blob to upload
|
||||
data: str, bytes, or IO stream
|
||||
if str, assumes utf-8 encoding
|
||||
overwrite: bool, whether to overwrite existing blob (if any)
|
||||
|
||||
Returns: str, URI to blob, includes SAS token if container_uri has SAS token
|
||||
Returns: str, URL to blob, includes SAS token if container_uri has SAS token
|
||||
"""
|
||||
blob_url = build_blob_uri(container_uri, blob_name)
|
||||
upload_blob_to_url(blob_url, data=data)
|
||||
return blob_url
|
||||
account_url, container, sas_token = split_container_uri(container_uri)
|
||||
with BlobClient(account_url=account_url, container_name=container,
|
||||
blob_name=blob_name, credential=sas_token) as blob_client:
|
||||
blob_client.upload_blob(data, overwrite=overwrite)
|
||||
return blob_client.url
|
||||
|
||||
|
||||
def download_blob_to_stream(sas_uri: str) -> Tuple[io.BytesIO, BlobProperties]:
|
||||
|
@ -384,10 +381,6 @@ def download_blob_to_stream(sas_uri: str) -> Tuple[io.BytesIO, BlobProperties]:
|
|||
|
||||
Raises: azure.core.exceptions.ResourceNotFoundError, if sas_uri points
|
||||
to a non-existant blob
|
||||
|
||||
NOTE: the returned BlobProperties object may have incorrect values for
|
||||
the blob name and container name. This is a bug which has been reported
|
||||
here: https://github.com/Azure/azure-sdk-for-python/issues/12563
|
||||
"""
|
||||
with BlobClient.from_blob_url(sas_uri) as blob_client:
|
||||
output_stream = io.BytesIO()
|
||||
|
@ -397,20 +390,34 @@ def download_blob_to_stream(sas_uri: str) -> Tuple[io.BytesIO, BlobProperties]:
|
|||
return output_stream, blob_properties
|
||||
|
||||
|
||||
def build_blob_uri(container_uri: str, blob_name: str) -> str:
|
||||
def split_container_uri(container_uri: str) -> Tuple[str, str, Optional[str]]:
|
||||
"""
|
||||
Args:
|
||||
container_uri: str, URI to blob storage container
|
||||
<account_url>/<container_name>?<sas_token>
|
||||
blob_name: str, name of blob
|
||||
<account_url>/<container>?<sas_token>
|
||||
|
||||
Returns: str, blob URI
|
||||
<account_url>/<container_name>/<blob_name>?<sas_token>
|
||||
Returns: account_url, container_name, sas_token
|
||||
"""
|
||||
account_container = container_uri.split('?', maxsplit=1)[0]
|
||||
account_url, container_name = account_container.rsplit('/', maxsplit=1)
|
||||
sas_token = get_sas_token_from_uri(container_uri)
|
||||
blob_uri = f'{account_url}/{container_name}/{blob_name}'
|
||||
return account_url, container_name, sas_token
|
||||
|
||||
|
||||
def build_blob_uri(container_uri: str, blob_name: str) -> str:
|
||||
"""
|
||||
Args:
|
||||
container_uri: str, URI to blob storage container
|
||||
<account_url>/<container>?<sas_token>
|
||||
blob_name: str, name of blob, not URL-escaped
|
||||
|
||||
Returns: str, blob URI <account_url>/<container>/<blob_name>?<sas_token>,
|
||||
<blob_name> is URL-escaped
|
||||
"""
|
||||
account_url, container, sas_token = split_container_uri(container_uri)
|
||||
|
||||
blob_name = parse.quote(blob_name)
|
||||
blob_uri = f'{account_url}/{container}/{blob_name}'
|
||||
if sas_token is not None:
|
||||
blob_uri += f'?{sas_token}'
|
||||
return blob_uri
|
||||
|
|
|
@ -18,7 +18,8 @@ https://github.com/azure/azurite
|
|||
3) Run Azurite. The -l flag sets a temp folder where Azurite can store data to
|
||||
disk. By default, Azurite's blob service runs at 127.0.0.1:10000, which can be
|
||||
changed by the parameters --blobHost 1.2.3.4 --blobPort 5678.
|
||||
mkdir $HOME/tmp/azurite
|
||||
mkdir -p $HOME/tmp/azurite
|
||||
rm -r $HOME/tmp/azurite/* # if the folder already existed, clear it
|
||||
azurite-blob -l $HOME/tmp/azurite
|
||||
|
||||
4) In a separate terminal, activate a virtual environment with the Azure Storage
|
||||
|
@ -40,7 +41,7 @@ from azure.storage.blob import BlobClient, ContainerClient
|
|||
|
||||
from sas_blob_utils import (
|
||||
build_blob_uri,
|
||||
check_blob_existence,
|
||||
check_blob_exists,
|
||||
download_blob_to_stream,
|
||||
generate_writable_container_sas,
|
||||
get_account_from_uri,
|
||||
|
@ -82,21 +83,23 @@ class Tests(unittest.TestCase):
|
|||
# cleanup: delete the private emulated container
|
||||
print('running cleanup')
|
||||
|
||||
# until the private emulated account is able to work, skip cleanup
|
||||
# with ContainerClient.from_container_url(
|
||||
# PRIVATE_CONTAINER_URI,
|
||||
# credential=PRIVATE_ACCOUNT_KEY) as cc:
|
||||
# try:
|
||||
# cc.get_container_properties()
|
||||
# cc.delete_container()
|
||||
# except ResourceNotFoundError:
|
||||
# pass
|
||||
with BlobClient(account_url=PRIVATE_ACCOUNT_URI,
|
||||
container_name=PRIVATE_CONTAINER_NAME,
|
||||
blob_name=PRIVATE_BLOB_NAME,
|
||||
credential=PRIVATE_ACCOUNT_KEY) as bc:
|
||||
if bc.exists():
|
||||
print('deleted blob')
|
||||
bc.delete_blob(delete_snapshots='include')
|
||||
|
||||
# if check_blob_existence(PRIVATE_BLOB_URI):
|
||||
# with BlobClient.from_blob_url(
|
||||
# PRIVATE_BLOB_URI,
|
||||
# credential=PRIVATE_ACCOUNT_KEY) as bc:
|
||||
# bc.delete_blob(delete_snapshots=True)
|
||||
with ContainerClient.from_container_url(
|
||||
PRIVATE_CONTAINER_URI,
|
||||
credential=PRIVATE_ACCOUNT_KEY) as cc:
|
||||
try:
|
||||
cc.get_container_properties()
|
||||
cc.delete_container()
|
||||
print('deleted container')
|
||||
except ResourceNotFoundError:
|
||||
pass
|
||||
self.needs_cleanup = False
|
||||
|
||||
def test_get_account_from_uri(self):
|
||||
|
@ -118,22 +121,18 @@ class Tests(unittest.TestCase):
|
|||
get_sas_token_from_uri(PUBLIC_CONTAINER_URI_SAS),
|
||||
PUBLIC_CONTAINER_SAS)
|
||||
|
||||
def test_check_blob_existence(self):
|
||||
def test_check_blob_exists(self):
|
||||
print('PUBLIC_BLOB_URI')
|
||||
self.assertTrue(check_blob_existence(PUBLIC_BLOB_URI))
|
||||
self.assertTrue(check_blob_exists(PUBLIC_BLOB_URI))
|
||||
print('PUBLIC_CONTAINER_URI + PUBLIC_BLOB_NAME')
|
||||
self.assertTrue(check_blob_existence(
|
||||
self.assertTrue(check_blob_exists(
|
||||
PUBLIC_CONTAINER_URI, blob_name=PUBLIC_BLOB_NAME))
|
||||
|
||||
print('PUBLIC_CONTAINER_URI')
|
||||
with self.assertRaises(IndexError):
|
||||
check_blob_existence(PUBLIC_CONTAINER_URI)
|
||||
check_blob_exists(PUBLIC_CONTAINER_URI)
|
||||
print('PUBLIC_INVALID_BLOB_URI')
|
||||
self.assertFalse(check_blob_existence(PUBLIC_INVALID_BLOB_URI))
|
||||
|
||||
print('PRIVATE_BLOB_URI')
|
||||
with self.assertRaises(HttpResponseError):
|
||||
check_blob_existence(PRIVATE_BLOB_URI)
|
||||
self.assertFalse(check_blob_exists(PUBLIC_INVALID_BLOB_URI))
|
||||
|
||||
def test_list_blobs_in_container(self):
|
||||
blobs_list = list_blobs_in_container(
|
||||
|
@ -155,9 +154,6 @@ class Tests(unittest.TestCase):
|
|||
self.assertEqual(blobs_list, expected)
|
||||
|
||||
def test_generate_writable_container_sas(self):
|
||||
# until the private emulated account is able to work, skip this test
|
||||
self.skipTest('skipping private account tests for now')
|
||||
|
||||
self.needs_cleanup = True
|
||||
new_sas_uri = generate_writable_container_sas(
|
||||
account_name=PRIVATE_ACCOUNT_NAME,
|
||||
|
@ -172,9 +168,9 @@ class Tests(unittest.TestCase):
|
|||
def test_upload_blob(self):
|
||||
self.needs_cleanup = True
|
||||
# uploading to a read-only public container without a SAS token yields
|
||||
# ResourceNotFoundError('The specified resource does not exist.')
|
||||
# HttpResponseError('Server failed to authenticate the request.')
|
||||
print('PUBLIC_CONTAINER_URI')
|
||||
with self.assertRaises(ResourceNotFoundError):
|
||||
with self.assertRaises(HttpResponseError):
|
||||
upload_blob(PUBLIC_CONTAINER_URI,
|
||||
blob_name='failblob', data='fail')
|
||||
|
||||
|
@ -195,17 +191,25 @@ class Tests(unittest.TestCase):
|
|||
upload_blob(PRIVATE_CONTAINER_URI,
|
||||
blob_name=PRIVATE_BLOB_NAME, data='success')
|
||||
|
||||
# until the private emulated account is able to work, skip this test
|
||||
# private_container_uri_sas = generate_writable_container_sas(
|
||||
# account_name=PRIVATE_ACCOUNT_NAME,
|
||||
# account_key=PRIVATE_ACCOUNT_KEY,
|
||||
# container_name=PRIVATE_CONTAINER_NAME,
|
||||
# access_duration_hrs=1,
|
||||
# account_url=PRIVATE_ACCOUNT_URI)
|
||||
# blob_url = upload_blob(
|
||||
# private_container_uri_sas,
|
||||
# blob_name=PRIVATE_BLOB_NAME, data='success')
|
||||
# self.assertEqual(blob_url, PRIVATE_BLOB_URI)
|
||||
# upload to a private container with a SAS token
|
||||
private_container_uri_sas = generate_writable_container_sas(
|
||||
account_name=PRIVATE_ACCOUNT_NAME,
|
||||
account_key=PRIVATE_ACCOUNT_KEY,
|
||||
container_name=PRIVATE_CONTAINER_NAME,
|
||||
access_duration_hrs=1,
|
||||
account_url=PRIVATE_ACCOUNT_URI)
|
||||
container_sas = get_sas_token_from_uri(private_container_uri_sas)
|
||||
private_blob_uri_sas = f'{PRIVATE_BLOB_URI}?{container_sas}'
|
||||
blob_url = upload_blob(
|
||||
private_container_uri_sas,
|
||||
blob_name=PRIVATE_BLOB_NAME, data='success')
|
||||
self.assertEqual(blob_url, private_blob_uri_sas)
|
||||
|
||||
with BlobClient(account_url=PRIVATE_ACCOUNT_URI,
|
||||
container_name=PRIVATE_CONTAINER_NAME,
|
||||
blob_name=PRIVATE_BLOB_NAME,
|
||||
credential=container_sas) as blob_client:
|
||||
self.assertTrue(blob_client.exists())
|
||||
|
||||
def test_download_blob_to_stream(self):
|
||||
output, props = download_blob_to_stream(PUBLIC_BLOB_URI)
|
||||
|
@ -213,11 +217,10 @@ class Tests(unittest.TestCase):
|
|||
self.assertEqual(len(x), 376645)
|
||||
output.close()
|
||||
|
||||
# see https://github.com/Azure/azure-sdk-for-python/issues/12563
|
||||
expected_properties = {
|
||||
'size': 376645,
|
||||
# 'name': PUBLIC_BLOB_NAME,
|
||||
# 'container': 'nacti-unzipped'
|
||||
'name': PUBLIC_BLOB_NAME,
|
||||
'container': 'nacti-unzipped'
|
||||
}
|
||||
|
||||
for k, v in expected_properties.items():
|
||||
|
|
Загрузка…
Ссылка в новой задаче