Add ImageFeatureExtractionMixin (#10905)

* Add ImageFeatureExtractionMixin

* Add dummy vision objects

* Add require_vision

* Add tests

* Fix test
This commit is contained in:
Sylvain Gugger 2021-03-26 11:23:56 -04:00 коммит произвёл GitHub
Родитель 3c27d246e5
Коммит b0595d33c1
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
8 изменённых файлов: 544 добавлений и 26 удалений

Просмотреть файл

@ -39,3 +39,10 @@ BatchFeature
.. autoclass:: transformers.BatchFeature
:members:
ImageFeatureExtractionMixin
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.image_utils.ImageFeatureExtractionMixin
:members:

Просмотреть файл

@ -48,6 +48,7 @@ from .file_utils import (
is_tf_available,
is_tokenizers_available,
is_torch_available,
is_vision_available,
)
from .utils import logging
@ -105,6 +106,7 @@ _import_structure = {
"is_tokenizers_available",
"is_torch_available",
"is_torch_tpu_available",
"is_vision_available",
],
"hf_argparser": ["HfArgumentParser"],
"integrations": [
@ -341,6 +343,16 @@ else:
name for name in dir(dummy_tokenizers_objects) if not name.startswith("_")
]
# Vision-specific objects
if is_vision_available():
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
else:
from .utils import dummy_vision_objects
_import_structure["utils.dummy_vision_objects"] = [
name for name in dir(dummy_vision_objects) if not name.startswith("_")
]
# PyTorch-backed objects
if is_torch_available():
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
@ -1317,6 +1329,7 @@ if TYPE_CHECKING:
is_tokenizers_available,
is_torch_available,
is_torch_tpu_available,
is_vision_available,
)
from .hf_argparser import HfArgumentParser
@ -1544,6 +1557,11 @@ if TYPE_CHECKING:
else:
from .utils.dummy_tokenizers_objects import *
if is_vision_available():
from .image_utils import ImageFeatureExtractionMixin
else:
from .utils.dummy_vision_objects import *
# Modeling
if is_torch_available():

Просмотреть файл

