esvit/layers/performer.py

204 строки
7.1 KiB
Python

# mainly modified from
# https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py
import math
from scipy.stats import ortho_group
import torch
from torch import nn
from einops import rearrange, repeat
from functools import partial
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# kernel functions
# transcribed from jax to pytorch from
# https://github.com/google-research/google-research/blob/master/performer/fast_attention/jax/fast_attention.py
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True,
eps=1e-4, device=None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
ratio = (projection_matrix.shape[0] ** -0.5)
projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
projection)
diag_data = data ** 2
diag_data = torch.sum(diag_data, dim=-1)
diag_data = (diag_data / 2.0) * (data_normalizer ** 2)
diag_data = diag_data.unsqueeze(dim=-1)
if is_query:
data_dash = ratio * (
torch.exp(data_dash - diag_data -
torch.max(data_dash, dim=-1,
keepdim=True).values) + eps)
else:
data_dash = ratio * (
torch.exp(data_dash - diag_data - torch.max(data_dash)) + eps)
return data_dash.type_as(data)
def generalized_kernel(data, *, projection_matrix, kernel_fn=nn.ReLU(),
kernel_epsilon=0.001, normalize_data=True, device=None):
b, h, *_ = data.shape
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.
if projection_matrix is None:
return kernel_fn(data_normalizer * data) + kernel_epsilon
projection = repeat(projection_matrix, 'j d -> b h j d', b=b, h=h)
projection = projection.type_as(data)
data_dash = torch.einsum('...id,...jd->...ij', (data_normalizer * data),
projection)
data_prime = kernel_fn(data_dash) + kernel_epsilon
return data_prime.type_as(data)
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0,
device=None):
nb_full_blocks = int(nb_rows / nb_columns)
block_list = []
for _ in range(nb_full_blocks):
q = torch.FloatTensor(ortho_group.rvs(nb_columns), device='cpu').to(device)
block_list.append(q)
remaining_rows = nb_rows - nb_full_blocks * nb_columns
if remaining_rows > 0:
q = torch.FloatTensor(ortho_group.rvs(nb_columns), device='cpu').to(device)
block_list.append(q[:remaining_rows])
final_matrix = torch.cat(block_list)
if scaling == 0:
multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(
dim=1)
elif scaling == 1:
multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,),
device=device)
else:
raise ValueError(f'Invalid scaling {scaling}')
return torch.diag(multiplier) @ final_matrix
# linear attention classes with softmax kernel
# non-causal linear attention
def linear_attention(q, k, v):
k_cumsum = k.sum(dim=-2)
D_inv = 1. / torch.einsum('...nd,...d->...n', q, k_cumsum.type_as(q))
context = torch.einsum('...nd,...ne->...de', k, v)
out = torch.einsum('...de,...nd,...n->...ne', context, q, D_inv)
return out
class FastAttention(nn.Module):
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0,
generalized_attention=False, kernel_fn=nn.ReLU(),
no_projection=False):
super().__init__()
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
self.dim_heads = dim_heads
self.nb_features = nb_features
self.ortho_scaling = ortho_scaling
self.create_projection = partial(gaussian_orthogonal_random_matrix,
nb_rows=self.nb_features,
nb_columns=dim_heads,
scaling=ortho_scaling)
projection_matrix = self.create_projection()
self.register_buffer('projection_matrix', projection_matrix)
self.generalized_attention = generalized_attention
self.kernel_fn = kernel_fn
# if this is turned on, no projection will be used
# queries and keys will be softmax-ed as in the original efficient attention paper
self.no_projection = no_projection
@torch.no_grad()
def redraw_projection_matrix(self, device):
projections = self.create_projection(device=device)
self.projection_matrix.copy_(projections)
del projections
def forward(self, q, k, v):
device = q.device
if self.no_projection:
q = q.softmax(dim=-1)
k = k.softmax(dim=-2)
elif self.generalized_attention:
create_kernel = partial(generalized_kernel,
kernel_fn=self.kernel_fn,
projection_matrix=self.projection_matrix,
device=device)
q, k = map(create_kernel, (q, k))
else:
create_kernel = partial(softmax_kernel,
projection_matrix=self.projection_matrix,
device=device)
q = create_kernel(q, is_query=True)
k = create_kernel(k, is_query=False)
out = linear_attention(q, k, v)
return out
class PerformerSelfAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., nb_features=None,
generalized_attention=False,
kernel_fn=nn.ReLU(), no_projection=False):
super().__init__()
assert dim % num_heads == 0, 'dimension must be divisible by number of heads'
head_dim = dim // num_heads
self.fast_attention = FastAttention(
head_dim, nb_features, generalized_attention=generalized_attention,
kernel_fn=kernel_fn, no_projection=no_projection
)
self.num_heads = num_heads
self.scale = qk_scale or head_dim ** -0.5 # not used in performer
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, nx=None, ny=None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
x = self.fast_attention(q, k, v)
x = rearrange(x, 'b h n d -> b n (h d)')
x = self.proj(x)
x = self.proj_drop(x)
return x