register ATen fallback for instance_norm2d gradient

This commit is contained in:
Prathik Rao 2022-10-06 16:52:03 -07:00
Родитель 43766ee36d
Коммит ceaf1da088
1 изменённых файлов: 12 добавлений и 1 удалений

Просмотреть файл

@ -114,6 +114,18 @@ def diagonal_gradient():
]
@register_gradient("org.pytorch.aten", "ATen", "instance_norm2d", "")
def instance_norm2d_gradient():
return [
(
("ATen", "org.pytorch.aten"),
["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "O(1)"],
["GI(0)"],
{"operator": {"value": "instance_norm2d_backward", "dtype": "string"}},
),
]
@register_gradient("org.pytorch.aten", "ATen", "max_pool2d_with_indices", "")
def max_pool2d_gradient():
return [
@ -125,7 +137,6 @@ def max_pool2d_gradient():
),
]
def minmax_gradient():
# Gradient of torch.min(input) (and max)
# In PyTorch, when there are multiple maxima/minima, the gradient is evenly distributed among them.