Родитель
ccc7eadd9d
Коммит
e66ee67ebb
|
@ -39,6 +39,10 @@ def _input_format_classification(
|
|||
Returns:
|
||||
preds: tensor with labels
|
||||
target: tensor with labels
|
||||
|
||||
Example:
|
||||
>>> _input_format_classification(torch.tensor([[0.45, 0.55], [0.3, 0.7], [0.9, 0.1]]), torch.tensor([1, 0, 0]))
|
||||
(tensor([1, 1, 0]), tensor([1, 0, 0]))
|
||||
"""
|
||||
if not (preds.ndim == target.ndim or preds.ndim == target.ndim + 1):
|
||||
raise ValueError("preds and target must have same number of dimensions, or one additional dimension for preds")
|
||||
|
|
|
@ -88,6 +88,7 @@ def select_topk(prob_tensor: torch.Tensor, topk: int = 1, dim: int = 1) -> torch
|
|||
A binary tensor of the same shape as the input tensor of type torch.int32
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = torch.tensor([[1.1, 2.0, 3.0], [2.0, 1.0, 0.5]])
|
||||
>>> select_topk(x, topk=2)
|
||||
tensor([[0, 1, 1],
|
||||
|
@ -151,20 +152,29 @@ def get_num_classes(
|
|||
return num_classes
|
||||
|
||||
|
||||
def _stable_1d_sort(x: torch, N: int = 2049):
|
||||
def _stable_1d_sort(x: torch, nb: int = 2049):
|
||||
"""
|
||||
Stable sort of 1d tensors. Pytorch defaults to a stable sorting algorithm
|
||||
if number of elements are larger than 2048. This function pads the tensors,
|
||||
makes the sort and returns the sorted array (with the padding removed)
|
||||
See this discussion: https://discuss.pytorch.org/t/is-torch-sort-stable/20714
|
||||
|
||||
Example:
|
||||
>>> data = torch.tensor([8, 7, 2, 6, 4, 5, 3, 1, 9, 0])
|
||||
>>> _stable_1d_sort(data)
|
||||
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8]))
|
||||
>>> _stable_1d_sort(data, nb=5)
|
||||
(tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([9, 7, 2, 6, 4, 5, 3, 1, 0, 8]))
|
||||
"""
|
||||
if x.ndim > 1:
|
||||
raise ValueError('Stable sort only works on 1d tensors')
|
||||
n = x.numel()
|
||||
if N - n > 0:
|
||||
if n < nb:
|
||||
x_max = x.max()
|
||||
x_pad = torch.cat([x, (x_max + 1) * torch.ones(2049 - n, dtype=x.dtype, device=x.device)], 0)
|
||||
x_sort = x_pad.sort()
|
||||
x_sort = x_pad.sort()
|
||||
else:
|
||||
x_sort = x.sort()
|
||||
return x_sort.values[:n], x_sort.indices[:n]
|
||||
|
||||
|
||||
|
@ -190,6 +200,14 @@ def apply_to_collection(
|
|||
|
||||
Returns:
|
||||
the resulting collection
|
||||
|
||||
Example:
|
||||
>>> apply_to_collection(torch.tensor([8, 0, 2, 6, 7]), dtype=torch.Tensor, function=lambda x: x ** 2)
|
||||
tensor([64, 0, 4, 36, 49])
|
||||
>>> apply_to_collection([8, 0, 2, 6, 7], dtype=int, function=lambda x: x ** 2)
|
||||
[64, 0, 4, 36, 49]
|
||||
>>> apply_to_collection(dict(abc=123), dtype=int, function=lambda x: x ** 2)
|
||||
{'abc': 15129}
|
||||
"""
|
||||
elem_type = type(data)
|
||||
|
||||
|
|
|
@ -16,7 +16,16 @@ from typing import Union
|
|||
|
||||
|
||||
class EnumStr(str, Enum):
|
||||
""" Type of any enumerator with allowed comparison to string invariant to cases. """
|
||||
""" Type of any enumerator with allowed comparison to string invariant to cases.
|
||||
|
||||
Example:
|
||||
>>> class MyEnum(EnumStr):
|
||||
... ABC = 'abc'
|
||||
>>> MyEnum.from_str('Abc')
|
||||
<MyEnum.ABC: 'abc'>
|
||||
>>> {MyEnum.ABC: 123}
|
||||
{<MyEnum.ABC: 'abc'>: 123}
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> 'EnumStr':
|
||||
|
|
Загрузка…
Ссылка в новой задаче