add tests for uncovered functions (#43)

* data

* checks

* enum

* rand
This commit is contained in:
Jirka Borovec 2021-03-11 10:39:47 +01:00 коммит произвёл GitHub
Родитель ccc7eadd9d
Коммит e66ee67ebb
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
3 изменённых файлов: 35 добавлений и 4 удалений

Просмотреть файл

@ -39,6 +39,10 @@ def _input_format_classification(
Returns:
preds: tensor with labels
target: tensor with labels
Example:
>>> _input_format_classification(torch.tensor([[0.45, 0.55], [0.3, 0.7], [0.9, 0.1]]), torch.tensor([1, 0, 0]))
(tensor([1, 1, 0]), tensor([1, 0, 0]))
"""
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")

Просмотреть файл

@ -88,6 +88,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
A binary tensor of the same shape as the input tensor of type torch.int32
Example:
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
>>> select_topk(x, topk=2)
tensor([[0, 1, 1],
@ -151,20 +152,29 @@ def get_num_classes(
return num_classes
def _stable_1d_sort(x: torch, N: int = 2049):
def _stable_1d_sort(x: torch, nb: int = 2049):
"""
Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm
if number of elements are larger than 2048. This function pads the tensors,
makes the sort and returns the sorted array (with the padding removed)
See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714
Example:
>>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0])
>>> _stable_1d_sort(data)
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8]))
>>> _stable_1d_sort(data, nb=5)
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8]))
"""
if x.ndim > 1:
raise ValueError('Stable sort only works on 1d tensors')
n = x.numel()
if N - n > 0:
if n < nb:
x_max = x.max()
x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0)
x_sort = x_pad.sort()
x_sort = x_pad.sort()
else:
x_sort = x.sort()
return x_sort.values[:n], x_sort.indices[:n]
@ -190,6 +200,14 @@ def apply_to_collection(
Returns:
the resulting collection
Example:
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=torch.Tensor, function=lambda x: x ** 2)
tensor([64, 0, 4, 36, 49])
>>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2)
[64, 0, 4, 36, 49]
>>> apply_to_collection(dict(abc=123), dtype=int, function=lambda x: x ** 2)
{'abc': 15129}
"""
elem_type = type(data)

Просмотреть файл

@ -16,7 +16,16 @@ from typing import Union
class EnumStr(str, Enum):
""" Type of any enumerator with allowed comparison to string invariant to cases. """
""" Type of any enumerator with allowed comparison to string invariant to cases.
Example:
>>> class MyEnum(EnumStr):
... ABC = 'abc'
>>> MyEnum.from_str('Abc')
<MyEnum.ABC: 'abc'>
>>> {MyEnum.ABC: 123}
{<MyEnum.ABC: 'abc'>: 123}
"""
@classmethod
def from_str(cls, value: str) -> 'EnumStr':