chore(scripts): Improves the modeling_codegen_flash implementation to support xFormers.

This commit is contained in:
Gustavo Rosa 2023-04-03 16:21:59 -03:00
Родитель e936de854b
Коммит 6d106ae478
5 изменённых файлов: 101 добавлений и 27 удалений

Просмотреть файл

@ -7,15 +7,17 @@ from typing import Dict, Optional
import torch
import torch.nn as nn
from flash_attn.modules.mha import MHA
from flash_attn.modules.mlp import FusedMLP, Mlp
from transformers.activations import ACT2FN
from flash_attn.modules.mlp import FusedMLP
from transformers.modeling_outputs import CausalLMOutput
from transformers.models.codegen.configuration_codegen import CodeGenConfig
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenMLP,
CodeGenPreTrainedModel,
apply_rotary_pos_emb,
fixed_pos_embedding,
)
from xformers.ops import LowerTriangularMask, memory_efficient_attention
class CodeGenFlashConfig(CodeGenConfig):
@ -25,15 +27,20 @@ class CodeGenFlashConfig(CodeGenConfig):
self,
*args,
pad_vocab_size_multiple: Optional[int] = 1,
use_flash_attn: Optional[bool] = False,
use_flash_fused_mlp: Optional[bool] = False,
attn_type: Optional[str] = "default",
use_fused_mlp: Optional[bool] = False,
**kwargs
) -> None:
super().__init__(*args, **kwargs)
self.vocab_size = int(math.ceil(self.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
self.use_flash_attn = use_flash_attn
self.use_flash_fused_mlp = use_flash_fused_mlp
assert attn_type in [
"default",
"flash",
"xformer",
], "`attn_type` should be one of: `default`, `flash` or `xformer`."
self.attn_type = attn_type
self.use_fused_mlp = use_fused_mlp
class CodeGenFlashEmbedding(nn.Module):
@ -53,6 +60,63 @@ class CodeGenFlashEmbedding(nn.Module):
return hidden_states
class CodeGenXAttention(CodeGenAttention):
def __init__(self, config):
super().__init__(config)
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states)
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
mp_num = 4
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
seq_len = key.shape[1]
offset = 0
if self.rotary_dim is not None:
k_rot = key[:, :, :, : self.rotary_dim]
k_pass = key[:, :, :, self.rotary_dim :]
q_rot = query[:, :, :, : self.rotary_dim]
q_pass = query[:, :, :, self.rotary_dim :]
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
key = torch.cat([k_rot, k_pass], dim=-1)
query = torch.cat([q_rot, q_pass], dim=-1)
else:
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
key = apply_rotary_pos_emb(key, sincos, offset=offset)
query = apply_rotary_pos_emb(query, sincos, offset=offset)
# compute self-attention: V x Softmax(QK^T)
attn_output = memory_efficient_attention(
query.to(torch.float16),
key.to(torch.float16),
value.to(torch.float16),
attn_bias=LowerTriangularMask(),
)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
return attn_output
class CodeGenFlashBlock(nn.Module):
def __init__(self, config: CodeGenFlashConfig) -> None:
super().__init__()
@ -60,14 +124,14 @@ class CodeGenFlashBlock(nn.Module):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
self.use_flash_attn = config.use_flash_attn
self.attn_type = config.attn_type
self.use_fused_mlp = config.use_fused_mlp
self.resid_pdrop = config.resid_pdrop
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
if not self.use_flash_attn:
if self.attn_type == "default":
self.attn = CodeGenAttention(config)
self.mlp = CodeGenMLP(inner_dim, config)
else:
elif self.attn_type == "flash":
head_dim = config.n_embd // config.n_head
self.attn = MHA(
embed_dim=config.n_embd,
@ -82,16 +146,16 @@ class CodeGenFlashBlock(nn.Module):
use_flash_attn=True,
return_residual=False,
)
elif self.attn_type == "xformer":
self.attn = CodeGenXAttention(config)
if not config.use_flash_fused_mlp:
self.mlp = Mlp(
in_features=config.n_embd, hidden_features=inner_dim, activation=ACT2FN[config.activation_function]
)
else:
activation = (
"gelu_approx" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] else "relu"
)
self.mlp = FusedMLP(in_features=config.n_embd, hidden_features=inner_dim, activation=activation)
if not self.use_fused_mlp:
self.mlp = CodeGenMLP(inner_dim, config)
else:
activation = (
"gelu_approx" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] else "relu"
)
self.mlp = FusedMLP(in_features=config.n_embd, hidden_features=inner_dim, activation=activation)
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
residual = hidden_states
@ -103,7 +167,7 @@ class CodeGenFlashBlock(nn.Module):
feed_forward_hidden_states = self.mlp(hidden_states)
if self.use_flash_attn:
if self.attn_type == "flash":
attn_outputs = nn.Dropout(self.resid_pdrop)(attn_outputs)
feed_forward_hidden_states = nn.Dropout(self.resid_pdrop)(feed_forward_hidden_states)

Просмотреть файл

@ -100,8 +100,8 @@ if __name__ == "__main__":
n_head=16,
rotary_dim=32,
pad_vocab_size_multiple=64,
use_flash_attn=True,
use_flash_fused_mlp=True,
attn_type="flash",
use_fused_mlp=True,
)
model = CodeGenFlashSequential(config)

Просмотреть файл

@ -96,8 +96,8 @@ if __name__ == "__main__":
n_head=16,
rotary_dim=32,
pad_vocab_size_multiple=64,
use_flash_attn=True,
use_flash_fused_mlp=True,
attn_type="flash",
use_fused_mlp=True,
)
model = CodeGenFlashSequential(config)

Просмотреть файл

@ -4,6 +4,7 @@
[flake8]
max-line-length = 120
extend-ignore = E111, E203, E402, E501, E721, E722, E741, F401, F403, F405, F407, W503, W504, W605
max-complexity = 12
[isort]
profile = black

Просмотреть файл

@ -58,6 +58,7 @@ dependencies = [
"torchvision",
"tqdm",
"transformers>=4.27.1",
"xformers",
]
dependencies_dict = {y: x for x, y in (re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in dependencies)}
@ -78,7 +79,9 @@ extras_require["cv"] = filter_dependencies(
"scikit-learn",
"torchvision",
)
extras_require["nlp"] = filter_dependencies("datasets", "einops", "opt_einsum", "tokenizers", "transformers")
extras_require["nlp"] = filter_dependencies(
"datasets", "einops", "opt_einsum", "tokenizers", "transformers", "xformers"
)
extras_require["deepspeed"] = filter_dependencies("deepspeed", "mlflow")
extras_require["flash-attn"] = filter_dependencies("flash-attn", "fftconv")
@ -118,10 +121,16 @@ extras_require["aml"] = filter_dependencies(
"mldesigner",
"mlflow",
"pytorch-lightning",
"torchvision"
"torchvision",
)
extras_require["dev"] = extras_require["cv"] + extras_require["nlp"] + extras_require["docs"] + extras_require["tests"] + extras_require["aml"]
extras_require["dev"] = (
extras_require["cv"]
+ extras_require["nlp"]
+ extras_require["docs"]
+ extras_require["tests"]
+ extras_require["aml"]
)
if os.name != "nt":
# Support for DeepSpeed is not available on native Windows
extras_require["dev"] += extras_require["deepspeed"]