From 1735a23ffeb3dab98bbe75777f4e1f16304747d9 Mon Sep 17 00:00:00 2001 From: Avi Singhal <97785770+avisinghal6@users.noreply.github.com> Date: Sat, 20 May 2023 00:27:49 -0700 Subject: [PATCH] add avgpool2_formula to shape_formula.py (#5565) --- nni/nas/profiler/pytorch/utils/shape_formula.py | 13 +++++++++++++ test/ut/nas/profiler/test_shape.py | 6 ++++++ 2 files changed, 19 insertions(+) diff --git a/nni/nas/profiler/pytorch/utils/shape_formula.py b/nni/nas/profiler/pytorch/utils/shape_formula.py index 05e81081b..b5cd0b7ec 100644 --- a/nni/nas/profiler/pytorch/utils/shape_formula.py +++ b/nni/nas/profiler/pytorch/utils/shape_formula.py @@ -171,6 +171,18 @@ def maxpool2d_formula(module: nn.MaxPool2d | nas_nn.MutableMaxPool2d, input: Sha shape[-1] = (shape[-1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1 return MutableShape(*shape) +def avgpool2d_formula(module: nn.AvgPool2d , input: ShapeTensor) -> MutableShape: + shape = list(input.real_shape) # type: ignore + + padding, kernel_size, stride = map( + lambda name: _getattr(module, name, expected_type=tuple_2_t), + ['padding', 'kernel_size', 'stride'] + ) + + # H_out and W_out + shape[-2] = (shape[-2] + 2 * padding[0] - (kernel_size[0])) // stride[0] + 1 + shape[-1] = (shape[-1] + 2 * padding[1] - (kernel_size[1])) // stride[1] + 1 + return MutableShape(*shape) def multihead_attention_formula(module: nn.MultiheadAttention | nas_nn.MutableMultiheadAttention, query: ShapeTensor, key: ShapeTensor, *args: Any, **kwargs) -> tuple[MutableShape, MutableShape | None]: @@ -347,6 +359,7 @@ _shape_inference_formulas: dict[Type[nn.Module], Formula] = { nn.Linear: linear_formula, nn.Conv2d: conv2d_formula, nn.MaxPool2d: maxpool2d_formula, + nn.AvgPool2d: avgpool2d_formula, nn.BatchNorm2d: keep_shape_formula, nn.LayerNorm: keep_shape_formula, nn.MultiheadAttention: multihead_attention_formula, diff --git a/test/ut/nas/profiler/test_shape.py b/test/ut/nas/profiler/test_shape.py index ffe53e1a6..0686a6815 100644 --- a/test/ut/nas/profiler/test_shape.py +++ b/test/ut/nas/profiler/test_shape.py @@ -377,6 +377,12 @@ def test_adaptive_avg_pool2d(): assert shape_inference(nn.AdaptiveAvgPool2d(3), t).real_shape == MutableShape(4, 2, 3, 3) assert shape_inference(nn.AdaptiveAvgPool2d((3, 4)), t).real_shape == MutableShape(4, 2, 3, 4) +def test_avg_pool2d(): + t = ShapeTensor(torch.randn(4, 2, 5, 5), True) + assert shape_inference(nn.AvgPool2d(1), t).real_shape == MutableShape(4, 2, 5, 5) + assert shape_inference(nn.AvgPool2d(3,stride=1), t).real_shape == MutableShape(4, 2, 3, 3) + assert shape_inference(nn.AvgPool2d((3, 4),stride=1), t).real_shape == MutableShape(4, 2, 3, 2) + def test_linear(): input = ShapeTensor(torch.randn(4, 2), True)