зеркало из https://github.com/microsoft/archai.git
chore(scripts): Improves the modeling_codegen_flash implementation to support xFormers.
This commit is contained in:
Родитель
e936de854b
Коммит
6d106ae478
|
@ -7,15 +7,17 @@ from typing import Dict, Optional
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
from flash_attn.modules.mha import MHA
|
||||
from flash_attn.modules.mlp import FusedMLP, Mlp
|
||||
from transformers.activations import ACT2FN
|
||||
from flash_attn.modules.mlp import FusedMLP
|
||||
from transformers.modeling_outputs import CausalLMOutput
|
||||
from transformers.models.codegen.configuration_codegen import CodeGenConfig
|
||||
from transformers.models.codegen.modeling_codegen import (
|
||||
CodeGenAttention,
|
||||
CodeGenMLP,
|
||||
CodeGenPreTrainedModel,
|
||||
apply_rotary_pos_emb,
|
||||
fixed_pos_embedding,
|
||||
)
|
||||
from xformers.ops import LowerTriangularMask, memory_efficient_attention
|
||||
|
||||
|
||||
class CodeGenFlashConfig(CodeGenConfig):
|
||||
|
@ -25,15 +27,20 @@ class CodeGenFlashConfig(CodeGenConfig):
|
|||
self,
|
||||
*args,
|
||||
pad_vocab_size_multiple: Optional[int] = 1,
|
||||
use_flash_attn: Optional[bool] = False,
|
||||
use_flash_fused_mlp: Optional[bool] = False,
|
||||
attn_type: Optional[str] = "default",
|
||||
use_fused_mlp: Optional[bool] = False,
|
||||
**kwargs
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.vocab_size = int(math.ceil(self.vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple)
|
||||
self.use_flash_attn = use_flash_attn
|
||||
self.use_flash_fused_mlp = use_flash_fused_mlp
|
||||
assert attn_type in [
|
||||
"default",
|
||||
"flash",
|
||||
"xformer",
|
||||
], "`attn_type` should be one of: `default`, `flash` or `xformer`."
|
||||
self.attn_type = attn_type
|
||||
self.use_fused_mlp = use_fused_mlp
|
||||
|
||||
|
||||
class CodeGenFlashEmbedding(nn.Module):
|
||||
|
@ -53,6 +60,63 @@ class CodeGenFlashEmbedding(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
class CodeGenXAttention(CodeGenAttention):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
|
||||
new_shape = tensor.size()[:-2] + (num_attention_heads * attn_head_size,)
|
||||
return tensor.view(new_shape)
|
||||
|
||||
def forward(self, hidden_states: Optional[torch.FloatTensor]) -> torch.Tensor:
|
||||
qkv = self.qkv_proj(hidden_states)
|
||||
# TODO(enijkamp): factor out number of logical TPU-v4 cores or make forward pass agnostic
|
||||
mp_num = 4
|
||||
qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
|
||||
|
||||
local_dim = self.head_dim * self.num_attention_heads // mp_num
|
||||
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
|
||||
|
||||
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
||||
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
||||
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
|
||||
|
||||
seq_len = key.shape[1]
|
||||
offset = 0
|
||||
|
||||
if self.rotary_dim is not None:
|
||||
k_rot = key[:, :, :, : self.rotary_dim]
|
||||
k_pass = key[:, :, :, self.rotary_dim :]
|
||||
|
||||
q_rot = query[:, :, :, : self.rotary_dim]
|
||||
q_pass = query[:, :, :, self.rotary_dim :]
|
||||
|
||||
sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
|
||||
k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
|
||||
q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
|
||||
|
||||
key = torch.cat([k_rot, k_pass], dim=-1)
|
||||
query = torch.cat([q_rot, q_pass], dim=-1)
|
||||
else:
|
||||
sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
|
||||
key = apply_rotary_pos_emb(key, sincos, offset=offset)
|
||||
query = apply_rotary_pos_emb(query, sincos, offset=offset)
|
||||
|
||||
# compute self-attention: V x Softmax(QK^T)
|
||||
attn_output = memory_efficient_attention(
|
||||
query.to(torch.float16),
|
||||
key.to(torch.float16),
|
||||
value.to(torch.float16),
|
||||
attn_bias=LowerTriangularMask(),
|
||||
)
|
||||
|
||||
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
|
||||
attn_output = self.out_proj(attn_output)
|
||||
attn_output = self.resid_dropout(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class CodeGenFlashBlock(nn.Module):
|
||||
def __init__(self, config: CodeGenFlashConfig) -> None:
|
||||
super().__init__()
|
||||
|
@ -60,14 +124,14 @@ class CodeGenFlashBlock(nn.Module):
|
|||
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.n_embd
|
||||
rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
|
||||
|
||||
self.use_flash_attn = config.use_flash_attn
|
||||
self.attn_type = config.attn_type
|
||||
self.use_fused_mlp = config.use_fused_mlp
|
||||
self.resid_pdrop = config.resid_pdrop
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
if not self.use_flash_attn:
|
||||
if self.attn_type == "default":
|
||||
self.attn = CodeGenAttention(config)
|
||||
self.mlp = CodeGenMLP(inner_dim, config)
|
||||
else:
|
||||
elif self.attn_type == "flash":
|
||||
head_dim = config.n_embd // config.n_head
|
||||
self.attn = MHA(
|
||||
embed_dim=config.n_embd,
|
||||
|
@ -82,16 +146,16 @@ class CodeGenFlashBlock(nn.Module):
|
|||
use_flash_attn=True,
|
||||
return_residual=False,
|
||||
)
|
||||
elif self.attn_type == "xformer":
|
||||
self.attn = CodeGenXAttention(config)
|
||||
|
||||
if not config.use_flash_fused_mlp:
|
||||
self.mlp = Mlp(
|
||||
in_features=config.n_embd, hidden_features=inner_dim, activation=ACT2FN[config.activation_function]
|
||||
)
|
||||
else:
|
||||
activation = (
|
||||
"gelu_approx" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] else "relu"
|
||||
)
|
||||
self.mlp = FusedMLP(in_features=config.n_embd, hidden_features=inner_dim, activation=activation)
|
||||
if not self.use_fused_mlp:
|
||||
self.mlp = CodeGenMLP(inner_dim, config)
|
||||
else:
|
||||
activation = (
|
||||
"gelu_approx" if config.activation_function in ["gelu_new", "gelu_fast", "gelu_approx"] else "relu"
|
||||
)
|
||||
self.mlp = FusedMLP(in_features=config.n_embd, hidden_features=inner_dim, activation=activation)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
||||
residual = hidden_states
|
||||
|
@ -103,7 +167,7 @@ class CodeGenFlashBlock(nn.Module):
|
|||
|
||||
feed_forward_hidden_states = self.mlp(hidden_states)
|
||||
|
||||
if self.use_flash_attn:
|
||||
if self.attn_type == "flash":
|
||||
attn_outputs = nn.Dropout(self.resid_pdrop)(attn_outputs)
|
||||
feed_forward_hidden_states = nn.Dropout(self.resid_pdrop)(feed_forward_hidden_states)
|
||||
|
||||
|
|
|
@ -100,8 +100,8 @@ if __name__ == "__main__":
|
|||
n_head=16,
|
||||
rotary_dim=32,
|
||||
pad_vocab_size_multiple=64,
|
||||
use_flash_attn=True,
|
||||
use_flash_fused_mlp=True,
|
||||
attn_type="flash",
|
||||
use_fused_mlp=True,
|
||||
)
|
||||
model = CodeGenFlashSequential(config)
|
||||
|
||||
|
|
|
@ -96,8 +96,8 @@ if __name__ == "__main__":
|
|||
n_head=16,
|
||||
rotary_dim=32,
|
||||
pad_vocab_size_multiple=64,
|
||||
use_flash_attn=True,
|
||||
use_flash_fused_mlp=True,
|
||||
attn_type="flash",
|
||||
use_fused_mlp=True,
|
||||
)
|
||||
model = CodeGenFlashSequential(config)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
[flake8]
|
||||
max-line-length = 120
|
||||
extend-ignore = E111, E203, E402, E501, E721, E722, E741, F401, F403, F405, F407, W503, W504, W605
|
||||
max-complexity = 12
|
||||
|
||||
[isort]
|
||||
profile = black
|
||||
|
|
15
setup.py
15
setup.py
|
@ -58,6 +58,7 @@ dependencies = [
|
|||
"torchvision",
|
||||
"tqdm",
|
||||
"transformers>=4.27.1",
|
||||
"xformers",
|
||||
]
|
||||
dependencies_dict = {y: x for x, y in (re.findall(r"^(([^!=<>~ ]+)(?:[!=<>~ ].*)?$)", x)[0] for x in dependencies)}
|
||||
|
||||
|
@ -78,7 +79,9 @@ extras_require["cv"] = filter_dependencies(
|
|||
"scikit-learn",
|
||||
"torchvision",
|
||||
)
|
||||
extras_require["nlp"] = filter_dependencies("datasets", "einops", "opt_einsum", "tokenizers", "transformers")
|
||||
extras_require["nlp"] = filter_dependencies(
|
||||
"datasets", "einops", "opt_einsum", "tokenizers", "transformers", "xformers"
|
||||
)
|
||||
|
||||
extras_require["deepspeed"] = filter_dependencies("deepspeed", "mlflow")
|
||||
extras_require["flash-attn"] = filter_dependencies("flash-attn", "fftconv")
|
||||
|
@ -118,10 +121,16 @@ extras_require["aml"] = filter_dependencies(
|
|||
"mldesigner",
|
||||
"mlflow",
|
||||
"pytorch-lightning",
|
||||
"torchvision"
|
||||
"torchvision",
|
||||
)
|
||||
|
||||
extras_require["dev"] = extras_require["cv"] + extras_require["nlp"] + extras_require["docs"] + extras_require["tests"] + extras_require["aml"]
|
||||
extras_require["dev"] = (
|
||||
extras_require["cv"]
|
||||
+ extras_require["nlp"]
|
||||
+ extras_require["docs"]
|
||||
+ extras_require["tests"]
|
||||
+ extras_require["aml"]
|
||||
)
|
||||
if os.name != "nt":
|
||||
# Support for DeepSpeed is not available on native Windows
|
||||
extras_require["dev"] += extras_require["deepspeed"]
|
||||
|
|
Загрузка…
Ссылка в новой задаче