Update docs for torch-directml 0.2.2 (#593)

* update docs for next torch-directml release

* Minor readme spacing issues

---------

Co-authored-by: Sheil Kumar <sheilk@microsoft.com>
Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
Sheil Kumar 2024-06-14 18:12:21 -07:00 коммит произвёл GitHub
Родитель 4d65cad0be
Коммит 372a622890
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: B5690EEEBB952194
32 изменённых файлов: 106714 добавлений и 633 удалений

Просмотреть файл

@ -22,6 +22,7 @@ For `torch-directml` samples find brief summaries below or explore the [cv](./cv
* [resnet50 - an image classification model](./cv/resnet50)
* [maskrcnn - an object detection model](./cv/objectDetection/maskrcnn/)
* [llm - a text generation and chatbot app supporting various language models](./llm/)
* [whisper - a general-purpose speech recognition model](./audio/whisper/)
## External Links

Просмотреть файл

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2022 OpenAI
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

Просмотреть файл

@ -0,0 +1,94 @@
# Speech Recognition with Whisper
This sample guides you on how to run OpenAI's automatic speech recognition (ASR) [Whisper model](https://github.com/openai/whisper/blob/main/README.md) with our DirectML-backend.
- [Setup](#setup)
- [About Whisper](#run-the-whisper-model)
- [Basic Settings](#basic-settings)
- [External Links](#external-links)
- [Model License](#model-license)
## About Whisper
The [OpenAI Whisper](https://github.com/openai/whisper/) model is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
Whisper supports five model sizes, four with English-only versions and all five with multilingual versions.
| Size | Parameters | English-only model | Multilingual model | Required VRAM
|:---------:|:----------:|:------------------:|:------------------:|:-------------:|
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB |
| base | 74 M | `base.en` | `base` | ~1 GB |
| small | 244 M | `small.en` | `small` | ~2 GB |
| medium | 769 M | `medium.en` | `medium` | ~5 GB |
| large v3 | 1550 M | N/A | `large-v3` | ~10 GB |
For more information on the model, please refer to the [OpenAI Whisper GitHub repo](https://github.com/openai/whisper/).
## Setup
Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running the app:
```
conda install ffmpeg
pip install -r requirements.txt
```
## Run the Whisper model
Run Whisper with DirectML backend with a sample audio file with the following command:
```bash
python run.py --input_file <audio_file> --model_size "tiny.en"
```
For example, you should see the result output as below:
```
> python run.py --input_file test/samples_jfk.wav --model_size "tiny.en"
100%|█████████████████████████████████████| 72.1M/72.1M [00:09<00:00, 7.90MiB/s]
test/samples_jfk.wav
And so my fellow Americans ask not what your country can do for you ask what you can do for your country.
```
Note, by default [OpenAI Whisper](https://github.com/openai/whisper/) uses a naive implementation for the scaled dot product attention. If you want to improve performance further to leverage DirectML's scaled dot product attention, execute `run.py` with `--use_dml_attn` flag:
```bash
python run.py --input_file <audio_file> --model_size "tiny.en" --use_dml_attn
```
Based on this flag `MultiHeadAttention` module in `model.py` would choose between naive whisper scaled dot product attention and DirectML's scaled dot product attention:
```python
if use_dml_attn:
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention)
else:
wv, qk = self.qkv_attention(q, k, v, mask)
```
## Basic Settings
Following is a list of the basic settings supported by `run.py`:
| Flag | Description | Default |
| ---------------------- | ------------------------------------------------------------ | ------- |
| `--help` | Show this help message. | - |
| `--input_file` | [Required] Path to input audio file | - |
| `--model_size` | Size of Whisper model to use. Options: [`tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, `large-v3`] | `tiny.en` |
| `--fp16` | Runs inference with fp16 precision. | True |
| `--use_dml_attn` | Runs inference with DirectML Scaled dot product attention impl. | False |
## External Links
- [Whisper Base Hugging Face Repository](https://huggingface.co/openai/whisper-base.en)
- [Whisper Tiny Hugging Face Repository](https://huggingface.co/openai/whisper-tiny.en)
- [Whisper Small Hugging Face Repository](https://huggingface.co/openai/whisper-small.en)
- [Whisper Medium Hugging Face Repository](https://huggingface.co/openai/whisper-medium.en)
- [Whisper Large v3 Hugging Face Repository](https://huggingface.co/openai/whisper-large-v3)
- [Whisper GitHub Repo](https://github.com/openai/whisper)
## Model License
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.

Просмотреть файл

@ -0,0 +1,6 @@
numba
numpy
tqdm
more-itertools
tiktoken
ffmpeg-python

Просмотреть файл

@ -0,0 +1,45 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import whisper
import torch_directml
import argparse
def main(args):
device = torch_directml.device(torch_directml.default_device())
model = whisper.load_model(args.model_size, device=device, use_dml_attn=args.use_dml_attn)
# Load audio and pad/trim it to fit 30 seconds
audio = whisper.load_audio(args.input_file)
audio = whisper.pad_or_trim(audio)
n_mels = 80
if args.model_size == "large-v3":
n_mels = 128
mel = whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device)
language = "en"
if "en" not in args.model_size:
_, probs = model.detect_language(mel)
language = max(probs, key=probs.get)
print(f"Detected language: {language}")
options = whisper.DecodingOptions(language=language, fp16=args.fp16)
result = whisper.decode(model, mel, options)
print(result.text)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Whisper model on specified audio file with warmup.')
parser.add_argument('--model_size', type=str, default='tiny.en', help='Size of the Whisper model to use.')
parser.add_argument('--input_file', type=str, required=True, help='Path to the input audio file.')
parser.add_argument('--fp16', action="store_true", help='Runs inference with fp16 precision.')
parser.add_argument('--use_dml_attn', action="store_true", help='Use DirectML attention implementation.')
args = parser.parse_args()
main(args)

Двоичные данные
PyTorch/audio/whisper/test/samples_jfk.wav Normal file

Двоичный файл не отображается.

Просмотреть файл

@ -0,0 +1,160 @@
import hashlib
import io
import os
import urllib
import warnings
from typing import List, Optional, Union
import torch
from tqdm import tqdm
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
# from .version import __version__
_MODELS = {
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
}
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
_ALIGNMENT_HEADS = {
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
}
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
os.makedirs(root, exist_ok=True)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, os.path.basename(url))
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
with open(download_target, "rb") as f:
model_bytes = f.read()
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
return model_bytes if in_memory else download_target
else:
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
model_bytes = open(download_target, "rb").read()
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
raise RuntimeError(
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
)
return model_bytes if in_memory else download_target
def available_models() -> List[str]:
"""Returns the names of available models"""
return list(_MODELS.keys())
def load_model(
name: str,
device: Optional[Union[str, torch.device]] = None,
download_root: str = None,
in_memory: bool = False,
use_dml_attn: bool = False,
) -> Whisper:
"""
Load a Whisper ASR model
Parameters
----------
name : str
one of the official model names listed by `whisper.available_models()`, or
path to a model checkpoint containing the model dimensions and the model state_dict.
device : Union[str, torch.device]
the PyTorch device to put the model into
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"
in_memory: bool
whether to preload the model weights into host memory
Returns
-------
model : Whisper
The Whisper ASR model instance
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if download_root is None:
default = os.path.join(os.path.expanduser("~"), ".cache")
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
if name in _MODELS:
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
alignment_heads = _ALIGNMENT_HEADS[name]
elif os.path.isfile(name):
checkpoint_file = open(name, "rb").read() if in_memory else name
alignment_heads = None
else:
raise RuntimeError(
f"Model {name} not found; available models = {available_models()}"
)
# with (
# io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
# ) as fp:
# # checkpoint = torch.load(fp, map_location=device)
# checkpoint = torch.load(fp, mmap=True, weights_only=True)
# del checkpoint_file
checkpoint = torch.load(checkpoint_file, mmap=True, weights_only=True)
dims = ModelDimensions(**checkpoint["dims"])
model = Whisper(dims, use_dml_attn=use_dml_attn)
model.load_state_dict(checkpoint["model_state_dict"])
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)
return model.to(device)

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Двоичные данные
PyTorch/audio/whisper/whisper/assets/mel_filters.npz Normal file

Двоичный файл не отображается.

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,158 @@
import os
from functools import lru_cache
from subprocess import CalledProcessError, run
from typing import Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from .utils import exact_div
# hard-coded audio hyperparameters
SAMPLE_RATE = 16000
N_FFT = 400
HOP_LENGTH = 160
CHUNK_LENGTH = 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
def load_audio(file: str, sr: int = SAMPLE_RATE):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
# This launches a subprocess to decode audio while down-mixing
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
# fmt: off
print(file)
cmd = [
"ffmpeg",
"-nostdin",
"-threads", "0",
"-i", file,
"-f", "s16le",
"-ac", "1",
"-acodec", "pcm_s16le",
"-ar", str(sr),
"-"
]
# fmt: on
try:
out = run(cmd, capture_output=True, check=True).stdout
except CalledProcessError as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
if torch.is_tensor(array):
if array.shape[axis] > length:
array = array.index_select(
dim=axis, index=torch.arange(length, device=array.device)
)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
@lru_cache(maxsize=None)
def mel_filters(device, n_mels: int) -> torch.Tensor:
"""
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
Allows decoupling librosa dependency; saved using:
np.savez_compressed(
"mel_filters.npz",
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
)
"""
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
with np.load(filters_path, allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(
audio: Union[str, np.ndarray, torch.Tensor],
n_mels: int = 80,
padding: int = 0,
device: Optional[Union[str, torch.device]] = None,
):
"""
Compute the log-Mel spectrogram of
Parameters
----------
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
n_mels: int
The number of Mel-frequency filters, only 80 is supported
padding: int
Number of zero samples to pad to the right
device: Optional[Union[str, torch.device]]
If given, the audio tensor is moved to this device before STFT
Returns
-------
torch.Tensor, shape = (80, n_frames)
A Tensor that contains the Mel spectrogram
"""
if not torch.is_tensor(audio):
if isinstance(audio, str):
audio = load_audio(audio)
audio = torch.from_numpy(audio)
if device is not None:
audio = audio.to(device)
if padding > 0:
audio = F.pad(audio, (0, padding))
window = torch.hann_window(N_FFT).to(audio.device)
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
magnitudes = stft[..., :-1].abs() ** 2
filters = mel_filters(audio.device, n_mels)
mel_spec = filters @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0
return log_spec

Просмотреть файл

@ -0,0 +1,809 @@
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.distributions import Categorical
from .audio import CHUNK_LENGTH
from .tokenizer import Tokenizer, get_tokenizer
from .utils import compression_ratio
if TYPE_CHECKING:
from .model import Whisper
@torch.no_grad()
def detect_language(
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
) -> Tuple[Tensor, List[dict]]:
"""
Detect the spoken language in the audio, and return them as list of strings, along with the ids
of the most probable language tokens and the probability distribution over all language tokens.
This is performed outside the main decode loop in order to not interfere with kv-caching.
Returns
-------
language_tokens : Tensor, shape = (n_audio,)
ids of the most probable language tokens, which appears after the startoftranscript token.
language_probs : List[Dict[str, float]], length = n_audio
list of dictionaries containing the probability distribution over all languages.
"""
if tokenizer is None:
tokenizer = get_tokenizer(
model.is_multilingual, num_languages=model.num_languages
)
if (
tokenizer.language is None
or tokenizer.language_token not in tokenizer.sot_sequence
):
raise ValueError(
"This model doesn't have language tokens so it can't perform lang id"
)
single = mel.ndim == 2
if single:
mel = mel.unsqueeze(0)
# skip encoder forward pass if already-encoded audio features were given
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
mel = model.encoder(mel)
# forward pass using a single token, startoftranscript
n_audio = mel.shape[0]
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
logits = model.logits(x, mel)[:, 0]
# collect detected languages; suppress all non-language tokens
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_token_probs = logits.softmax(dim=-1).cpu()
language_probs = [
{
c: language_token_probs[i, j].item()
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
}
for i in range(n_audio)
]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@dataclass(frozen=True)
class DecodingOptions:
# whether to perform X->X "transcribe" or X->English "translate"
task: str = "transcribe"
# language that the audio is in; uses detected language if None
language: Optional[str] = None
# sampling-related options
temperature: float = 0.0
sample_len: Optional[int] = None # maximum number of tokens to sample
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
# "alpha" in Google NMT, or None for length norm, when ranking generations
# to select which to return among the beams or best-of-N samples
length_penalty: Optional[float] = None
# text or tokens to feed as the prompt or the prefix; for more info:
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
prompt: Optional[Union[str, List[int]]] = None # for the previous context
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
# list of tokens ids (or comma-separated token ids) to suppress
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
suppress_blank: bool = True # this will suppress blank outputs
# timestamp sampling options
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
max_initial_timestamp: Optional[float] = 1.0
# implementation details
fp16: bool = False # use fp16 for most of the calculation
@dataclass(frozen=True)
class DecodingResult:
audio_features: Tensor
language: str
language_probs: Optional[Dict[str, float]] = None
tokens: List[int] = field(default_factory=list)
text: str = ""
avg_logprob: float = np.nan
no_speech_prob: float = np.nan
temperature: float = np.nan
compression_ratio: float = np.nan
class Inference:
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
"""Perform a forward pass on the decoder and return per-token logits"""
raise NotImplementedError
def rearrange_kv_cache(self, source_indices) -> None:
"""Update the key-value cache according to the updated beams"""
raise NotImplementedError
def cleanup_caching(self) -> None:
"""Clean up any resources or hooks after decoding is finished"""
pass
class PyTorchInference(Inference):
def __init__(self, model: "Whisper", initial_token_length: int):
self.model: "Whisper" = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
key_modules = [block.attn.key for block in self.model.decoder.blocks]
value_modules = [block.attn.value for block in self.model.decoder.blocks]
self.kv_modules = key_modules + value_modules
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
if tokens.shape[-1] > self.initial_token_length:
# only need to use the last token except in the first forward pass
tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
class SequenceRanker:
def rank(
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
) -> List[int]:
"""
Given a list of groups of samples and their cumulative log probabilities,
return the indices of the samples in each group to select as the final result
"""
raise NotImplementedError
class MaximumLikelihoodRanker(SequenceRanker):
"""
Select the sample with the highest log probabilities, penalized using either
a simple length normalization or Google NMT paper's length penalty
"""
def __init__(self, length_penalty: Optional[float]):
self.length_penalty = length_penalty
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
if self.length_penalty is None:
penalty = length
else:
# from the Google NMT paper
penalty = ((5 + length) / 6) ** self.length_penalty
result.append(logprob / penalty)
return result
# get the sequence with the highest score
lengths = [[len(t) for t in s] for s in tokens]
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
class TokenDecoder:
def reset(self):
"""Initialize any stateful variables for decoding a new sequence"""
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
"""Specify how to select the next token, based on the current trace and logits
Parameters
----------
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
sum_logprobs : Tensor, shape = (n_batch)
cumulative log probabilities for each sequence
Returns
-------
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
the tokens, appended with the selected next token
completed : bool
True if all sequences has reached the end of text
"""
raise NotImplementedError
def finalize(
self, tokens: Tensor, sum_logprobs: Tensor
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
"""Finalize search and return the final candidate sequences
Parameters
----------
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence
sum_logprobs : Tensor, shape = (n_audio, n_group)
cumulative log probabilities for each sequence
Returns
-------
tokens : Sequence[Sequence[Tensor]], length = n_audio
sequence of Tensors containing candidate token sequences, for each audio input
sum_logprobs : List[List[float]], length = n_audio
sequence of cumulative log probabilities corresponding to the above
"""
raise NotImplementedError
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature: float, eot: int):
self.temperature = temperature
self.eot = eot
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if self.temperature == 0:
next_tokens = logits.argmax(dim=-1)
else:
next_tokens = Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
completed = (tokens[:, -1] == self.eot).all()
return tokens, completed
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
# make sure each sequence has at least one EOT token at the end
tokens = F.pad(tokens, (0, 1), value=self.eot)
return tokens, sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(
self,
beam_size: int,
eot: int,
inference: Inference,
patience: Optional[float] = None,
):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates: int = round(beam_size * self.patience)
self.finished_sequences = None
assert (
self.max_candidates > 0
), f"Invalid beam size ({beam_size}) or patience ({patience})"
def reset(self):
self.finished_sequences = None
def update(
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
) -> Tuple[Tensor, bool]:
if tokens.shape[0] % self.beam_size != 0:
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: # for the first update
self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
# STEP 1: calculate the cumulative log probabilities for possible candidates
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
new_logprob = (sum_logprobs[idx] + logprob).item()
sequence = tuple(prefix + [token.item()])
scores[sequence] = new_logprob
sources[sequence] = idx
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot:
finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size:
break
finished_sequences.append(finished)
tokens = torch.tensor(next_tokens, device=tokens.device)
# self.inference.rearrange_kv_cache(source_indices)
# add newly finished sequences to self.finished_sequences
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(
self.finished_sequences, finished_sequences
):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates:
break # the candidate list is full
previously_finished[seq] = newly_finished[seq]
# mark as completed if all audio has enough number of samples
completed = all(
len(sequences) >= self.max_candidates
for sequences in self.finished_sequences
)
return tokens, completed
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if (
len(sequences) < self.beam_size
): # when not enough sequences are finished
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size:
break
tokens: List[List[Tensor]] = [
[torch.tensor(seq) for seq in sequences.keys()]
for sequences in self.finished_sequences
]
sum_logprobs: List[List[float]] = [
list(sequences.values()) for sequences in self.finished_sequences
]
return tokens, sum_logprobs
class LogitFilter:
def apply(self, logits: Tensor, tokens: Tensor) -> None:
"""Apply any filtering or masking to logits in-place
Parameters
----------
logits : Tensor, shape = (n_batch, vocab_size)
per-token logits of the probability distribution at the current step
tokens : Tensor, shape = (n_batch, current_sequence_length)
all tokens in the context so far, including the prefix and sot_sequence tokens
"""
raise NotImplementedError
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits: Tensor, tokens: Tensor):
if tokens.shape[1] == self.sample_begin:
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens: Sequence[int]):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits: Tensor, tokens: Tensor):
logits[:, self.suppress_tokens] = -np.inf
class ApplyTimestampRules(LogitFilter):
def __init__(
self,
tokenizer: Tokenizer,
sample_begin: int,
max_initial_timestamp_index: Optional[int],
):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits: Tensor, tokens: Tensor):
# suppress <|notimestamps|> which is handled by without_timestamps
if self.tokenizer.no_timestamps is not None:
logits[:, self.tokenizer.no_timestamps] = -np.inf
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
)
penultimate_was_timestamp = (
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
)
if last_was_timestamp:
if penultimate_was_timestamp: # has to be non-timestamp
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: # cannot be normal text tokens
logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[
sampled_tokens.ge(self.tokenizer.timestamp_begin)
]
if timestamps.numel() > 0:
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
# also force each segment to have a nonzero length, to prevent infinite looping
if last_was_timestamp and not penultimate_was_timestamp:
timestamp_last = timestamps[-1]
else:
timestamp_last = timestamps[-1] + 1
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
if tokens.shape[1] == self.sample_begin:
# suppress generating non-timestamp tokens at the beginning
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
# apply the `max_initial_timestamp` option
if self.max_initial_timestamp_index is not None:
last_allowed = (
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
)
logits[:, last_allowed + 1 :] = -np.inf
# if sum of probability over timestamps is above any other token, sample timestamp
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
dim=-1
)
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
if timestamp_logprob > max_text_token_logprob:
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
inference: Inference
sequence_ranker: SequenceRanker
decoder: TokenDecoder
logit_filters: List[LogitFilter]
def __init__(self, model: "Whisper", options: DecodingOptions):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=options.task,
)
self.tokenizer: Tokenizer = tokenizer
self.options: DecodingOptions = self._verify_options(options)
self.n_group: int = options.beam_size or options.best_of or 1
self.n_ctx: int = model.dims.n_text_ctx
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
if self.options.without_timestamps:
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
self.sample_begin: int = len(self.initial_tokens)
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
# inference: implements the forward pass through the decoder, including kv caching
self.inference = PyTorchInference(model, len(self.initial_tokens))
# sequence ranker: implements how to rank a group of sampled sequences
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
# decoder: implements how to select the next tokens, given the autoregressive distribution
if options.beam_size is not None:
self.decoder = BeamSearchDecoder(
options.beam_size, tokenizer.eot, self.inference, options.patience
)
else:
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
# logit filters: applies various rules to suppress or penalize certain tokens
self.logit_filters = []
if self.options.suppress_blank:
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens:
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
max_initial_timestamp_index = None
if options.max_initial_timestamp:
max_initial_timestamp_index = round(
self.options.max_initial_timestamp / precision
)
self.logit_filters.append(
ApplyTimestampRules(
tokenizer, self.sample_begin, max_initial_timestamp_index
)
)
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
if options.beam_size is not None and options.best_of is not None:
raise ValueError("beam_size and best_of can't be given together")
if options.temperature == 0:
if options.best_of is not None:
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
if options.patience is not None and options.beam_size is None:
raise ValueError("patience requires beam_size to be given")
if options.length_penalty is not None and not (
0 <= options.length_penalty <= 1
):
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
return options
def _get_initial_tokens(self) -> Tuple[int]:
tokens = list(self.sot_sequence)
if prefix := self.options.prefix:
prefix_tokens = (
self.tokenizer.encode(" " + prefix.strip())
if isinstance(prefix, str)
else prefix
)
if self.sample_len is not None:
max_prefix_len = self.n_ctx // 2 - self.sample_len
prefix_tokens = prefix_tokens[-max_prefix_len:]
tokens = tokens + prefix_tokens
if prompt := self.options.prompt:
prompt_tokens = (
self.tokenizer.encode(" " + prompt.strip())
if isinstance(prompt, str)
else prompt
)
tokens = (
[self.tokenizer.sot_prev]
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
+ tokens
)
return tuple(tokens)
def _get_suppress_tokens(self) -> Tuple[int]:
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str):
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0:
suppress_tokens = [] # interpret empty string as an empty list
else:
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
suppress_tokens.extend(
[
self.tokenizer.transcribe,
self.tokenizer.translate,
self.tokenizer.sot,
self.tokenizer.sot_prev,
self.tokenizer.sot_lm,
]
)
if self.tokenizer.no_speech is not None:
# no-speech probability is collected separately
suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel: Tensor):
if self.options.fp16:
mel = mel.half()
if mel.shape[-2:] == (
self.model.dims.n_audio_ctx,
self.model.dims.n_audio_state,
):
# encoded audio features are given; skip audio encoding
audio_features = mel
else:
audio_features = self.model.encoder(mel)
if audio_features.dtype != (
torch.float16 if self.options.fp16 else torch.float32
):
return TypeError(
f"audio_features has an incorrect dtype: {audio_features.dtype}"
)
return audio_features
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(
audio_features, self.tokenizer
)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None:
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
return languages, lang_probs
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
n_batch = tokens.shape[0]
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if (
i == 0 and self.tokenizer.no_speech is not None
): # save no_speech_probs
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
# now we need to consider the logits at the last token only
logits = logits[:, -1]
# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
# expand the tokens tensor with the selected next tokens
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx:
break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
self.decoder.reset()
tokenizer: Tokenizer = self.tokenizer
n_audio: int = mel.shape[0]
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
# detect language if requested, overwriting the language token
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id":
return [
DecodingResult(
audio_features=features, language=language, language_probs=probs
)
for features, language, probs in zip(
audio_features, languages, language_probs
)
]
# repeat text tensors by the group size, for beam search or best-of-n sampling
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
# call the main sampling loop
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
# get the final candidates for each group, and slice between the first sampled token and EOT
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens: List[List[Tensor]] = [
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
for s in tokens
]
# select the top-ranked sample in each group
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
avg_logprobs: List[float] = [
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
]
fields = (
texts,
languages,
tokens,
audio_features,
avg_logprobs,
no_speech_probs,
)
if len(set(map(len, fields))) != 1:
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
return [
DecodingResult(
audio_features=features,
language=language,
tokens=tokens,
text=text,
avg_logprob=avg_logprob,
no_speech_prob=no_speech_prob,
temperature=self.options.temperature,
compression_ratio=compression_ratio(text),
)
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
*fields
)
]
@torch.no_grad()
def decode(
model: "Whisper",
mel: Tensor,
options: DecodingOptions = DecodingOptions(),
**kwargs,
) -> Union[DecodingResult, List[DecodingResult]]:
"""
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
Parameters
----------
model: Whisper
the Whisper model instance
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
A tensor containing the Mel spectrogram(s)
options: DecodingOptions
A dataclass that contains all necessary options for decoding 30-second segments
Returns
-------
result: Union[DecodingResult, List[DecodingResult]]
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
"""
if single := mel.ndim == 2:
mel = mel.unsqueeze(0)
if kwargs:
options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result

Просмотреть файл

@ -0,0 +1,364 @@
import base64
import gzip
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import time
import torch_directml
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from .decoding import decode as decode_function
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x: Tensor) -> Tensor:
return F.linear(
x,
self.weight.to(x.dtype),
None if self.bias is None else self.bias.to(x.dtype),
)
class Conv1d(nn.Conv1d):
def _conv_forward(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor:
return super()._conv_forward(
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
)
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
self.past_key_tensor = None
self.past_value_tensor = None
self.cross_key = None
self.cross_value = None
self.total_seq_len = 0
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
cross_attention=False,
use_dml_attn=False,
):
q = self.query(x)
if xa is None:
k = self.key(x)
v = self.value(x)
else:
cross_attention = True
if self.past_key_tensor is None:
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
self.cross_key = self.key(xa)
self.cross_value = self.value(xa)
k = self.cross_key
v = self.cross_value
if use_dml_attn:
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention)
else:
wv, qk = self.qkv_attention(q, k, v, mask)
return self.out(wv), qk
def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
_, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
def dml_sdp_attn(self, q, k, v, mask, cross_attention=False):
_, n_ctx, n_state = q.shape
if mask is not None:
self.total_seq_len += n_ctx
mask = mask.expand(-1, -1, n_ctx, self.total_seq_len)
# cross attention i.e. encoder/decoder attn uses same k and v but the query changes
if cross_attention:
y, self.past_key_tensor, self.past_value_tensor= torch_directml.multi_head_attention(
q, k, v, n_state, self.n_head, None, None, mask
)
else:
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
q, k, v, n_state, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
)
return y, None
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def forward(
self,
x: Tensor,
xa: Optional[Tensor] = None,
mask: Optional[Tensor] = None,
kv_cache: Optional[dict] = None,
use_dml_attn: bool = False,
):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, use_dml_attn=use_dml_attn)[0]
if self.cross_attn:
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, use_dml_attn=use_dml_attn)[0]
x = x + self.mlp(self.mlp_ln(x))
return x
class AudioEncoder(nn.Module):
def __init__(
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, use_dml_attn: bool = False,
):
super().__init__()
self.use_dml_attn = use_dml_attn
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
)
self.ln_post = LayerNorm(n_state)
def forward(self, x: Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
the mel spectrogram of the audio
"""
x = F.gelu(self.conv1(x))
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x, use_dml_attn=self.use_dml_attn)
x = self.ln_post(x)
return x
class TextDecoder(nn.Module):
def __init__(
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, use_dml_attn: bool = False,
):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.use_dml_attn = use_dml_attn
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
[
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
)
self.ln = LayerNorm(n_state)
if self.use_dml_attn:
mask = torch.ones([1, n_head, 1, 1], dtype=torch.int32)
else:
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
self.pos = 0
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
"""
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
the text tokens
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
the encoded audio features to be attended on
"""
offset = self.pos
self.pos += x.shape[1]
x = (
self.token_embedding(x)
+ self.positional_embedding[offset : offset + x.shape[-1]]
)
x = x.to(xa.dtype)
for i, block in enumerate(self.blocks):
x = block(x, xa, mask=self.mask, kv_cache=kv_cache, use_dml_attn=self.use_dml_attn)
x = self.ln(x)
logits = (
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
).float()
return logits
class Whisper(nn.Module):
def __init__(self, dims: ModelDimensions, use_dml_attn=False):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(
self.dims.n_mels,
self.dims.n_audio_ctx,
self.dims.n_audio_state,
self.dims.n_audio_head,
self.dims.n_audio_layer,
use_dml_attn
)
self.decoder = TextDecoder(
self.dims.n_vocab,
self.dims.n_text_ctx,
self.dims.n_text_state,
self.dims.n_text_head,
self.dims.n_text_layer,
use_dml_attn,
)
# use the last half among the decoder layers for time alignment by default;
# to use a specific set of heads, see `set_alignment_heads()` below.
all_heads = torch.zeros(
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads, persistent=False)
def set_alignment_heads(self, dump: bytes):
array = np.frombuffer(
gzip.decompress(base64.b85decode(dump)), dtype=bool
).copy()
mask = torch.from_numpy(array).reshape(
self.dims.n_text_layer, self.dims.n_text_head
)
self.register_buffer("alignment_heads", mask, persistent=False)
def embed_audio(self, mel: torch.Tensor):
return self.encoder(mel)
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
return self.decoder(tokens, audio_features)
def forward(
self, mel: torch.Tensor, tokens: torch.Tensor
) -> Dict[str, torch.Tensor]:
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
"""
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
tensors calculated for the previous positions. This method returns a dictionary that stores
all caches, and the necessary hooks for the key and value projection modules that save the
intermediate tensors to be reused during later calculations.
Returns
-------
cache : Dict[nn.Module, torch.Tensor]
A dictionary object mapping the key/value projection modules to its cache
hooks : List[RemovableHandle]
List of PyTorch RemovableHandle objects to stop the hooks to be called
"""
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
# save as-is, for the first token or cross attention
cache[module] = output
else:
cache[module] = torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
if self.decoder.use_dml_attn:
layer.total_seq_len = 0
layer.past_key_tensor = None
layer.past_value_tensor = None
else:
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
self.decoder.pos = 0
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function

Просмотреть файл

@ -0,0 +1,2 @@
from .basic import BasicTextNormalizer as BasicTextNormalizer
from .english import EnglishTextNormalizer as EnglishTextNormalizer

Просмотреть файл

@ -0,0 +1,76 @@
import re
import unicodedata
import regex
# non-ASCII letters that are not separated by "NFKD" normalization
ADDITIONAL_DIACRITICS = {
"œ": "oe",
"Œ": "OE",
"ø": "o",
"Ø": "O",
"æ": "ae",
"Æ": "AE",
"ß": "ss",
"": "SS",
"đ": "d",
"Đ": "D",
"ð": "d",
"Ð": "D",
"þ": "th",
"Þ": "th",
"ł": "l",
"Ł": "L",
}
def remove_symbols_and_diacritics(s: str, keep=""):
"""
Replace any other markers, symbols, and punctuations with a space,
and drop any diacritics (category 'Mn' and some manual mappings)
"""
return "".join(
c
if c in keep
else ADDITIONAL_DIACRITICS[c]
if c in ADDITIONAL_DIACRITICS
else ""
if unicodedata.category(c) == "Mn"
else " "
if unicodedata.category(c)[0] in "MSP"
else c
for c in unicodedata.normalize("NFKD", s)
)
def remove_symbols(s: str):
"""
Replace any other markers, symbols, punctuations with a space, keeping diacritics
"""
return "".join(
" " if unicodedata.category(c)[0] in "MSP" else c
for c in unicodedata.normalize("NFKC", s)
)
class BasicTextNormalizer:
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
self.clean = (
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
)
self.split_letters = split_letters
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = self.clean(s).lower()
if self.split_letters:
s = " ".join(regex.findall(r"\X", s, regex.U))
s = re.sub(
r"\s+", " ", s
) # replace any successive whitespace characters with a space
return s

Разница между файлами не показана из-за своего большого размера Загрузить разницу

Просмотреть файл

@ -0,0 +1,550 @@
import json
import os
import re
from fractions import Fraction
from typing import Iterator, List, Match, Optional, Union
from more_itertools import windowed
from .basic import remove_symbols_and_diacritics
class EnglishNumberNormalizer:
"""
Convert any spelled-out numbers into arabic numbers, while handling:
- remove any commas
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- spell out `one` and `ones`
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
"""
def __init__(self):
super().__init__()
self.zeros = {"o", "oh", "zero"}
self.ones = {
name: i
for i, name in enumerate(
[
"one",
"two",
"three",
"four",
"five",
"six",
"seven",
"eight",
"nine",
"ten",
"eleven",
"twelve",
"thirteen",
"fourteen",
"fifteen",
"sixteen",
"seventeen",
"eighteen",
"nineteen",
],
start=1,
)
}
self.ones_plural = {
"sixes" if name == "six" else name + "s": (value, "s")
for name, value in self.ones.items()
}
self.ones_ordinal = {
"zeroth": (0, "th"),
"first": (1, "st"),
"second": (2, "nd"),
"third": (3, "rd"),
"fifth": (5, "th"),
"twelfth": (12, "th"),
**{
name + ("h" if name.endswith("t") else "th"): (value, "th")
for name, value in self.ones.items()
if value > 3 and value != 5 and value != 12
},
}
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
self.tens = {
"twenty": 20,
"thirty": 30,
"forty": 40,
"fifty": 50,
"sixty": 60,
"seventy": 70,
"eighty": 80,
"ninety": 90,
}
self.tens_plural = {
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
}
self.tens_ordinal = {
name.replace("y", "ieth"): (value, "th")
for name, value in self.tens.items()
}
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
self.multipliers = {
"hundred": 100,
"thousand": 1_000,
"million": 1_000_000,
"billion": 1_000_000_000,
"trillion": 1_000_000_000_000,
"quadrillion": 1_000_000_000_000_000,
"quintillion": 1_000_000_000_000_000_000,
"sextillion": 1_000_000_000_000_000_000_000,
"septillion": 1_000_000_000_000_000_000_000_000,
"octillion": 1_000_000_000_000_000_000_000_000_000,
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
}
self.multipliers_plural = {
name + "s": (value, "s") for name, value in self.multipliers.items()
}
self.multipliers_ordinal = {
name + "th": (value, "th") for name, value in self.multipliers.items()
}
self.multipliers_suffixed = {
**self.multipliers_plural,
**self.multipliers_ordinal,
}
self.decimals = {*self.ones, *self.tens, *self.zeros}
self.preceding_prefixers = {
"minus": "-",
"negative": "-",
"plus": "+",
"positive": "+",
}
self.following_prefixers = {
"pound": "£",
"pounds": "£",
"euro": "",
"euros": "",
"dollar": "$",
"dollars": "$",
"cent": "¢",
"cents": "¢",
}
self.prefixes = set(
list(self.preceding_prefixers.values())
+ list(self.following_prefixers.values())
)
self.suffixers = {
"per": {"cent": "%"},
"percent": "%",
}
self.specials = {"and", "double", "triple", "point"}
self.words = set(
[
key
for mapping in [
self.zeros,
self.ones,
self.ones_suffixed,
self.tens,
self.tens_suffixed,
self.multipliers,
self.multipliers_suffixed,
self.preceding_prefixers,
self.following_prefixers,
self.suffixers,
self.specials,
]
for key in mapping
]
)
self.literal_words = {"one", "ones"}
def process_words(self, words: List[str]) -> Iterator[str]:
prefix: Optional[str] = None
value: Optional[Union[str, int]] = None
skip = False
def to_fraction(s: str):
try:
return Fraction(s)
except ValueError:
return None
def output(result: Union[str, int]):
nonlocal prefix, value
result = str(result)
if prefix is not None:
result = prefix + result
value = None
prefix = None
return result
if len(words) == 0:
return
for prev, current, next in windowed([None] + words + [None], 3):
if skip:
skip = False
continue
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
has_prefix = current[0] in self.prefixes
current_without_prefix = current[1:] if has_prefix else current
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
# arabic numbers (potentially with signs and fractions)
f = to_fraction(current_without_prefix)
assert f is not None
if value is not None:
if isinstance(value, str) and value.endswith("."):
# concatenate decimals / ip address components
value = str(value) + str(current)
continue
else:
yield output(value)
prefix = current[0] if has_prefix else prefix
if f.denominator == 1:
value = f.numerator # store integers as int
else:
value = current_without_prefix
elif current not in self.words:
# non-numeric words
if value is not None:
yield output(value)
yield output(current)
elif current in self.zeros:
value = str(value or "") + "0"
elif current in self.ones:
ones = self.ones[current]
if value is None:
value = ones
elif isinstance(value, str) or prev in self.ones:
if (
prev in self.tens and ones < 10
): # replace the last zero with the digit
assert value[-1] == "0"
value = value[:-1] + str(ones)
else:
value = str(value) + str(ones)
elif ones < 10:
if value % 10 == 0:
value += ones
else:
value = str(value) + str(ones)
else: # eleven to nineteen
if value % 100 == 0:
value += ones
else:
value = str(value) + str(ones)
elif current in self.ones_suffixed:
# ordinal or cardinal; yield the number right away
ones, suffix = self.ones_suffixed[current]
if value is None:
yield output(str(ones) + suffix)
elif isinstance(value, str) or prev in self.ones:
if prev in self.tens and ones < 10:
assert value[-1] == "0"
yield output(value[:-1] + str(ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
elif ones < 10:
if value % 10 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
else: # eleven to nineteen
if value % 100 == 0:
yield output(str(value + ones) + suffix)
else:
yield output(str(value) + str(ones) + suffix)
value = None
elif current in self.tens:
tens = self.tens[current]
if value is None:
value = tens
elif isinstance(value, str):
value = str(value) + str(tens)
else:
if value % 100 == 0:
value += tens
else:
value = str(value) + str(tens)
elif current in self.tens_suffixed:
# ordinal or cardinal; yield the number right away
tens, suffix = self.tens_suffixed[current]
if value is None:
yield output(str(tens) + suffix)
elif isinstance(value, str):
yield output(str(value) + str(tens) + suffix)
else:
if value % 100 == 0:
yield output(str(value + tens) + suffix)
else:
yield output(str(value) + str(tens) + suffix)
elif current in self.multipliers:
multiplier = self.multipliers[current]
if value is None:
value = multiplier
elif isinstance(value, str) or value == 0:
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
value = p.numerator
else:
yield output(value)
value = multiplier
else:
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
elif current in self.multipliers_suffixed:
multiplier, suffix = self.multipliers_suffixed[current]
if value is None:
yield output(str(multiplier) + suffix)
elif isinstance(value, str):
f = to_fraction(value)
p = f * multiplier if f is not None else None
if f is not None and p.denominator == 1:
yield output(str(p.numerator) + suffix)
else:
yield output(value)
yield output(str(multiplier) + suffix)
else: # int
before = value // 1000 * 1000
residual = value % 1000
value = before + residual * multiplier
yield output(str(value) + suffix)
value = None
elif current in self.preceding_prefixers:
# apply prefix (positive, minus, etc.) if it precedes a number
if value is not None:
yield output(value)
if next in self.words or next_is_numeric:
prefix = self.preceding_prefixers[current]
else:
yield output(current)
elif current in self.following_prefixers:
# apply prefix (dollars, cents, etc.) only after a number
if value is not None:
prefix = self.following_prefixers[current]
yield output(value)
else:
yield output(current)
elif current in self.suffixers:
# apply suffix symbols (percent -> '%')
if value is not None:
suffix = self.suffixers[current]
if isinstance(suffix, dict):
if next in suffix:
yield output(str(value) + suffix[next])
skip = True
else:
yield output(value)
yield output(current)
else:
yield output(str(value) + suffix)
else:
yield output(current)
elif current in self.specials:
if next not in self.words and not next_is_numeric:
# apply special handling only if the next word can be numeric
if value is not None:
yield output(value)
yield output(current)
elif current == "and":
# ignore "and" after hundreds, thousands, etc.
if prev not in self.multipliers:
if value is not None:
yield output(value)
yield output(current)
elif current == "double" or current == "triple":
if next in self.ones or next in self.zeros:
repeats = 2 if current == "double" else 3
ones = self.ones.get(next, 0)
value = str(value or "") + str(ones) * repeats
skip = True
else:
if value is not None:
yield output(value)
yield output(current)
elif current == "point":
if next in self.decimals or next_is_numeric:
value = str(value or "") + "."
else:
# should all have been covered at this point
raise ValueError(f"Unexpected token: {current}")
else:
# all should have been covered at this point
raise ValueError(f"Unexpected token: {current}")
if value is not None:
yield output(value)
def preprocess(self, s: str):
# replace "<number> and a half" with "<number> point five"
results = []
segments = re.split(r"\band\s+a\s+half\b", s)
for i, segment in enumerate(segments):
if len(segment.strip()) == 0:
continue
if i == len(segments) - 1:
results.append(segment)
else:
results.append(segment)
last_word = segment.rsplit(maxsplit=2)[-1]
if last_word in self.decimals or last_word in self.multipliers:
results.append("point five")
else:
results.append("and a half")
s = " ".join(results)
# put a space at number/letter boundary
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
# but remove spaces which could be a suffix
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
return s
def postprocess(self, s: str):
def combine_cents(m: Match):
try:
currency = m.group(1)
integer = m.group(2)
cents = int(m.group(3))
return f"{currency}{integer}.{cents:02d}"
except ValueError:
return m.string
def extract_cents(m: Match):
try:
return f"¢{int(m.group(1))}"
except ValueError:
return m.string
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
# write "one(s)" instead of "1(s)", just for the readability
s = re.sub(r"\b1(s?)\b", r"one\1", s)
return s
def __call__(self, s: str):
s = self.preprocess(s)
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
s = self.postprocess(s)
return s
class EnglishSpellingNormalizer:
"""
Applies British-American spelling mappings as listed in [1].
[1] https://www.tysto.com/uk-us-spelling-list.html
"""
def __init__(self):
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
self.mapping = json.load(open(mapping_path))
def __call__(self, s: str):
return " ".join(self.mapping.get(word, word) for word in s.split())
class EnglishTextNormalizer:
def __init__(self):
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
self.replacers = {
# common contractions
r"\bwon't\b": "will not",
r"\bcan't\b": "can not",
r"\blet's\b": "let us",
r"\bain't\b": "aint",
r"\by'all\b": "you all",
r"\bwanna\b": "want to",
r"\bgotta\b": "got to",
r"\bgonna\b": "going to",
r"\bi'ma\b": "i am going to",
r"\bimma\b": "i am going to",
r"\bwoulda\b": "would have",
r"\bcoulda\b": "could have",
r"\bshoulda\b": "should have",
r"\bma'am\b": "madam",
# contractions in titles/prefixes
r"\bmr\b": "mister ",
r"\bmrs\b": "missus ",
r"\bst\b": "saint ",
r"\bdr\b": "doctor ",
r"\bprof\b": "professor ",
r"\bcapt\b": "captain ",
r"\bgov\b": "governor ",
r"\bald\b": "alderman ",
r"\bgen\b": "general ",
r"\bsen\b": "senator ",
r"\brep\b": "representative ",
r"\bpres\b": "president ",
r"\brev\b": "reverend ",
r"\bhon\b": "honorable ",
r"\basst\b": "assistant ",
r"\bassoc\b": "associate ",
r"\blt\b": "lieutenant ",
r"\bcol\b": "colonel ",
r"\bjr\b": "junior ",
r"\bsr\b": "senior ",
r"\besq\b": "esquire ",
# prefect tenses, ideally it should be any past participles, but it's harder..
r"'d been\b": " had been",
r"'s been\b": " has been",
r"'d gone\b": " had gone",
r"'s gone\b": " has gone",
r"'d done\b": " had done", # "'s done" is ambiguous
r"'s got\b": " has got",
# general contractions
r"n't\b": " not",
r"'re\b": " are",
r"'s\b": " is",
r"'d\b": " would",
r"'ll\b": " will",
r"'t\b": " not",
r"'ve\b": " have",
r"'m\b": " am",
}
self.standardize_numbers = EnglishNumberNormalizer()
self.standardize_spellings = EnglishSpellingNormalizer()
def __call__(self, s: str):
s = s.lower()
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
s = re.sub(self.ignore_patterns, "", s)
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
for pattern, replacement in self.replacers.items():
s = re.sub(pattern, replacement, s)
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
s = self.standardize_numbers(s)
s = self.standardize_spellings(s)
# now remove prefix/suffix symbols that are not preceded/followed by numbers
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
s = re.sub(r"([^0-9])%", r"\1 ", s)
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
return s

Просмотреть файл

@ -0,0 +1,365 @@
import itertools
import subprocess
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, List
import numba
import numpy as np
import torch
import torch.nn.functional as F
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
from .tokenizer import Tokenizer
if TYPE_CHECKING:
from .model import Whisper
def median_filter(x: torch.Tensor, filter_width: int):
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
pad_width = filter_width // 2
if x.shape[-1] <= pad_width:
# F.pad requires the padding width to be smaller than the input dimension
return x
if (ndim := x.ndim) <= 2:
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
assert (
filter_width > 0 and filter_width % 2 == 1
), "`filter_width` should be an odd number"
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2:
result = result[0, 0]
return result
@numba.jit(nopython=True)
def backtrace(trace: np.ndarray):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1:
i -= 1
elif trace[i, j] == 2:
j -= 1
else:
raise ValueError("Unexpected trace[i, j]")
result = np.array(result)
return result[::-1, :].T
@numba.jit(nopython=True, parallel=True)
def dtw_cpu(x: np.ndarray):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2:
c, t = c0, 0
elif c1 < c0 and c1 < c2:
c, t = c1, 1
else:
c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw_cuda(x, BLOCK_SIZE=1024):
from .triton_ops import dtw_kernel
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
x_skew = (
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
cost = cost.cuda()
trace = torch.zeros_like(cost, dtype=torch.int32)
dtw_kernel[(1,)](
cost,
trace,
x_skew,
x_skew.stride(0),
cost.stride(0),
trace.stride(0),
N,
M,
BLOCK_SIZE=BLOCK_SIZE,
)
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
:, : N + 1
]
return backtrace(trace.cpu().numpy())
def dtw(x: torch.Tensor) -> np.ndarray:
return dtw_cpu(x.double().cpu().numpy())
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
if len(text_tokens) == 0:
return []
tokens = torch.tensor(
[
*tokenizer.sot_sequence,
tokenizer.no_timestamps,
*text_tokens,
tokenizer.eot,
]
).to(model.device)
# install hooks on the cross attention layers to retrieve the attention weights
QKs = [None] * model.dims.n_text_layer
hooks = [
block.cross_attn.register_forward_hook(
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
)
for i, block in enumerate(model.decoder.blocks)
]
with torch.no_grad():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
text_token_probs = text_token_probs.tolist()
for hook in hooks:
hook.remove()
# heads * tokens * frames
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
weights = weights[:, :, : num_frames // 2]
weights = (weights * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = (weights - mean) / std
weights = median_filter(weights, medfilt_width)
matrix = weights.mean(axis=0)
matrix = matrix[len(tokenizer.sot_sequence) : -1]
text_indices, time_indices = dtw(-matrix)
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1:
# return on eot only
# >>> np.pad([], (1, 0))
# array([0.])
# This results in crashes when we lookup jump_times with float, like
# IndexError: arrays used as indices must be of integer (or boolean) type
return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
np.mean(text_token_probs[i:j])
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return [
WordTiming(word, tokens, start, end, probability)
for word, tokens, start, end, probability in zip(
words, word_tokens, start_times, end_times, word_probabilities
)
]
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
# merge prepended punctuations
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
# prepend it to the following word
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else:
j = i
i -= 1
# merge appended punctuations
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
# append it to the previous word
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else:
i = j
j += 1
def add_word_timestamps(
*,
segments: List[dict],
model: "Whisper",
tokenizer: Tokenizer,
mel: torch.Tensor,
num_frames: int,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
last_speech_timestamp: float,
**kwargs,
):
if len(segments) == 0:
return
text_tokens_per_segment = [
[token for token in segment["tokens"] if token < tokenizer.eot]
for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
median_duration = min(0.7, float(median_duration))
max_duration = median_duration * 2
# hack: truncate long words at sentence boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(word_durations) > 0:
sentence_end_marks = ".。!?"
# ensure words at sentence boundaries are not longer than twice the median word duration.
for i in range(1, len(alignment)):
if alignment[i].end - alignment[i].start > max_duration:
if alignment[i].word in sentence_end_marks:
alignment[i].end = alignment[i].start + max_duration
elif alignment[i - 1].word in sentence_end_marks:
alignment[i].start = alignment[i].end - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_index = 0
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing.word:
words.append(
dict(
word=timing.word,
start=round(time_offset + timing.start, 2),
end=round(time_offset + timing.end, 2),
probability=timing.probability,
)
)
saved_tokens += len(timing.tokens)
word_index += 1
# hack: truncate long words at segment boundaries.
# a better segmentation algorithm based on VAD should be able to replace this.
if len(words) > 0:
# ensure the first and second word after a pause is not longer than
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
or (
len(words) > 1
and words[1]["end"] - words[0]["start"] > max_duration * 2
)
):
if (
len(words) > 1
and words[1]["end"] - words[1]["start"] > max_duration
):
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
if (
segment["start"] < words[0]["end"]
and segment["start"] - 0.5 > words[0]["start"]
):
words[0]["start"] = max(
0, min(words[0]["end"] - median_duration, segment["start"])
)
else:
segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
if (
segment["end"] > words[-1]["start"]
and segment["end"] + 0.5 < words[-1]["end"]
):
words[-1]["end"] = max(
words[-1]["start"] + median_duration, segment["end"]
)
else:
segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words

Просмотреть файл

@ -0,0 +1,395 @@
import base64
import os
import string
from dataclasses import dataclass, field
from functools import cached_property, lru_cache
from typing import Dict, List, Optional, Tuple
import tiktoken
LANGUAGES = {
"en": "english",
"zh": "chinese",
"de": "german",
"es": "spanish",
"ru": "russian",
"ko": "korean",
"fr": "french",
"ja": "japanese",
"pt": "portuguese",
"tr": "turkish",
"pl": "polish",
"ca": "catalan",
"nl": "dutch",
"ar": "arabic",
"sv": "swedish",
"it": "italian",
"id": "indonesian",
"hi": "hindi",
"fi": "finnish",
"vi": "vietnamese",
"he": "hebrew",
"uk": "ukrainian",
"el": "greek",
"ms": "malay",
"cs": "czech",
"ro": "romanian",
"da": "danish",
"hu": "hungarian",
"ta": "tamil",
"no": "norwegian",
"th": "thai",
"ur": "urdu",
"hr": "croatian",
"bg": "bulgarian",
"lt": "lithuanian",
"la": "latin",
"mi": "maori",
"ml": "malayalam",
"cy": "welsh",
"sk": "slovak",
"te": "telugu",
"fa": "persian",
"lv": "latvian",
"bn": "bengali",
"sr": "serbian",
"az": "azerbaijani",
"sl": "slovenian",
"kn": "kannada",
"et": "estonian",
"mk": "macedonian",
"br": "breton",
"eu": "basque",
"is": "icelandic",
"hy": "armenian",
"ne": "nepali",
"mn": "mongolian",
"bs": "bosnian",
"kk": "kazakh",
"sq": "albanian",
"sw": "swahili",
"gl": "galician",
"mr": "marathi",
"pa": "punjabi",
"si": "sinhala",
"km": "khmer",
"sn": "shona",
"yo": "yoruba",
"so": "somali",
"af": "afrikaans",
"oc": "occitan",
"ka": "georgian",
"be": "belarusian",
"tg": "tajik",
"sd": "sindhi",
"gu": "gujarati",
"am": "amharic",
"yi": "yiddish",
"lo": "lao",
"uz": "uzbek",
"fo": "faroese",
"ht": "haitian creole",
"ps": "pashto",
"tk": "turkmen",
"nn": "nynorsk",
"mt": "maltese",
"sa": "sanskrit",
"lb": "luxembourgish",
"my": "myanmar",
"bo": "tibetan",
"tl": "tagalog",
"mg": "malagasy",
"as": "assamese",
"tt": "tatar",
"haw": "hawaiian",
"ln": "lingala",
"ha": "hausa",
"ba": "bashkir",
"jw": "javanese",
"su": "sundanese",
"yue": "cantonese",
}
# language code lookup by name, with a few language aliases
TO_LANGUAGE_CODE = {
**{language: code for code, language in LANGUAGES.items()},
"burmese": "my",
"valencian": "ca",
"flemish": "nl",
"haitian": "ht",
"letzeburgesch": "lb",
"pushto": "ps",
"panjabi": "pa",
"moldavian": "ro",
"moldovan": "ro",
"sinhalese": "si",
"castilian": "es",
"mandarin": "zh",
}
@dataclass
class Tokenizer:
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
encoding: tiktoken.Encoding
num_languages: int
language: Optional[str] = None
task: Optional[str] = None
sot_sequence: Tuple[int] = ()
special_tokens: Dict[str, int] = field(default_factory=dict)
def __post_init__(self):
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot: int = self.special_tokens["<|startoftranscript|>"]
translate: int = self.special_tokens["<|translate|>"]
transcribe: int = self.special_tokens["<|transcribe|>"]
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None:
sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None:
task_token: int = transcribe if self.task == "transcribe" else translate
sot_sequence.append(task_token)
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids: List[int], **kwargs) -> str:
token_ids = [t for t in token_ids if t < self.timestamp_begin]
return self.encoding.decode(token_ids, **kwargs)
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
"""
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
"""
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self) -> int:
return self.encoding.eot_token
@cached_property
def transcribe(self) -> int:
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self) -> int:
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self) -> int:
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self) -> int:
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self) -> int:
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self) -> int:
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self) -> int:
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self) -> int:
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self) -> int:
"""Returns the token id corresponding to the value of the `language` field"""
if self.language is None:
raise ValueError("This tokenizer does not have language token configured")
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None):
return token
raise KeyError(f"Language {language} not found in tokenizer.")
@cached_property
def all_language_tokens(self) -> Tuple[int]:
result = []
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES:
result.append(token_id)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self) -> Tuple[str]:
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@cached_property
def non_speech_tokens(self) -> Tuple[int]:
"""
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- ( SPEAKING FOREIGN LANGUAGE )
- [DAVID] Hey there,
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
"""
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += (
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
)
# symbols that may be a single token or multiple tokens depending on the tokenizer.
# In case they're multiple tokens, suppress the first token, which is safe because:
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [
self.encoding.encode(symbol),
self.encoding.encode(" " + symbol),
]:
if len(tokens) == 1 or symbol in miscellaneous:
result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(self, tokens: List[int]):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
# These languages don't typically use spaces, so it is difficult to split words
# without morpheme analysis. Here, we instead split words at any
# position where the tokens are decoded as valid unicode points
return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens: List[int]):
decoded_full = self.decode_with_timestamps(tokens)
replacement_char = "\ufffd"
words = []
word_tokens = []
current_tokens = []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (
replacement_char not in decoded
or decoded_full[unicode_offset + decoded.index(replacement_char)]
== replacement_char
):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens: List[int]):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words = []
word_tokens = []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
special = subword_tokens[0] >= self.eot
with_space = subword.startswith(" ")
punctuation = subword.strip() in string.punctuation
if special or with_space or punctuation or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens
@lru_cache(maxsize=None)
def get_encoding(name: str = "gpt2", num_languages: int = 99):
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
ranks = {
base64.b64decode(token): int(rank)
for token, rank in (line.split() for line in open(vocab_path) if line)
}
n_vocab = len(ranks)
special_tokens = {}
specials = [
"<|endoftext|>",
"<|startoftranscript|>",
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
"<|translate|>",
"<|transcribe|>",
"<|startoflm|>",
"<|startofprev|>",
"<|nospeech|>",
"<|notimestamps|>",
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(
name=os.path.basename(vocab_path),
explicit_n_vocab=n_vocab,
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
mergeable_ranks=ranks,
special_tokens=special_tokens,
)
@lru_cache(maxsize=None)
def get_tokenizer(
multilingual: bool,
*,
num_languages: int = 99,
language: Optional[str] = None,
task: Optional[str] = None, # Literal["transcribe", "translate", None]
) -> Tokenizer:
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE:
language = TO_LANGUAGE_CODE[language]
else:
raise ValueError(f"Unsupported language: {language}")
if multilingual:
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
return Tokenizer(
encoding=encoding, num_languages=num_languages, language=language, task=task
)

Просмотреть файл

@ -0,0 +1,605 @@
import argparse
import os
import traceback
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import numpy as np
import torch
import tqdm
from .audio import (
FRAMES_PER_SECOND,
HOP_LENGTH,
N_FRAMES,
N_SAMPLES,
SAMPLE_RATE,
log_mel_spectrogram,
pad_or_trim,
)
from .decoding import DecodingOptions, DecodingResult
from .timing import add_word_timestamps
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
from .utils import (
exact_div,
format_timestamp,
get_end,
get_writer,
make_safe,
optional_float,
optional_int,
str2bool,
)
if TYPE_CHECKING:
from .model import Whisper
def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
no_speech_threshold: Optional[float] = 0.6,
condition_on_previous_text: bool = True,
initial_prompt: Optional[str] = None,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,!?::”)]}、",
clip_timestamps: Union[str, List[float]] = "0",
hallucination_silence_threshold: Optional[float] = None,
**decode_options,
):
"""
Transcribe an audio file using Whisper
Parameters
----------
model: Whisper
The Whisper model instance
audio: Union[str, np.ndarray, torch.Tensor]
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text being decoded to the console. If True, displays all the details,
If False, displays minimal details. If None, does not display anything
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
compression_ratio_threshold: float
If the gzip compression ratio is above this value, treat as failed
logprob_threshold: float
If the average log probability over sampled tokens is below this value, treat as failed
no_speech_threshold: float
If the no_speech probability is higher than this value AND the average log probability
over sampled tokens is below `logprob_threshold`, consider the segment as silent
condition_on_previous_text: bool
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
word_timestamps: bool
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
and include the timestamps for each word in each segment.
prepend_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the next word
append_punctuations: str
If word_timestamps is True, merge these punctuation symbols with the previous word
initial_prompt: Optional[str]
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
decode_options: dict
Keyword arguments to construct `DecodingOptions` instances
clip_timestamps: Union[str, List[float]]
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
The last end timestamp defaults to the end of the file.
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
when a possible hallucination is detected
Returns
-------
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
"""
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
if model.device == torch.device("cpu"):
if torch.cuda.is_available():
warnings.warn("Performing inference on CPU when CUDA is available")
if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32
if dtype == torch.float32:
decode_options["fp16"] = False
# Pad 30-seconds of silence to the input audio, for slicing
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None:
if not model.is_multilingual:
decode_options["language"] = "en"
else:
if verbose:
print(
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
)
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
print(
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
)
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language=language,
task=task,
)
if isinstance(clip_timestamps, str):
clip_timestamps = [
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
]
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0:
seek_points.append(0)
if len(seek_points) % 2 == 1:
seek_points.append(content_frames)
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,!?::”)]}、"
if word_timestamps and task == "translate":
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
temperatures = (
[temperature] if isinstance(temperature, (int, float)) else temperature
)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
# disable beam_size and patience when t > 0
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else:
# disable best_of when t == 0
kwargs.pop("best_of", None)
options = DecodingOptions(**kwargs, temperature=t)
decode_result = model.decode(segment, options)
needs_fallback = False
if (
compression_ratio_threshold is not None
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
if (
logprob_threshold is not None
and decode_result.avg_logprob < logprob_threshold
):
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
and decode_result.no_speech_prob > no_speech_threshold
):
needs_fallback = False # silence
if not needs_fallback:
break
return decode_result
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(
N_FRAMES, model.dims.n_audio_ctx
) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
all_tokens = []
all_segments = []
prompt_reset_since = 0
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
else:
initial_prompt_tokens = []
def new_segment(
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
"seek": seek,
"start": start,
"end": end,
"text": tokenizer.decode(text_tokens),
"tokens": tokens,
"temperature": result.temperature,
"avg_logprob": result.avg_logprob,
"compression_ratio": result.compression_ratio,
"no_speech_prob": result.no_speech_prob,
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
# for seek_clip_start, seek_clip_end in seek_clips:
# while seek < seek_clip_end
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start:
seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips):
seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
decode_options["prompt"] = all_tokens[prompt_reset_since:]
result: DecodingResult = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
if (
logprob_threshold is not None
and result.avg_logprob > logprob_threshold
):
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
if should_skip:
seek += segment_size # fast-forward to the next segment boundary
continue
previous_seek = seek
current_segments = []
# anomalous words are very long/short/improbable
def word_anomaly_score(word: dict) -> float:
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15:
score += 1.0
if duration < 0.133:
score += (0.133 - duration) * 15
if duration > 2.0:
score += duration - 2.0
return score
def is_segment_anomaly(segment: Optional[dict]) -> bool:
if segment is None or not segment["words"]:
return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments: List[dict]) -> Optional[dict]:
return next((s for s in segments if s["words"]), None)
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
# if the output contains two consecutive timestamp tokens
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
start_timestamp_pos = (
sliced_tokens[0].item() - tokenizer.timestamp_begin
)
end_timestamp_pos = (
sliced_tokens[-1].item() - tokenizer.timestamp_begin
)
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
end=time_offset + end_timestamp_pos * time_precision,
tokens=sliced_tokens,
result=result,
)
)
last_slice = current_slice
if single_timestamp_ending:
# single timestamp at the end means no speech after the last timestamp.
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
last_timestamp_pos = (
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
)
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (
len(timestamps) > 0
and timestamps[-1].item() != tokenizer.timestamp_begin
):
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (
timestamps[-1].item() - tokenizer.timestamp_begin
)
duration = last_timestamp_pos * time_precision
current_segments.append(
new_segment(
start=time_offset,
end=time_offset + duration,
tokens=tokens,
result=result,
)
)
seek += segment_size
if word_timestamps:
add_word_timestamps(
segments=current_segments,
model=model,
tokenizer=tokenizer,
mel=mel_segment,
num_frames=segment_size,
prepend_punctuations=prepend_punctuations,
append_punctuations=append_punctuations,
last_speech_timestamp=last_speech_timestamp,
)
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
seek = round(last_word_end * FRAMES_PER_SECOND)
# skip silence before possible hallucinations
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset:
remaining_duration = window_end_time - last_word_end
if remaining_duration > threshold:
seek = round(last_word_end * FRAMES_PER_SECOND)
else:
seek = previous_seek + segment_size
# if first segment might be a hallucination, skip leading silence
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
# skip silence before any possible hallucination that is surrounded
# by silence or more hallucinations
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]:
continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(
current_segments[si + 1 :]
)
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
hal_next_start = time_offset + segment_duration
silence_before = (
segment["start"] - hal_last_end > threshold
or segment["start"] < threshold
or segment["start"] - time_offset < 2.0
)
silence_after = (
hal_next_start - segment["end"] > threshold
or is_segment_anomaly(next_segment)
or window_end_time - segment["end"] < 2.0
)
if silence_before and silence_after:
seek = round(
max(time_offset + 1, segment["start"])
* FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None:
last_speech_timestamp = last_word_end
if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
print(make_safe(line))
# if a segment is instantaneous or does not contain text, clear it
for i, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend(
[
{"id": i, **segment}
for i, segment in enumerate(
current_segments, start=len(all_segments)
)
]
)
all_tokens.extend(
[token for segment in current_segments for token in segment["tokens"]]
)
if not condition_on_previous_text or result.temperature > 0.5:
# do not feed the prompt tokens if a high temperature was used
prompt_reset_since = len(all_tokens)
# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
segments=all_segments,
language=language,
)
def cli():
from . import available_models
def valid_model_name(name):
if name in available_models() or os.path.exists(name):
return name
raise ValueError(
f"model should be one of {available_models()} or path to a model checkpoint"
)
# fmt: off
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,!?::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
# fmt: on
args = parser.parse_args().__dict__
model_name: str = args.pop("model")
model_dir: str = args.pop("model_dir")
output_dir: str = args.pop("output_dir")
output_format: str = args.pop("output_format")
device: str = args.pop("device")
os.makedirs(output_dir, exist_ok=True)
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
if args["language"] is not None:
warnings.warn(
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
)
args["language"] = "en"
temperature = args.pop("temperature")
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
else:
temperature = [temperature]
if (threads := args.pop("threads")) > 0:
torch.set_num_threads(threads)
from . import load_model
model = load_model(model_name, device=device, download_root=model_dir)
writer = get_writer(output_format, output_dir)
word_options = [
"highlight_words",
"max_line_count",
"max_line_width",
"max_words_per_line",
]
if not args["word_timestamps"]:
for option in word_options:
if args[option]:
parser.error(f"--{option} requires --word_timestamps True")
if args["max_line_count"] and not args["max_line_width"]:
warnings.warn("--max_line_count has no effect without --max_line_width")
if args["max_words_per_line"] and args["max_line_width"]:
warnings.warn("--max_words_per_line has no effect with --max_line_width")
writer_args = {arg: args.pop(arg) for arg in word_options}
for audio_path in args.pop("audio"):
try:
result = transcribe(model, audio_path, temperature=temperature, **args)
writer(result, audio_path, **writer_args)
except Exception as e:
traceback.print_exc()
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
if __name__ == "__main__":
cli()

Просмотреть файл

@ -0,0 +1,316 @@
import json
import os
import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO
system_encoding = sys.getdefaultencoding()
if system_encoding != "utf-8":
def make_safe(string):
# replaces any character not representable using the system default encoding with an '?',
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
return string.encode(system_encoding, errors="replace").decode(system_encoding)
else:
def make_safe(string):
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
return string
def exact_div(x, y):
assert x % y == 0
return x // y
def str2bool(string):
str2val = {"True": True, "False": False}
if string in str2val:
return str2val[string]
else:
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
def optional_int(string):
return None if string == "None" else int(string)
def optional_float(string):
return None if string == "None" else float(string)
def compression_ratio(text) -> float:
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
output_path = os.path.join(
self.output_dir, audio_basename + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
def iterate_result(
self,
result: dict,
options: Optional[dict] = None,
*,
max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
chunk_index = 0
words_count = max_words_per_line
while chunk_index < len(segment["words"]):
remaining_words = len(segment["words"]) - chunk_index
if max_words_per_line > len(segment["words"]) - chunk_index:
words_count = remaining_words
for i, original_timing in enumerate(
segment["words"][chunk_index : chunk_index + words_count]
):
timing = original_timing.copy()
long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if (
line_len > 0
and has_room
and not long_pause
and not seg_break
):
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0:
yield subtitle
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers:
writer(result, file, options, **kwargs)
return write_all
return writers[output_format](output_dir)

Просмотреть файл

@ -44,7 +44,7 @@ To use the Llama and Mistral models, you will need to go through an extra step t
1. Visit
- LLaMA 2: [https://huggingface.co/meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
- LLaMA 3: [https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
- Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
- Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
2. Follow the steps on the Hugging Face page to obtain access
3. Run `huggingface-cli login`
4. Paste your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens) to login
@ -57,7 +57,7 @@ Run the chatbot app using the following command:
> python app.py
```
The chatbot app will start with the default settings, which uses `DirectML` as the backend to run the `Phi-3` model for inference using `float16`. The app will automatically download `Phi-3-4k-instruct` on the first run from the default `model_repo` which is set to `microsoft/Phi-3-4k-instruct`.
The chatbot app will start with the default settings, which uses `DirectML` as the backend to run the `Phi-3` model for inference using `float16`. The app will automatically download `Phi-3-4k-instruct` on the first run from the default `hf_model` which is set to `microsoft/Phi-3-4k-instruct`.
This model is optimized to take advantage of DirectML operators and to use the custom DirectML graph implementations for Rotary Positional Embedding (RoPE), Multi-Head Attention (MHA), and the Feedforward layers (MLP).
@ -109,13 +109,13 @@ You can also select another model to run (`microsoft/Phi-3-mini-4k-instruct`, `m
For example to run `Mistral-7B-Instruct-v0.1` use the following command:
```
> python app.py --precision float16 --model_repo "mistralai/Mistral-7B-Instruct-v0.1"
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
```
You should see the result such as this:
```
> python app.py --precision float16 --model_repo "mistralai/Mistral-7B-Instruct-v0.1"
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
checkpoints\mistralai\Mistral-7B-Instruct-v0.1\model.pth doesnt exist. Downloading and converting from huggingface hub
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.90k/3.90k [00:00<?, ?B/s]
model.safetensors.index.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25.1k/25.1k [00:00<?, ?B/s]
@ -146,9 +146,9 @@ Following is a list of the basic settings supported by `app.py`:
| Flag | Description | Default |
| ---------------------- | ------------------------------------------------------------ | ------- |
| `--help` | Show this help message. | N/A |
| `--model_repo` | Specify the model to downloading using the Hugging Face Repository ID. | `microsoft/Phi-3-mini-4k-instruct` |
| `--hf_model` | Specify the model to downloading using the Hugging Face Repository ID. | `microsoft/Phi-3-mini-4k-instruct` |
| `--precision` | Model precision to use during generation. Options: [`float16`, `float32`] | `float16` |
| `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{model_repo}/model.pth` |
| `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{hf_model}/model.pth` |
| `--max_context_length` | Max prompt length including the history. If exceeded, history is clipped starting from the first (user, assistant) pair. | `1500` |
| `--disable_history` | Disable the chat history during generation. | Enabled |
@ -170,25 +170,25 @@ We offer two methods for preparing PyTorch models:
### Use `download_and_convert.py` to download a language model:
```
> python .\scripts\download_and_convert.py --model_repo "microsoft/Phi-3-mini-4k-instruct"
> python .\scripts\download_and_convert.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
```
After the model is downloaded and converted, you can pass the following parameter to `app.py` to run the language model:
```
> python app.py --model_repo "microsoft/Phi-3-mini-4k-instruct"
> python app.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
```
### Download a DirectML optimized PyTorch model from the [Microsoft Hugging Face repo](https://huggingface.co/microsoft):
1. cd checkpoints
2. git clone https://huggingface.co/{model_repo} {model_repo}
2. git clone https://huggingface.co/{hf_model} {hf_model}
3. cd ../
After the model is downloaded, you can pass the following parameter to `app.py` to run the language model:
```
> python app.py --checkpoint_path "checkpoints/{model_repo}/model.pth"
> python app.py --checkpoint_path "checkpoints/{hf_model}/model.pth"
```
## External Links

Просмотреть файл

@ -1,13 +1,19 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
from pathlib import Path
from typing import Union
from typing import Union, List, Tuple, Iterator
import torch_directml
import torch
import gradio as gr
from transformers import AutoTokenizer
from transformers import AutoTokenizer, PreTrainedTokenizerFast
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
@ -18,18 +24,19 @@ from models.phi3 import Transformer as Phi3Transformer
from models.phi2 import Transformer as Phi2Transformer
from models.llama import Transformer as LlamaTransformer
device = torch_directml.device(torch_directml.default_device())
def decode_n_tokens(
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
cur_token: torch.Tensor,
input_pos: torch.Tensor,
num_new_tokens: int,
tokenizer,
tokenizer: PreTrainedTokenizerFast,
stream_every_n: int,
is_llama_3: bool = False,
**sampling_kwargs
):
) -> Iterator[str]:
res = tokenizer.decode(cur_token[0][0].item(), skip_special_tokens=True).strip() + " "
yield res
@ -48,7 +55,7 @@ def decode_n_tokens(
new_tokens.append(next_token.clone())
cur_token = next_token.view(1, -1)
# Handle output and overlap at the specified intervals or at the last token for adding
# Handle output and overlap at the specified intervals or at the last token for adding
# the space correctly between stream batches
if ((i + 1) % stream_every_n == 0 or i == num_new_tokens - 1):
# Determine the range of tokens to decode, including the overlap
@ -68,8 +75,6 @@ def decode_n_tokens(
from_index = max(0, start_pos - overlap_size)
yield decode_with_overlap(tokenizer, new_tokens, from_index, overlap_text)
break
return new_tokens
@torch.no_grad()
def generate(
@ -77,11 +82,11 @@ def generate(
prompt: torch.Tensor,
max_new_tokens: int,
*,
tokenizer = None,
tokenizer: PreTrainedTokenizerFast,
stream_every_n: int = 10,
is_llama_3: bool = False,
**sampling_kwargs
) -> torch.Tensor:
) -> Iterator[str]:
"""
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
"""
@ -98,36 +103,38 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
input_pos = torch.arange(0, T, device=device)
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
input_pos = torch.tensor([T], device=device, dtype=torch.int)
# generated_tokens = decode_n_tokens(
return decode_n_tokens(
yield from decode_n_tokens(
model, next_token.view(1, -1), input_pos, max_new_tokens - 1, tokenizer, stream_every_n, is_llama_3=is_llama_3, **sampling_kwargs)
class LLM_Model:
def __init__(self,
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.01,
model_repo: str = "microsoft/Phi-3-mini-4k-instruct",
checkpoint_path: str = None,
precision: str = 'float32',
stream_every_n: int = 7,
max_context_length: int = 3500,
use_history: bool = False):
def __init__(
self,
prompt: str = "Hello, my name is",
interactive: bool = False,
num_samples: int = 5,
max_new_tokens: int = 100,
top_k: int = 200,
temperature: float = 0.01,
hf_model: str = "microsoft/Phi-3-mini-4k-instruct",
checkpoint_path: str = None,
precision: str = 'float32',
stream_every_n: int = 7,
max_context_length: int = 3500,
use_history: bool = False
):
self.prompt = prompt
self.interactive = interactive
self.num_samples = num_samples
self.max_new_tokens = max_new_tokens
self.top_k = top_k
self.temperature = temperature
self.model_repo = model_repo
self.checkpoint_path = Path(f"checkpoints/{model_repo}/model.pth") if checkpoint_path is None else Path(checkpoint_path)
self.hf_model = hf_model
self.checkpoint_path = Path(f"checkpoints/{hf_model}/model.pth") if checkpoint_path is None else Path(checkpoint_path)
self.precision = torch.float32 if precision == 'float32' else torch.float16
self.stream_every_n = stream_every_n
self.max_context_length = max_context_length
@ -136,21 +143,34 @@ class LLM_Model:
self.tokenizer = None
self.model = None
def encode_tokens(self, prompt, conversation_history, device=None, max_context_length=1500, bos=True):
def encode_tokens(
self,
prompt: str,
conversation_history: List[List[str]],
device: torch.device = None,
max_context_length: int = 1500,
bos: bool = True
) -> torch.Tensor:
if self.is_phi_2:
tokens = self.format_prompt_phi2_chat_and_encode(
prompt, conversation_history, device, max_context_length, bos
)
else:
tokens = self.format_prompt_and_encode(
prompt, conversation_history, device, max_context_length,
prompt, conversation_history, device, max_context_length,
)
return tokens
def format_prompt_and_encode(self, prompt, history, device=None, max_context_length=3500):
def format_prompt_and_encode(
self,
prompt: str,
conversation_history: List[List[str]],
device: torch.device = None,
max_context_length: int = 1500,
) -> torch.Tensor:
messages = []
if len(history) and self.use_history:
for user, assistant in history:
if len(conversation_history) and self.use_history:
for user, assistant in conversation_history:
user = {"role": "user", "content": user}
assistant = {"role": "assistant", "content": assistant}
messages.append(user)
@ -172,7 +192,14 @@ class LLM_Model:
return tokens
def format_prompt_phi2_chat_and_encode(self, prompt, conversation_history, device=None, max_context_length=1500, bos=True):
def format_prompt_phi2_chat_and_encode(
self,
prompt: str,
conversation_history: List[List[str]],
device: torch.device = None,
max_context_length: int = 1500,
bos: bool = True
) -> torch.Tensor:
formatted_prompt = ""
if self.use_history:
for user_prompt, llm_response in conversation_history:
@ -194,7 +221,14 @@ class LLM_Model:
token_tensor = torch.tensor(tokens, dtype=torch.int, device=device)
return token_tensor
def format_prompt_phi2_qa_and_encode(self, prompt, conversation_history, max_context_length=1500, bos=True, device=None):
def format_prompt_phi2_qa_and_encode(
self,
prompt: str,
conversation_history: List[List[str]],
device: torch.device = None,
max_context_length: int = 1500,
bos: bool = True
) -> torch.Tensor:
formatted_prompt = ""
if self.use_history:
for user_prompt, llm_response in conversation_history:
@ -218,22 +252,22 @@ class LLM_Model:
token_tensor = torch.tensor(tokens, dtype=torch.int, device=device)
return token_tensor
def download_and_convert(self):
checkpoint_dir = hf_download(self.model_repo)
def download_and_convert(self) -> None:
checkpoint_dir = hf_download(self.hf_model)
convert_hf_checkpoint(
checkpoint_dir=Path(checkpoint_dir),
)
self.checkpoint_path = Path(f"checkpoints/{self.model_repo}/model.pth")
self.checkpoint_path = Path(f"{checkpoint_dir}/model.pth")
def load_model(self):
def load_model(self) -> None:
if not self.checkpoint_path.is_file():
print(f"{self.checkpoint_path} doesnt exist. Downloading and converting {self.model_repo} from huggingface hub. "
"Specify a different model with --model_repo or valid pre-converted checkpoint with --checkpoint_path")
print(f"{self.checkpoint_path} doesnt exist. Downloading and converting {self.hf_model} from huggingface hub. "
"Specify a different model with --hf_model or valid pre-converted checkpoint with --checkpoint_path")
self.download_and_convert()
print("Running app...")
print(f"Loading model from {self.checkpoint_path}")
self.is_llama_3 = "Llama-3" in self.checkpoint_path.parent.name
self.is_phi_2 = "phi-2" in self.checkpoint_path.parent.name
print(f"Using device={device}, is_llama_3={self.is_llama_3}, is_phi_2={self.is_phi_2}")
@ -245,9 +279,14 @@ class LLM_Model:
if self.max_context_length > self.model.config.block_size - (self.max_new_tokens+1):
raise ValueError(
f"Expected max_context_length to be less than {self.model.config.block_size - (self.max_new_tokens+1)} but got {self.max_context_length}")
@torch.no_grad()
def chat(self, prompt, history, **sampling_kwargs):
def chat(
self,
prompt: str,
history: List[List[str]],
**sampling_kwargs
) -> Iterator[str]:
torch.manual_seed(1235)
encoded = self.encode_tokens(
prompt,
@ -256,7 +295,7 @@ class LLM_Model:
max_context_length=self.max_context_length,
)
toks = generate(
yield from generate(
self.model,
encoded,
self.max_new_tokens,
@ -266,32 +305,30 @@ class LLM_Model:
temperature=self.temperature,
top_k=self.top_k,
)
return toks
def chat(message, history):
def chat(message: str, history: List[List[str]]) -> Iterator[str]:
total_msg = ""
for msg in llm_model.chat(message, history):
total_msg += msg
yield total_msg
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Your CLI description.')
parser.add_argument(
'--model_repo',
'--hf_model',
type=str,
default="microsoft/Phi-3-mini-4k-instruct",
help='Huggingface Repository ID to download from.'
default="phi-3",
help='Huggingface Repository ID to download from. Or one of the model name from ["phi-2", "phi-3", "llama-2", "llama-3", "mistral"]'
)
parser.add_argument(
'--checkpoint_path',
type=str,
default=None,
help='Converted pytorch model checkpoint path. Defaults to `checkpoints/{model_repo}/model.pth`.'
help='Converted pytorch model checkpoint path. Defaults to `checkpoints/{hf_model}/model.pth`.'
)
parser.add_argument(
'--max_context_length',
@ -319,7 +356,7 @@ if __name__ == "__main__":
max_new_tokens = 500,
top_k = 200,
temperature = 0.8,
model_repo = args.model_repo,
hf_model = args.hf_model,
checkpoint_path = args.checkpoint_path,
precision = args.precision,
max_context_length = args.max_context_length,

Просмотреть файл

Просмотреть файл

@ -0,0 +1,53 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from models.configs import ModelArgs, find_multiple
class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.max_position_embeddings = config.block_size
self.rope_base = config.rope_base
self.max_batch_size = -1
self.max_seq_length = -1
def setup_caches(self, max_batch_size, max_seq_length):
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype=self.output.weight.dtype
for b in self.layers:
b.attention._init_rope(self.max_position_embeddings, self.rope_base, dtype=dtype)
self.causal_mask = torch.tril(
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
)
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
x = self.tok_embeddings(idx)
for _, layer in enumerate(self.layers):
x = layer(x, input_pos, mask)
x = self.norm(x)
logits = self.output(x)
return logits
@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))

Просмотреть файл

@ -0,0 +1,70 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
partial_rotary_factor: float = 1.0
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
if not config:
raise ValueError(f"No configuration found for the model named '{name}'. Supported models: {list(transformer_configs.keys())}")
return cls(**transformer_configs[config[0]])
transformer_configs = {
"7B": dict(block_size=4096, n_layer=32, n_head=32, dim=4096),
"phi-2": dict(block_size=2048, n_layer=32, n_head=32, dim=2560, intermediate_size=10240, rope_base=10000, vocab_size=51200, partial_rotary_factor=0.4),
"Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064),
"Mistral-7B": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0),
}
default_models = {
"llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
"llama-2": "meta-llama/Llama-2-7b-chat-hf",
"phi-3": "microsoft/Phi-3-mini-4k-instruct",
"phi-2": "microsoft/phi-2",
"mistral": "mistralai/Mistral-7B-Instruct-v0.1"
}

Просмотреть файл

@ -0,0 +1,137 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union, Tuple, Dict
import torch_directml
import torch
import torch.nn as nn
from torch import Tensor
from models.configs import ModelArgs
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm = torch_directml.rmsnorm
def forward(self, x: Tensor) -> Tensor:
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
return output.type_as(x)
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim: int,
max_position_embeddings: int = 4096,
base: Union[int, float] = 10000,
dtype: torch.dtype = torch.float16,
device: Optional[torch.device] = None,
):
super().__init__()
self.dtype = dtype
device = device if device is not None else torch_directml.device(torch_directml.default_device())
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device
)
def _set_cos_sin_cache(self, seq_len: int, device: torch.device) -> None:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
def forward(self) -> Tuple[torch.Tensor, torch.Tensor] :
return (
self.cos_cached,
self.sin_cached
)
class LlamaAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict: Dict[str, torch.Tensor], prefix: str, *argspy):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def _init_rope( self, max_position_embeddings: int = 4096, rope_base: Union[int, float] = 10000.0, dtype: torch.dtype = torch.float16) -> None:
self.min_position = 0
self.past_key_tensor = None
self.past_value_tensor = None
self.rotary_emb = RotaryEmbedding(
self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype
)
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb()
q, k = torch_directml.apply_rotary_position_emb(
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
self.min_position += seqlen
if self.n_head != self.n_local_heads:
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
)
y = self.wo(y)
return y
class LlamaTransformerBlock(nn.Module):
def __init__(self, config: ModelArgs, feed_forward_module: nn.Module):
super().__init__()
self.attention = LlamaAttention(config)
self.feed_forward = feed_forward_module(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out

Просмотреть файл

@ -1,180 +1,30 @@
from dataclasses import dataclass
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch_directml
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from models.configs import ModelArgs
from models.layers import RMSNorm
from models.layers import LlamaTransformerBlock as TransformerBlock
from models.base import Transformer
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
return cls(**transformer_configs[config[0]])
transformer_configs = {
"7B": dict(block_size=4096, n_layer=32, n_head=32, dim=4096),
"Mistral-7B": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0),
}
class Transformer(nn.Module):
class LlamaTransformer(Transformer):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
self.max_position_embeddings = config.block_size
self.rope_base = config.rope_base
super().__init__(config)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.layers = nn.ModuleList(TransformerBlock(config, FeedForward) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.max_batch_size = -1
self.max_seq_length = -1
def setup_caches(self, max_batch_size, max_seq_length):
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype=self.output.weight.dtype
for b in self.layers:
b.attention._init_rope(self.max_position_embeddings, self.rope_base, dtype=dtype)
self.causal_mask = torch.tril(
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
)
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
x = self.tok_embeddings(idx)
for _, layer in enumerate(self.layers):
x = layer(x, input_pos, mask)
x = self.norm(x)
logits = self.output(x)
return logits
@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *argspy):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def _init_rope(self, max_position_embeddings=4096, rope_base=10000.0, dtype=torch.float16):
self.min_position = 0
self.past_key_tensor = None
self.past_value_tensor = None
self.rotary_emb = RotaryEmbedding(
self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype)
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb()
q, k = torch_directml.apply_rotary_position_emb(
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
self.min_position += seqlen
if self.n_head != self.n_local_heads:
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
)
y = self.wo(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -185,47 +35,3 @@ class FeedForward(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.mlp(x, self.w1.weight, self.w3.weight, self.w2.weight)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm = torch_directml.rmsnorm
def forward(self, x: Tensor) -> Tensor:
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
return output.type_as(x)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
device = device if device is not None else torch_directml.device(torch_directml.default_device())
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device
)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
def forward(self):
return (
self.cos_cached,
self.sin_cached
)

Просмотреть файл

@ -1,94 +1,28 @@
from dataclasses import dataclass
from typing import Optional
import math
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Dict, Union
import torch_directml
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 51200
n_layer: int = 32
n_head: int = 32
dim: int = 2560
intermediate_size: int = 10240
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
else:
self.n_kv_groups = self.n_head // self.n_local_heads
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
assert len(config) == 1, name
return cls(**transformer_configs[config[0]])
transformer_configs = {
"phi-2": dict(block_size=2048, n_layer=32, n_head=32, dim=2560, intermediate_size=10240, rope_base=10000),
}
from models.configs import ModelArgs
from models.layers import RotaryEmbedding
from models.base import Transformer
class Transformer(nn.Module):
class Phi2Transformer(Transformer):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
super().__init__(config)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size)
self.max_batch_size = -1
self.max_seq_length = -1
def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float32):
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
for b in self.layers:
b.attention._init_rope(dtype=dtype)
self.causal_mask = torch.tril(
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
)
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
x = self.tok_embeddings(idx)
for _, layer in enumerate(self.layers):
x = layer(x, input_pos, mask)
x = self.norm(x)
logits = self.output(x)
return logits
@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -102,7 +36,6 @@ class TransformerBlock(nn.Module):
ffn_hidden_states = self.feed_forward(hidden_states)
return attn_outputs + ffn_hidden_states + x
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
@ -111,26 +44,37 @@ class Attention(nn.Module):
self.wqkv = nn.Linear(config.dim, 3 * config.dim)
self.wo = nn.Linear(config.dim, config.dim)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self.partial_rotary_factor = config.partial_rotary_factor
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *argspy):
def load_hook(self, state_dict: Dict[str, torch.Tensor], prefix: str, *argspy):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def _init_rope(self, dtype=torch.float32):
def _init_rope(
self,
max_position_embeddings: int = 4096,
rope_base: Union[int, float] = 10000.0,
dtype: torch.dtype = torch.float32
) -> None:
self.min_position = 0
self.past_key_tensor = None
self.past_value_tensor = None
self.rotary_emb = PhiRotaryEmbedding(self.n_head, dtype=dtype)
self.rotary_emb = RotaryEmbedding(
int(self.head_dim * self.partial_rotary_factor),
max_position_embeddings=max_position_embeddings,
base=rope_base,
dtype=dtype
)
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
@ -139,9 +83,9 @@ class Attention(nn.Module):
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1,2)
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1,2)
cos, sin = self.rotary_emb()
q, k = torch_directml.apply_rotary_position_emb(
q, k, cos, sin, self.min_position, seqlen, self.rotary_emb.dim)
@ -155,7 +99,6 @@ class Attention(nn.Module):
y = self.wo(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -165,34 +108,3 @@ class FeedForward(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.mlp(x, self.w1.weight, self.w2.weight, self.w1.bias, self.w2.bias)
class PhiRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=torch.float32, device=None):
super().__init__()
device = device if device is not None else torch_directml.device(torch_directml.default_device())
self.dtype = dtype
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=dtype).to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=dtype
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
def forward(self):
return (
self.cos_cached,
self.sin_cached,
)

Просмотреть файл

@ -1,86 +1,29 @@
from dataclasses import dataclass
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional
import torch_directml
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from models.configs import ModelArgs
from models.layers import RMSNorm
from models.layers import LlamaTransformerBlock as TransformerBlock
from models.base import Transformer
def find_multiple(n: int, k: int) -> int:
if n % k == 0:
return n
return n + k - (n % k)
@dataclass
class ModelArgs:
block_size: int = 2048
vocab_size: int = 32000
n_layer: int = 32
n_head: int = 32
dim: int = 4096
intermediate_size: int = None
n_local_heads: int = -1
head_dim: int = 64
rope_base: float = 10000
norm_eps: float = 1e-5
def __post_init__(self):
if self.n_local_heads == -1:
self.n_local_heads = self.n_head
if self.intermediate_size is None:
hidden_dim = 4 * self.dim
n_hidden = int(2 * hidden_dim / 3)
self.intermediate_size = find_multiple(n_hidden, 256)
self.head_dim = self.dim // self.n_head
@classmethod
def from_name(cls, name: str):
if name in transformer_configs:
return cls(**transformer_configs[name])
# fuzzy search
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
# take longer name (as it have more symbols matched)
if len(config) > 1:
config.sort(key=len, reverse=True)
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
return cls(**transformer_configs[config[0]])
transformer_configs = {
"Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064),
}
class Transformer(nn.Module):
class Phi3Transformer(Transformer):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.config = config
super().__init__(config)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
self.layers = nn.ModuleList(TransformerBlock(config, FeedForward) for _ in range(config.n_layer))
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
def setup_caches(self, max_batch_size, max_seq_length):
head_dim = self.config.dim // self.config.n_head
max_seq_length = find_multiple(max_seq_length, 8)
self.max_seq_length = max_seq_length
self.max_batch_size = max_batch_size
dtype=self.output.weight.dtype
for b in self.layers:
b.attention._init_rope(dtype=dtype)
self.causal_mask = torch.tril(
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
)
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
x = self.tok_embeddings(idx)
@ -92,80 +35,6 @@ class Transformer(nn.Module):
logits = self.output(x)
return logits
@classmethod
def from_name(cls, name: str):
return cls(ModelArgs.from_name(name))
class TransformerBlock(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.attention = Attention(config)
self.feed_forward = FeedForward(config)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
h = x + self.attention(self.attention_norm(x), mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
class Attention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
assert config.dim % config.n_head == 0
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
self.wo = nn.Linear(config.dim, config.dim, bias=False)
self.kv_cache = None
self.n_head = config.n_head
self.head_dim = config.head_dim
self.n_local_heads = config.n_local_heads
self.dim = config.dim
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(self, state_dict, prefix, *argspy):
if prefix + "wq.weight" in state_dict:
wq = state_dict.pop(prefix + "wq.weight")
wk = state_dict.pop(prefix + "wk.weight")
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
def _init_rope(self, dtype=torch.float32):
self.min_position = 0
self.past_key_tensor = None
self.past_value_tensor = None
self.rotary_emb = RotaryEmbedding(self.head_dim, dtype=dtype)
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape
kv_size = self.n_local_heads * self.head_dim
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb()
q, k = torch_directml.apply_rotary_position_emb(
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
self.min_position += seqlen
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
)
y = self.wo(y)
return y
class FeedForward(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
@ -176,46 +45,3 @@ class FeedForward(nn.Module):
def forward(self, x: Tensor) -> Tensor:
return self.mlp(x, self.w1.weight, self.w2.weight)
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm = torch_directml.rmsnorm
def forward(self, x: Tensor) -> Tensor:
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
return output.type_as(x)
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=4096, base=10000, dtype=torch.float32, device=None):
super().__init__()
self.dtype = dtype
device = device if device is not None else torch_directml.device(torch_directml.default_device())
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=dtype).to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=dtype
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
def forward(self):
return (
self.cos_cached,
self.sin_cached
)

Просмотреть файл

@ -1,8 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import json
import re
@ -13,61 +15,73 @@ from typing import Optional
import torch
from safetensors.torch import load_file
from requests.exceptions import HTTPError
from huggingface_hub.utils._errors import RepositoryNotFoundError
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from models.phi2 import ModelArgs as Phi2ModelArgs
from models.phi3 import ModelArgs as Phi3ModelArgs
from models.llama import ModelArgs as LlamaModelArgs
from models.configs import ModelArgs, default_models
def hf_download(model_repo: Optional[str] = None, hf_token: Optional[str] = None) -> None:
def is_dir_empty(directory: str) -> bool:
return not any(os.scandir(directory))
def download_model_from_hf(hf_model: str, checkpoint_dir: str, hf_token: Optional[str]) -> str:
from huggingface_hub import snapshot_download
checkpoint_dir = f"checkpoints/{model_repo}"
checkpoint_dir = f"{checkpoint_dir}/{hf_model}"
os.makedirs(checkpoint_dir, exist_ok=True)
if os.listdir(checkpoint_dir):
if not is_dir_empty(checkpoint_dir):
print(f"The directory {checkpoint_dir} is not empty. Skipping download.")
else:
try:
snapshot_download(model_repo, local_dir=checkpoint_dir, local_dir_use_symlinks=False, token=hf_token)
snapshot_download(hf_model, local_dir=checkpoint_dir, local_dir_use_symlinks=False, token=hf_token)
print(f"Downloaded {hf_model} successfully.")
except HTTPError as e:
if e.response.status_code == 401:
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
else:
raise e
print("You need to pass a valid Hugging Face token to download private models.")
raise e
return checkpoint_dir
def hf_download(hf_model: Optional[str] = None, hf_token: Optional[str] = None, checkpoint_dir: str = "checkpoints") -> None:
try:
checkpoint_dir_download = download_model_from_hf(hf_model, checkpoint_dir, hf_token)
except RepositoryNotFoundError as e:
# invalid repo passed, try to search for a default repo from the given hf_model
os.rmdir(f"{checkpoint_dir}/{hf_model}")
print(f"Couldn't find {hf_model} on HuggingFace. Searching for the closest supported match ...")
if hf_model in default_models:
hf_model = default_models[hf_model]
else:
raise ValueError(f"Please provide a valid hf_model to download from Huggingface. {hf_model} doesnt exist on Huggingface.")
print(f"Found closest match on Huggingface: {hf_model}")
checkpoint_dir_download = download_model_from_hf(hf_model, checkpoint_dir, hf_token)
return checkpoint_dir_download
@torch.inference_mode()
def convert_hf_checkpoint(
*,
checkpoint_dir: Path = Path("microsoft/Phi-3-mini-4k-instruct"),
checkpoint_dir: Path = Path("checkpoints/microsoft/Phi-3-mini-4k-instruct"),
weight_map_path: str = "config/weight_map.json",
model_name: Optional[str] = None,
) -> None:
if not os.path.exists(checkpoint_dir):
raise ValueError("Please download you model first with the hf_download function.")
if os.path.exists(checkpoint_dir / "model.pth"):
print(f"Converted checkpoint already exists here {checkpoint_dir / 'model.pth'}. Skipping Conversion.")
return
model_name = checkpoint_dir.name
config = ModelArgs.from_name(model_name)
with open(weight_map_path, 'r') as file:
weight_maps = json.load(file)
if model_name is None:
model_name = checkpoint_dir.name
is_llama3 = "Llama-3" in model_name
is_phi3 = "Phi-3" in model_name
model_name = checkpoint_dir.name
model_name = "llama" if "phi" not in model_name.lower() else model_name
weight_map = weight_maps[model_name]
if "phi" not in model_name.lower():
weight_map = weight_maps["llama"]
config = LlamaModelArgs.from_name(model_name)
else:
weight_map = weight_maps[model_name]
if is_phi3:
config = Phi3ModelArgs.from_name(model_name)
else:
config = Phi2ModelArgs.from_name(model_name)
# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "model.safetensors.index.json"
@ -77,12 +91,12 @@ def convert_hf_checkpoint(
bin_index = json.load(json_map)
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
merged_result = {}
for file in sorted(bin_files):
state_dict = load_file(str(file))
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
if "layers" in key:
@ -125,13 +139,15 @@ def convert_hf_checkpoint(
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Download and convert HuggingFace checkpoint.')
parser.add_argument('--model_repo', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Huggingface Repository ID to download from.')
parser.add_argument('--hf_model', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Huggingface Repository ID to download from.')
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
parser.add_argument('--model_name', type=str, default=None)
parser.add_argument(
'--checkpoint_dir', type=str, default="checkpoints",
help="Directory to downloads the Huggingface repo to. The model will be downloaded and converted to '{checkpoint_dir}/{hf_model}/"
)
args = parser.parse_args()
checkpoint_dir = hf_download(args.model_repo, args.hf_token)
checkpoint_dir = hf_download(args.hf_model, args.hf_token, args.checkpoint_dir)
convert_hf_checkpoint(
checkpoint_dir=Path(checkpoint_dir),
model_name=args.model_name,
)

Просмотреть файл

@ -1,27 +1,33 @@
# Copyright (c) Microsoft Corporation.
#
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, List
import torch
from torch import Tensor
from transformers import AutoTokenizer
from transformers import PreTrainedTokenizerFast
# support running without installing as a package
wd = Path(__file__).parent.parent.resolve()
sys.path.append(str(wd))
from models.phi3 import Transformer as Phi3Transformer
from models.phi2 import Transformer as Phi2Transformer
from models.llama import Transformer as LlamaTransformer
from models.phi3 import Phi3Transformer
from models.phi2 import Phi2Transformer
from models.llama import LlamaTransformer
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
def multinomial_sample_one_no_sync(probs_sort: Tensor) -> Tensor: # Does multinomial sampling without a cuda synchronization
q = torch.empty_like(probs_sort).exponential_(1)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
def logits_to_probs(logits: Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> Tensor:
logits = logits / max(temperature, 1e-5)
if top_k is not None:
@ -31,8 +37,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
probs = torch.nn.functional.softmax(logits, dim=-1)
return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
def sample(logits: Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> Tensor:
probs = logits_to_probs(logits[0, -1], temperature, top_k)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@ -40,10 +45,10 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
def prefill(
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
x: torch.Tensor,
input_pos: torch.Tensor,
x: Tensor,
input_pos: Tensor,
**sampling_kwargs
) -> torch.Tensor:
) -> Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)[0]
@ -51,17 +56,16 @@ def prefill(
def decode_one_token(
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
x: torch.Tensor,
input_pos: torch.Tensor,
x: Tensor,
input_pos: Tensor,
**sampling_kwargs
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[Tensor, Tensor]:
# input_pos: [B, 1]
assert input_pos.shape[-1] == 1
logits = model(x, input_pos)
return sample(logits, **sampling_kwargs)
def decode_with_overlap(tokenizer, tokens, start, overlap):
def decode_with_overlap(tokenizer: PreTrainedTokenizerFast, tokens: List[Tensor], start: int, overlap: str) -> str:
"""Helper function to decode text, managing overlap."""
current_decoded = tokenizer.decode(torch.IntTensor(tokens[start:]).tolist(), skip_special_tokens=True)
if overlap and current_decoded.startswith(overlap):
@ -70,8 +74,7 @@ def decode_with_overlap(tokenizer, tokens, start, overlap):
text_output = current_decoded
return text_output
def _load_model(checkpoint_path, device, precision):
def _load_model(checkpoint_path: str, device: torch.device, precision: torch.dtype) -> torch.nn.Module:
model_name = checkpoint_path.parent.name
with torch.device('meta'):
if 'phi-2' in model_name.lower():