зеркало из https://github.com/microsoft/nni.git
add avgpool2_formula to shape_formula.py (#5565)
This commit is contained in:
Родитель
5676de40d7
Коммит
1735a23ffe
|
@ -171,6 +171,18 @@ def maxpool2d_formula(module: nn.MaxPool2d | nas_nn.MutableMaxPool2d, input: Sha
|
|||
shape[-1] = (shape[-1] + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[1] + 1
|
||||
return MutableShape(*shape)
|
||||
|
||||
def avgpool2d_formula(module: nn.AvgPool2d , input: ShapeTensor) -> MutableShape:
|
||||
shape = list(input.real_shape) # type: ignore
|
||||
|
||||
padding, kernel_size, stride = map(
|
||||
lambda name: _getattr(module, name, expected_type=tuple_2_t),
|
||||
['padding', 'kernel_size', 'stride']
|
||||
)
|
||||
|
||||
# H_out and W_out
|
||||
shape[-2] = (shape[-2] + 2 * padding[0] - (kernel_size[0])) // stride[0] + 1
|
||||
shape[-1] = (shape[-1] + 2 * padding[1] - (kernel_size[1])) // stride[1] + 1
|
||||
return MutableShape(*shape)
|
||||
|
||||
def multihead_attention_formula(module: nn.MultiheadAttention | nas_nn.MutableMultiheadAttention,
|
||||
query: ShapeTensor, key: ShapeTensor, *args: Any, **kwargs) -> tuple[MutableShape, MutableShape | None]:
|
||||
|
@ -347,6 +359,7 @@ _shape_inference_formulas: dict[Type[nn.Module], Formula] = {
|
|||
nn.Linear: linear_formula,
|
||||
nn.Conv2d: conv2d_formula,
|
||||
nn.MaxPool2d: maxpool2d_formula,
|
||||
nn.AvgPool2d: avgpool2d_formula,
|
||||
nn.BatchNorm2d: keep_shape_formula,
|
||||
nn.LayerNorm: keep_shape_formula,
|
||||
nn.MultiheadAttention: multihead_attention_formula,
|
||||
|
|
|
@ -377,6 +377,12 @@ def test_adaptive_avg_pool2d():
|
|||
assert shape_inference(nn.AdaptiveAvgPool2d(3), t).real_shape == MutableShape(4, 2, 3, 3)
|
||||
assert shape_inference(nn.AdaptiveAvgPool2d((3, 4)), t).real_shape == MutableShape(4, 2, 3, 4)
|
||||
|
||||
def test_avg_pool2d():
|
||||
t = ShapeTensor(torch.randn(4, 2, 5, 5), True)
|
||||
assert shape_inference(nn.AvgPool2d(1), t).real_shape == MutableShape(4, 2, 5, 5)
|
||||
assert shape_inference(nn.AvgPool2d(3,stride=1), t).real_shape == MutableShape(4, 2, 3, 3)
|
||||
assert shape_inference(nn.AvgPool2d((3, 4),stride=1), t).real_shape == MutableShape(4, 2, 3, 2)
|
||||
|
||||
|
||||
def test_linear():
|
||||
input = ShapeTensor(torch.randn(4, 2), True)
|
||||
|
|
Загрузка…
Ссылка в новой задаче