register ATen fallback for instance_norm2d gradient
This commit is contained in:
Родитель
43766ee36d
Коммит
ceaf1da088
|
@ -114,6 +114,18 @@ def diagonal_gradient():
|
|||
]
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "instance_norm2d", "")
|
||||
def instance_norm2d_gradient():
|
||||
return [
|
||||
(
|
||||
("ATen", "org.pytorch.aten"),
|
||||
["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "O(1)"],
|
||||
["GI(0)"],
|
||||
{"operator": {"value": "instance_norm2d_backward", "dtype": "string"}},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@register_gradient("org.pytorch.aten", "ATen", "max_pool2d_with_indices", "")
|
||||
def max_pool2d_gradient():
|
||||
return [
|
||||
|
@ -125,7 +137,6 @@ def max_pool2d_gradient():
|
|||
),
|
||||
]
|
||||
|
||||
|
||||
def minmax_gradient():
|
||||
# Gradient of torch.min(input) (and max)
|
||||
# In PyTorch, when there are multiple maxima/minima, the gradient is evenly distributed among them.
|
||||
|
|
Загрузка…
Ссылка в новой задаче