NeuronBlocks/block_zoo/op/Match.py

82 строки
2.4 KiB
Python

# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT license.
import torch.nn as nn
import copy
from block_zoo.BaseLayer import BaseLayer, BaseConf
from utils.DocInherit import DocInherit
from utils.exceptions import ConfigurationError
class MatchConf(BaseConf):
"""Configuration for MatchAttention layer
"""
def __init__(self, **kwargs):
super(MatchConf, self).__init__(**kwargs)
@DocInherit
def default(self):
self.activation = 'PReLU'
@DocInherit
def declare(self):
self.num_of_inputs = 2
self.input_ranks = [3, 3]
@DocInherit
def inference(self):
self.output_dim = copy.deepcopy(self.input_dims[0])
if self.input_dims[0][1] == -1 or self.input_dims[1][1] == -1:
raise ConfigurationError("For Match layer, the sequence length should be fixed")
self.output_dim[-1] = self.input_dims[1][1] # y_len
super(MatchConf, self).inference() # PUT THIS LINE AT THE END OF inference()
@DocInherit
def verify(self):
super(MatchConf, self).verify()
class Match(BaseLayer):
""" Match layer
Given sequences X and Y, match sequence Y to each element in X.
Args:
layer_conf (MatchConf): configuration of a layer
"""
def __init__(self, layer_conf):
super(Match, self).__init__(layer_conf)
assert layer_conf.input_dims[0][-1] == layer_conf.input_dims[1][-1]
self.linear = nn.Linear(layer_conf.input_dims[0][-1], layer_conf.input_dims[0][-1])
if layer_conf.activation:
self.activation = eval("nn." + self.layer_conf.activation)()
else:
self.activation = None
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, x_len, y, y_len):
"""
Args:
x: [batch_size, x_max_len, dim].
x_len: [batch_size], default is None.
y: [batch_size, y_max_len, dim].
y_len: [batch_size], default is None.
Returns:
output: has the same shape as x.
"""
x_proj = self.linear(x) # [batch_size, x_max_len, dim]
y_proj = self.linear(y) # [batch_size, y_max_len, dim]
if self.activation:
x_proj = self.activation(x_proj)
y_proj = self.activation(y_proj)
scores = x_proj.bmm(y_proj.transpose(2, 1)) # [batch_size, x_max_len, y_max_len]
return scores, x_len