add avgpool2_formula to shape_formula.py (#5565)

This commit is contained in:
Avi Singhal 2023-05-20 00:27:49 -07:00 коммит произвёл GitHub
Родитель 5676de40d7
Коммит 1735a23ffe
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
2 изменённых файлов: 19 добавлений и 0 удалений

Просмотреть файл

@ -171,6 +171,18 @@ def maxpool2d_formula(module: nn.MaxPool2d | nas_nn.MutableMaxPool2d, input: Sha
shape[-1] = (shape[-1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
return MutableShape(*shape)
def avgpool2d_formula(module: nn.AvgPool2d , input: ShapeTensor) -> MutableShape:
shape = list(input.real_shape) # type: ignore
padding, kernel_size, stride = map(
lambda name: _getattr(module, name, expected_type=tuple_2_t),
['padding', 'kernel_size', 'stride']
)
# H_out and W_out
shape[-2] = (shape[-2] + 2 * padding[0] - (kernel_size[0])) // stride[0] + 1
shape[-1] = (shape[-1] + 2 * padding[1] - (kernel_size[1])) // stride[1] + 1
return MutableShape(*shape)
def multihead_attention_formula(module: nn.MultiheadAttention | nas_nn.MutableMultiheadAttention,
query: ShapeTensor, key: ShapeTensor, *args: Any, **kwargs) -> tuple[MutableShape, MutableShape | None]:
@ -347,6 +359,7 @@ _shape_inference_formulas: dict[Type[nn.Module], Formula] = {
nn.Linear: linear_formula,
nn.Conv2d: conv2d_formula,
nn.MaxPool2d: maxpool2d_formula,
nn.AvgPool2d: avgpool2d_formula,
nn.BatchNorm2d: keep_shape_formula,
nn.LayerNorm: keep_shape_formula,
nn.MultiheadAttention: multihead_attention_formula,

Просмотреть файл

@ -377,6 +377,12 @@ def test_adaptive_avg_pool2d():
assert shape_inference(nn.AdaptiveAvgPool2d(3), t).real_shape == MutableShape(4, 2, 3, 3)
assert shape_inference(nn.AdaptiveAvgPool2d((3, 4)), t).real_shape == MutableShape(4, 2, 3, 4)
def test_avg_pool2d():
t = ShapeTensor(torch.randn(4, 2, 5, 5), True)
assert shape_inference(nn.AvgPool2d(1), t).real_shape == MutableShape(4, 2, 5, 5)
assert shape_inference(nn.AvgPool2d(3,stride=1), t).real_shape == MutableShape(4, 2, 3, 3)
assert shape_inference(nn.AvgPool2d((3, 4),stride=1), t).real_shape == MutableShape(4, 2, 3, 2)
def test_linear():
input = ShapeTensor(torch.randn(4, 2), True)