@ -326,6 +326,10 @@ def is_tokenizers_available():
return importlib.util.find_spec("tokenizers") is not None
def is_vision_available():
return importlib.util.find_spec("PIL") is not None
def is_in_notebook():
try:
# Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py
@ -490,6 +494,13 @@ explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/ins
"""
# docstyle-ignore
VISION_IMPORT_ERROR = """
{0} requires the PIL library but it was not found in your environment. You can install it with pip:
`pip install pillow`
"""
def requires_datasets(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_datasets_available():
@ -556,6 +567,12 @@ def requires_scatter(obj):
raise ImportError(SCATTER_IMPORT_ERROR.format(name))
def requires_vision(obj):
name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
if not is_vision_available():
raise ImportError(VISION_IMPORT_ERROR.format(name))
def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")

Просмотреть файл

@ -0,0 +1,158 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import PIL.Image
from .file_utils import _is_torch, is_torch_available
def is_torch_tensor(obj):
return _is_torch(obj) if is_torch_available() else False
# In the future we can add a TF implementation here when we have TF models.
class ImageFeatureExtractionMixin:
"""
Mixin that contain utilities for preparing image features.
"""
def _ensure_format_supported(self, image):
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image):
raise ValueError(
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and "
"`torch.Tensor` are."
)
def to_pil_image(self, image, rescale=None):
"""
Converts :obj:`image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last
axis if needed.
Args:
image (:obj:`PIL.Image.Image` or :obj:`numpy.ndarray` or :obj:`torch.Tensor`):
The image to convert to the PIL Image format.
rescale (:obj:`bool`, `optional`):
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will
default to :obj:`True` if the image type is a floating type, :obj:`False` otherwise.
"""
self._ensure_format_supported(image)
if is_torch_tensor(image):
image = image.numpy()
if isinstance(image, np.ndarray):
if rescale is None:
# rescale default to the array being of floating type.
rescale = isinstance(image.flat[0], np.floating)
# If the channel as been moved to first dim, we put it back at the end.
if image.ndim == 3 and image.shape[0] in [1, 3]:
image = image.transpose(1, 2, 0)
if rescale:
image = image * 255
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)
return image
def to_numpy_array(self, image, rescale=None, channel_first=True):
"""
Converts :obj:`image` to a numpy array. Optionally rescales it and puts the channel dimension as the first
dimension.
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to convert to a NumPy array.
rescale (:obj:`bool`, `optional`):
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will
default to :obj:`True` if the image is a PIL Image or an array/tensor of integers, :obj:`False`
otherwise.
channel_first (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to permute the dimensions of the image to put the channel dimension first.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = np.array(image)
if is_torch_tensor(image):
image = image.numpy()
if rescale is None:
rescale = isinstance(image.flat[0], np.integer)
if rescale:
image = image.astype(np.float32) / 255.0
if channel_first:
image = image.transpose(2, 0, 1)
return image
def normalize(self, image, mean, std):
"""
Normalizes :obj:`image` with :obj:`mean` and :obj:`std`. Note that this will trigger a conversion of
:obj:`image` to a NumPy array if it's a PIL Image.
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to normalize.
mean (:obj:`List[float]` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The mean (per channel) to use for normalization.
std (:obj:`List[float]` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The standard deviation (per channel) to use for normalization.
"""
self._ensure_format_supported(image)
if isinstance(image, PIL.Image.Image):
image = self.to_numpy_array(image)
if isinstance(image, np.ndarray):
if not isinstance(mean, np.ndarray):
mean = np.array(mean)
if not isinstance(std, np.ndarray):
std = np.array(std)
elif is_torch_tensor(image):
import torch
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
if image.ndim == 3 and image.shape[0] in [1, 3]:
return (image - mean[:, None, None]) / std[:, None, None]
else:
return (image - mean) / std
def resize(self, image, size, resample=PIL.Image.BILINEAR):
"""
Resizes :obj:`image`. Note that this will trigger a conversion of :obj:`image` to a PIL Image.
Args:
image (:obj:`PIL.Image.Image` or :obj:`np.ndarray` or :obj:`torch.Tensor`):
The image to resize.
size (:obj:`int` or :obj:`Tuple[int, int]`):
The size to use for resizing the image.
resample (:obj:`int`, `optional`, defaults to :obj:`PIL.Image.BILINEAR`):
The filter to user for resampling.
"""
self._ensure_format_supported(image)
if not isinstance(size, tuple):
size = (size, size)
if not isinstance(image, PIL.Image.Image):
image = self.to_pil_image(image)
return image.resize(size, resample=resample)

Просмотреть файл

@ -39,6 +39,7 @@ from .file_utils import (
is_torch_available,
is_torch_tpu_available,
is_torchaudio_available,
is_vision_available,
)
from .integrations import is_optuna_available, is_ray_available
@ -229,12 +230,9 @@ def require_torch_scatter(test_case):
def require_torchaudio(test_case):
"""
Decorator marking a test that requires torchaudio.
These tests are skipped when torchaudio isn't installed.
Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
"""
if not is_torchaudio_available:
if not is_torchaudio_available():
return unittest.skip("test requires torchaudio")(test_case)
else:
return test_case
@ -242,10 +240,7 @@ def require_torchaudio(test_case):
def require_tf(test_case):
"""
Decorator marking a test that requires TensorFlow.
These tests are skipped when TensorFlow isn't installed.
Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
"""
if not is_tf_available():
return unittest.skip("test requires TensorFlow")(test_case)
@ -255,10 +250,7 @@ def require_tf(test_case):
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax
These tests are skipped when one / both are not installed
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
if not is_flax_available():
test_case = unittest.skip("test requires JAX & Flax")(test_case)
@ -267,10 +259,7 @@ def require_flax(test_case):
def require_sentencepiece(test_case):
"""
Decorator marking a test that requires SentencePiece.
These tests are skipped when SentencePiece isn't installed.
Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
"""
if not is_sentencepiece_available():
return unittest.skip("test requires SentencePiece")(test_case)
@ -280,10 +269,7 @@ def require_sentencepiece(test_case):
def require_tokenizers(test_case):
"""
Decorator marking a test that requires 🤗 Tokenizers.
These tests are skipped when 🤗 Tokenizers isn't installed.
Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
"""
if not is_tokenizers_available():
return unittest.skip("test requires tokenizers")(test_case)
@ -312,11 +298,21 @@ def require_scatter(test_case):
return test_case
def require_vision(test_case):
"""
Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
installed.
"""
if not is_vision_available():
return unittest.skip("test requires vision")(test_case)
else:
return test_case
def require_torch_multi_gpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
These tests are skipped on a machine without multiple GPUs.
Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
multiple GPUs.
To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
"""

Просмотреть файл

@ -0,0 +1,7 @@
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..file_utils import requires_vision
class ImageFeatureExtractionMixin:
def __init__(self, *args, **kwargs):
requires_vision(self)

315
tests/test_image_utils.py Normal file
Просмотреть файл

@ -0,0 +1,315 @@
# coding=utf-8
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from transformers import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision
if is_torch_available():
import torch
if is_vision_available():
import PIL.Image
from transformers import ImageFeatureExtractionMixin
def get_random_image(height, width):
random_array = np.random.randint(0, 256, (height, width, 3), dtype=np.uint8)
return PIL.Image.fromarray(random_array)
@require_vision
class ImageFeatureExtractionTester(unittest.TestCase):
def test_conversion_image_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
# Conversion with defaults (rescale + channel first)
array1 = feature_extractor.to_numpy_array(image)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
# Conversion with rescale and not channel first
array2 = feature_extractor.to_numpy_array(image, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array1, array2.transpose(2, 0, 1)))
# Conversion with no rescale and channel first
array3 = feature_extractor.to_numpy_array(image, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array3.astype(np.float32) / 255.0))
# Conversion with no rescale and not channel first
array4 = feature_extractor.to_numpy_array(image, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array4.astype(np.float32) / 255.0))
def test_conversion_array_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
# By default, rescale (for an array of ints) and channel permute
array1 = feature_extractor.to_numpy_array(array)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) / 255.0))
# Same with no permute
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array.astype(np.float32) / 255.0))
# Force rescale to False
array3 = feature_extractor.to_numpy_array(array, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
# Force rescale to False and no channel permute
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array4, array))
# Now test the default rescale for a float array (defaults to False)
array5 = feature_extractor.to_numpy_array(array2)
self.assertTrue(array5.dtype, np.float32)
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
@require_torch
def test_conversion_torch_to_array(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# By default, rescale (for a tensor of ints) and channel permute
array1 = feature_extractor.to_numpy_array(array)
self.assertTrue(array1.dtype, np.float32)
self.assertEqual(array1.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array1, array.transpose(2, 0, 1).astype(np.float32) / 255.0))
# Same with no permute
array2 = feature_extractor.to_numpy_array(array, channel_first=False)
self.assertTrue(array2.dtype, np.float32)
self.assertEqual(array2.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array2, array.astype(np.float32) / 255.0))
# Force rescale to False
array3 = feature_extractor.to_numpy_array(array, rescale=False)
self.assertTrue(array3.dtype, np.uint8)
self.assertEqual(array3.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array3, array.transpose(2, 0, 1)))
# Force rescale to False and no channel permute
array4 = feature_extractor.to_numpy_array(array, rescale=False, channel_first=False)
self.assertTrue(array4.dtype, np.uint8)
self.assertEqual(array4.shape, (16, 32, 3))
self.assertTrue(np.array_equal(array4, array))
# Now test the default rescale for a float tensor (defaults to False)
array5 = feature_extractor.to_numpy_array(array2)
self.assertTrue(array5.dtype, np.float32)
self.assertEqual(array5.shape, (3, 16, 32))
self.assertTrue(np.array_equal(array5, array1))
def test_conversion_image_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
# On an image, `to_pil_image1` is a noop.
image1 = feature_extractor.to_pil_image(image)
self.assertTrue(isinstance(image, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image), np.array(image1)))
def test_conversion_array_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.randint(0, 256, (16, 32, 3), dtype=np.uint8)
# By default, no rescale (for an array of ints)
image1 = feature_extractor.to_pil_image(array)
self.assertTrue(isinstance(image1, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image1), array))
# If the array is channel-first, proper reordering of the channels is done.
image2 = feature_extractor.to_pil_image(array.transpose(2, 0, 1))
self.assertTrue(isinstance(image2, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image2), array))
# If the array has floating type, it's rescaled by default.
image3 = feature_extractor.to_pil_image(array.astype(np.float32) / 255.0)
self.assertTrue(isinstance(image3, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image3), array))
# You can override the default to rescale.
image4 = feature_extractor.to_pil_image(array.astype(np.float32), rescale=False)
self.assertTrue(isinstance(image4, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image4), array))
# And with floats + channel first.
image5 = feature_extractor.to_pil_image(array.transpose(2, 0, 1).astype(np.float32) / 255.0)
self.assertTrue(isinstance(image5, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image5), array))
@require_torch
def test_conversion_tensor_to_image(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# By default, no rescale (for a tensor of ints)
image1 = feature_extractor.to_pil_image(tensor)
self.assertTrue(isinstance(image1, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image1), array))
# If the tensor is channel-first, proper reordering of the channels is done.
image2 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1))
self.assertTrue(isinstance(image2, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image2), array))
# If the tensor has floating type, it's rescaled by default.
image3 = feature_extractor.to_pil_image(tensor.float() / 255.0)
self.assertTrue(isinstance(image3, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image3), array))
# You can override the default to rescale.
image4 = feature_extractor.to_pil_image(tensor.float(), rescale=False)
self.assertTrue(isinstance(image4, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image4), array))
# And with floats + channel first.
image5 = feature_extractor.to_pil_image(tensor.permute(2, 0, 1).float() / 255.0)
self.assertTrue(isinstance(image5, PIL.Image.Image))
self.assertTrue(np.array_equal(np.array(image5), array))
def test_resize_image_and_array(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = np.array(image)
# Size can be an int or a tuple of ints.
resized_image = feature_extractor.resize(image, 8)
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
self.assertEqual(resized_image.size, (8, 8))
resized_image1 = feature_extractor.resize(image, (8, 16))
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
self.assertEqual(resized_image1.size, (8, 16))
# Passing and array converts it to a PIL Image.
resized_image2 = feature_extractor.resize(array, 8)
self.assertTrue(isinstance(resized_image2, PIL.Image.Image))
self.assertEqual(resized_image2.size, (8, 8))
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
resized_image3 = feature_extractor.resize(image, (8, 16))
self.assertTrue(isinstance(resized_image3, PIL.Image.Image))
self.assertEqual(resized_image3.size, (8, 16))
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
@require_torch
def test_resize_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.randint(0, 256, (16, 32, 3))
array = tensor.numpy()
# Size can be an int or a tuple of ints.
resized_image = feature_extractor.resize(tensor, 8)
self.assertTrue(isinstance(resized_image, PIL.Image.Image))
self.assertEqual(resized_image.size, (8, 8))
resized_image1 = feature_extractor.resize(tensor, (8, 16))
self.assertTrue(isinstance(resized_image1, PIL.Image.Image))
self.assertEqual(resized_image1.size, (8, 16))
# Check we get the same results as with NumPy arrays.
resized_image2 = feature_extractor.resize(array, 8)
self.assertTrue(np.array_equal(np.array(resized_image), np.array(resized_image2)))
resized_image3 = feature_extractor.resize(array, (8, 16))
self.assertTrue(np.array_equal(np.array(resized_image1), np.array(resized_image3)))
def test_normalize_image(self):
feature_extractor = ImageFeatureExtractionMixin()
image = get_random_image(16, 32)
array = np.array(image)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# PIL Image are converted to NumPy arrays for the normalization
normalized_image = feature_extractor.normalize(image, mean, std)
self.assertTrue(isinstance(normalized_image, np.ndarray))
self.assertEqual(normalized_image.shape, (3, 16, 32))
# During the conversion rescale and channel first will be applied.
expected = array.transpose(2, 0, 1).astype(np.float32) / 255.0
expected = (expected - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
self.assertTrue(np.array_equal(normalized_image, expected))
def test_normalize_array(self):
feature_extractor = ImageFeatureExtractionMixin()
array = np.random.random((16, 32, 3))
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or NumPy arrays.
expected = (array - np.array(mean)) / np.array(std)
normalized_array = feature_extractor.normalize(array, mean, std)
self.assertTrue(np.array_equal(normalized_array, expected))
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
self.assertTrue(np.array_equal(normalized_array, expected))
# Normalize will detect automatically if channel first or channel last is used.
array = np.random.random((3, 16, 32))
expected = (array - np.array(mean)[:, None, None]) / np.array(std)[:, None, None]
normalized_array = feature_extractor.normalize(array, mean, std)
self.assertTrue(np.array_equal(normalized_array, expected))
normalized_array = feature_extractor.normalize(array, np.array(mean), np.array(std))
self.assertTrue(np.array_equal(normalized_array, expected))
@require_torch
def test_normalize_tensor(self):
feature_extractor = ImageFeatureExtractionMixin()
tensor = torch.rand(16, 32, 3)
mean = [0.1, 0.5, 0.9]
std = [0.2, 0.4, 0.6]
# mean and std can be passed as lists or tensors.
expected = (tensor - torch.tensor(mean)) / torch.tensor(std)
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
self.assertTrue(torch.equal(normalized_tensor, expected))
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))
# Normalize will detect automatically if channel first or channel last is used.
tensor = torch.rand(3, 16, 32)
expected = (tensor - torch.tensor(mean)[:, None, None]) / torch.tensor(std)[:, None, None]
normalized_tensor = feature_extractor.normalize(tensor, mean, std)
self.assertTrue(torch.equal(normalized_tensor, expected))
normalized_tensor = feature_extractor.normalize(tensor, torch.tensor(mean), torch.tensor(std))
self.assertTrue(torch.equal(normalized_tensor, expected))

Просмотреть файл

@ -26,7 +26,7 @@ _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers", "vision"]
DUMMY_CONSTANT = """
@ -68,7 +68,7 @@ def read_init():
backend_specific_objects = {}
# Go through the end of the file
while line_index < len(lines):
# If the line is an if is_backemd_available, we grab all objects associated.
# If the line is an if is_backend_available, we grab all objects associated.
if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
line_index += 1