зеркало из https://github.com/microsoft/TextNAS.git
28 строки
644 B
Python
28 строки
644 B
Python
# Copyright (c) Microsoft Corporation.
|
|
# Licensed under the MIT License.
|
|
|
|
import sys
|
|
import numpy as np
|
|
import torch
|
|
|
|
def global_avg_pool(x, mask):
|
|
x = torch.sum(x, 2)
|
|
length = torch.sum(mask, 1, keepdim=True).float()
|
|
length += torch.eq(length, 0.0).float() * 1e-12
|
|
length = length.repeat(1, x.size()[1])
|
|
x /= length
|
|
return x
|
|
|
|
def global_max_pool(x, mask):
|
|
mask = torch.eq(mask.float(), 0.0).long()
|
|
mask = torch.unsqueeze(mask, dim=1).repeat(1, x.size()[1], 1)
|
|
mask *= -(2 ** 32) + 1
|
|
x += mask
|
|
x = torch.max(x, 2)[0]
|
|
return x
|
|
|
|
def get_length(mask):
|
|
length = torch.sum(mask, 1)
|
|
length = length.long()
|
|
return length
|