Update docs for torch-directml 0.2.2 (#593)
* update docs for next torch-directml release * Minor readme spacing issues --------- Co-authored-by: Sheil Kumar <sheilk@microsoft.com> Co-authored-by: Dwayne Robinson <dwayner@microsoft.com>
This commit is contained in:
Родитель
4d65cad0be
Коммит
372a622890
|
@ -22,6 +22,7 @@ For `torch-directml` samples find brief summaries below or explore the [cv](./cv
|
|||
* [resnet50 - an image classification model](./cv/resnet50)
|
||||
* [maskrcnn - an object detection model](./cv/objectDetection/maskrcnn/)
|
||||
* [llm - a text generation and chatbot app supporting various language models](./llm/)
|
||||
* [whisper - a general-purpose speech recognition model](./audio/whisper/)
|
||||
|
||||
## External Links
|
||||
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2022 OpenAI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
|
@ -0,0 +1,94 @@
|
|||
# Speech Recognition with Whisper
|
||||
This sample guides you on how to run OpenAI's automatic speech recognition (ASR) [Whisper model](https://github.com/openai/whisper/blob/main/README.md) with our DirectML-backend.
|
||||
|
||||
- [Setup](#setup)
|
||||
- [About Whisper](#run-the-whisper-model)
|
||||
- [Basic Settings](#basic-settings)
|
||||
- [External Links](#external-links)
|
||||
- [Model License](#model-license)
|
||||
|
||||
|
||||
## About Whisper
|
||||
|
||||
The [OpenAI Whisper](https://github.com/openai/whisper/) model is a general-purpose speech recognition model. It is trained on a large dataset of diverse audio and is also a multitasking model that can perform multilingual speech recognition, speech translation, and language identification.
|
||||
|
||||
Whisper supports five model sizes, four with English-only versions and all five with multilingual versions.
|
||||
| Size | Parameters | English-only model | Multilingual model | Required VRAM
|
||||
|:---------:|:----------:|:------------------:|:------------------:|:-------------:|
|
||||
| tiny | 39 M | `tiny.en` | `tiny` | ~1 GB |
|
||||
| base | 74 M | `base.en` | `base` | ~1 GB |
|
||||
| small | 244 M | `small.en` | `small` | ~2 GB |
|
||||
| medium | 769 M | `medium.en` | `medium` | ~5 GB |
|
||||
| large v3 | 1550 M | N/A | `large-v3` | ~10 GB |
|
||||
|
||||
For more information on the model, please refer to the [OpenAI Whisper GitHub repo](https://github.com/openai/whisper/).
|
||||
|
||||
|
||||
## Setup
|
||||
Once you've setup `torch-directml` following our [Windows](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-windows) and [WSL](https://learn.microsoft.com/en-us/windows/ai/directml/pytorch-wsl) guidance, install the following requirements for running the app:
|
||||
|
||||
|
||||
```
|
||||
conda install ffmpeg
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
|
||||
## Run the Whisper model
|
||||
Run Whisper with DirectML backend with a sample audio file with the following command:
|
||||
```bash
|
||||
python run.py --input_file <audio_file> --model_size "tiny.en"
|
||||
```
|
||||
|
||||
|
||||
For example, you should see the result output as below:
|
||||
```
|
||||
> python run.py --input_file test/samples_jfk.wav --model_size "tiny.en"
|
||||
100%|█████████████████████████████████████| 72.1M/72.1M [00:09<00:00, 7.90MiB/s]
|
||||
test/samples_jfk.wav
|
||||
|
||||
And so my fellow Americans ask not what your country can do for you ask what you can do for your country.
|
||||
```
|
||||
|
||||
|
||||
Note, by default [OpenAI Whisper](https://github.com/openai/whisper/) uses a naive implementation for the scaled dot product attention. If you want to improve performance further to leverage DirectML's scaled dot product attention, execute `run.py` with `--use_dml_attn` flag:
|
||||
|
||||
```bash
|
||||
python run.py --input_file <audio_file> --model_size "tiny.en" --use_dml_attn
|
||||
```
|
||||
Based on this flag `MultiHeadAttention` module in `model.py` would choose between naive whisper scaled dot product attention and DirectML's scaled dot product attention:
|
||||
```python
|
||||
if use_dml_attn:
|
||||
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention)
|
||||
else:
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
```
|
||||
|
||||
## Basic Settings
|
||||
|
||||
Following is a list of the basic settings supported by `run.py`:
|
||||
|
||||
|
||||
|
||||
| Flag | Description | Default |
|
||||
| ---------------------- | ------------------------------------------------------------ | ------- |
|
||||
| `--help` | Show this help message. | - |
|
||||
| `--input_file` | [Required] Path to input audio file | - |
|
||||
| `--model_size` | Size of Whisper model to use. Options: [`tiny.en`, `tiny`, `base.en`, `base`, `small.en`, `small`, `medium.en`, `medium`, `large-v3`] | `tiny.en` |
|
||||
| `--fp16` | Runs inference with fp16 precision. | True |
|
||||
| `--use_dml_attn` | Runs inference with DirectML Scaled dot product attention impl. | False |
|
||||
|
||||
|
||||
## External Links
|
||||
- [Whisper Base Hugging Face Repository](https://huggingface.co/openai/whisper-base.en)
|
||||
- [Whisper Tiny Hugging Face Repository](https://huggingface.co/openai/whisper-tiny.en)
|
||||
- [Whisper Small Hugging Face Repository](https://huggingface.co/openai/whisper-small.en)
|
||||
- [Whisper Medium Hugging Face Repository](https://huggingface.co/openai/whisper-medium.en)
|
||||
- [Whisper Large v3 Hugging Face Repository](https://huggingface.co/openai/whisper-large-v3)
|
||||
- [Whisper GitHub Repo](https://github.com/openai/whisper)
|
||||
|
||||
|
||||
|
||||
## Model License
|
||||
|
||||
Whisper's code and model weights are released under the MIT License. See [LICENSE](https://github.com/openai/whisper/blob/main/LICENSE) for further details.
|
|
@ -0,0 +1,6 @@
|
|||
numba
|
||||
numpy
|
||||
tqdm
|
||||
more-itertools
|
||||
tiktoken
|
||||
ffmpeg-python
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import whisper
|
||||
import torch_directml
|
||||
import argparse
|
||||
|
||||
|
||||
def main(args):
|
||||
device = torch_directml.device(torch_directml.default_device())
|
||||
model = whisper.load_model(args.model_size, device=device, use_dml_attn=args.use_dml_attn)
|
||||
|
||||
# Load audio and pad/trim it to fit 30 seconds
|
||||
audio = whisper.load_audio(args.input_file)
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
|
||||
n_mels = 80
|
||||
if args.model_size == "large-v3":
|
||||
n_mels = 128
|
||||
|
||||
mel = whisper.log_mel_spectrogram(audio, n_mels=n_mels).to(model.device)
|
||||
language = "en"
|
||||
if "en" not in args.model_size:
|
||||
_, probs = model.detect_language(mel)
|
||||
language = max(probs, key=probs.get)
|
||||
print(f"Detected language: {language}")
|
||||
|
||||
options = whisper.DecodingOptions(language=language, fp16=args.fp16)
|
||||
result = whisper.decode(model, mel, options)
|
||||
|
||||
print(result.text)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Whisper model on specified audio file with warmup.')
|
||||
parser.add_argument('--model_size', type=str, default='tiny.en', help='Size of the Whisper model to use.')
|
||||
parser.add_argument('--input_file', type=str, required=True, help='Path to the input audio file.')
|
||||
parser.add_argument('--fp16', action="store_true", help='Runs inference with fp16 precision.')
|
||||
parser.add_argument('--use_dml_attn', action="store_true", help='Use DirectML attention implementation.')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
Двоичный файл не отображается.
|
@ -0,0 +1,160 @@
|
|||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
|
||||
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
|
||||
from .model import ModelDimensions, Whisper
|
||||
from .transcribe import transcribe
|
||||
# from .version import __version__
|
||||
|
||||
_MODELS = {
|
||||
"tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
|
||||
"tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
|
||||
"base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
|
||||
"base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
|
||||
"small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
|
||||
"small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
|
||||
"medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
|
||||
"medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
|
||||
"large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
|
||||
"large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
|
||||
"large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
"large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
|
||||
}
|
||||
|
||||
# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
|
||||
# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
|
||||
_ALIGNMENT_HEADS = {
|
||||
"tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
|
||||
"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
|
||||
"base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
|
||||
"base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
|
||||
"small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
|
||||
"small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
|
||||
"medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
|
||||
"medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
|
||||
"large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
|
||||
"large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
|
||||
"large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
"large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
|
||||
os.makedirs(root, exist_ok=True)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, os.path.basename(url))
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
with open(download_target, "rb") as f:
|
||||
model_bytes = f.read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
|
||||
return model_bytes if in_memory else download_target
|
||||
else:
|
||||
warnings.warn(
|
||||
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
|
||||
)
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(
|
||||
total=int(source.info().get("Content-Length")),
|
||||
ncols=80,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
model_bytes = open(download_target, "rb").read()
|
||||
if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
|
||||
raise RuntimeError(
|
||||
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
||||
)
|
||||
|
||||
return model_bytes if in_memory else download_target
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load_model(
|
||||
name: str,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
download_root: str = None,
|
||||
in_memory: bool = False,
|
||||
use_dml_attn: bool = False,
|
||||
) -> Whisper:
|
||||
"""
|
||||
Load a Whisper ASR model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
one of the official model names listed by `whisper.available_models()`, or
|
||||
path to a model checkpoint containing the model dimensions and the model state_dict.
|
||||
device : Union[str, torch.device]
|
||||
the PyTorch device to put the model into
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/whisper"
|
||||
in_memory: bool
|
||||
whether to preload the model weights into host memory
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : Whisper
|
||||
The Whisper ASR model instance
|
||||
"""
|
||||
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
if download_root is None:
|
||||
default = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
|
||||
|
||||
if name in _MODELS:
|
||||
checkpoint_file = _download(_MODELS[name], download_root, in_memory)
|
||||
alignment_heads = _ALIGNMENT_HEADS[name]
|
||||
elif os.path.isfile(name):
|
||||
checkpoint_file = open(name, "rb").read() if in_memory else name
|
||||
alignment_heads = None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Model {name} not found; available models = {available_models()}"
|
||||
)
|
||||
|
||||
# with (
|
||||
# io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb")
|
||||
# ) as fp:
|
||||
# # checkpoint = torch.load(fp, map_location=device)
|
||||
# checkpoint = torch.load(fp, mmap=True, weights_only=True)
|
||||
# del checkpoint_file
|
||||
checkpoint = torch.load(checkpoint_file, mmap=True, weights_only=True)
|
||||
|
||||
dims = ModelDimensions(**checkpoint["dims"])
|
||||
model = Whisper(dims, use_dml_attn=use_dml_attn)
|
||||
|
||||
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
if alignment_heads is not None:
|
||||
model.set_alignment_heads(alignment_heads)
|
||||
|
||||
return model.to(device)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
Двоичный файл не отображается.
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
from functools import lru_cache
|
||||
from subprocess import CalledProcessError, run
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import exact_div
|
||||
|
||||
# hard-coded audio hyperparameters
|
||||
SAMPLE_RATE = 16000
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
CHUNK_LENGTH = 30
|
||||
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
||||
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
||||
|
||||
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
||||
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
||||
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
||||
|
||||
|
||||
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
||||
"""
|
||||
Open an audio file and read as mono waveform, resampling as necessary
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file: str
|
||||
The audio file to open
|
||||
|
||||
sr: int
|
||||
The sample rate to resample the audio if necessary
|
||||
|
||||
Returns
|
||||
-------
|
||||
A NumPy array containing the audio waveform, in float32 dtype.
|
||||
"""
|
||||
|
||||
# This launches a subprocess to decode audio while down-mixing
|
||||
# and resampling as necessary. Requires the ffmpeg CLI in PATH.
|
||||
# fmt: off
|
||||
print(file)
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-nostdin",
|
||||
"-threads", "0",
|
||||
"-i", file,
|
||||
"-f", "s16le",
|
||||
"-ac", "1",
|
||||
"-acodec", "pcm_s16le",
|
||||
"-ar", str(sr),
|
||||
"-"
|
||||
]
|
||||
# fmt: on
|
||||
try:
|
||||
out = run(cmd, capture_output=True, check=True).stdout
|
||||
except CalledProcessError as e:
|
||||
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
||||
|
||||
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
||||
"""
|
||||
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
||||
"""
|
||||
if torch.is_tensor(array):
|
||||
if array.shape[axis] > length:
|
||||
array = array.index_select(
|
||||
dim=axis, index=torch.arange(length, device=array.device)
|
||||
)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
||||
else:
|
||||
if array.shape[axis] > length:
|
||||
array = array.take(indices=range(length), axis=axis)
|
||||
|
||||
if array.shape[axis] < length:
|
||||
pad_widths = [(0, 0)] * array.ndim
|
||||
pad_widths[axis] = (0, length - array.shape[axis])
|
||||
array = np.pad(array, pad_widths)
|
||||
|
||||
return array
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||
)
|
||||
"""
|
||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
with np.load(filters_path, allow_pickle=False) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = 80,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
if isinstance(audio, str):
|
||||
audio = load_audio(audio)
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
|
@ -0,0 +1,809 @@
|
|||
from dataclasses import dataclass, field, replace
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.distributions import Categorical
|
||||
|
||||
from .audio import CHUNK_LENGTH
|
||||
from .tokenizer import Tokenizer, get_tokenizer
|
||||
from .utils import compression_ratio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def detect_language(
|
||||
model: "Whisper", mel: Tensor, tokenizer: Tokenizer = None
|
||||
) -> Tuple[Tensor, List[dict]]:
|
||||
"""
|
||||
Detect the spoken language in the audio, and return them as list of strings, along with the ids
|
||||
of the most probable language tokens and the probability distribution over all language tokens.
|
||||
This is performed outside the main decode loop in order to not interfere with kv-caching.
|
||||
|
||||
Returns
|
||||
-------
|
||||
language_tokens : Tensor, shape = (n_audio,)
|
||||
ids of the most probable language tokens, which appears after the startoftranscript token.
|
||||
language_probs : List[Dict[str, float]], length = n_audio
|
||||
list of dictionaries containing the probability distribution over all languages.
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages
|
||||
)
|
||||
if (
|
||||
tokenizer.language is None
|
||||
or tokenizer.language_token not in tokenizer.sot_sequence
|
||||
):
|
||||
raise ValueError(
|
||||
"This model doesn't have language tokens so it can't perform lang id"
|
||||
)
|
||||
|
||||
single = mel.ndim == 2
|
||||
if single:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
# skip encoder forward pass if already-encoded audio features were given
|
||||
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state):
|
||||
mel = model.encoder(mel)
|
||||
|
||||
# forward pass using a single token, startoftranscript
|
||||
n_audio = mel.shape[0]
|
||||
x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
|
||||
logits = model.logits(x, mel)[:, 0]
|
||||
|
||||
# collect detected languages; suppress all non-language tokens
|
||||
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
|
||||
mask[list(tokenizer.all_language_tokens)] = False
|
||||
logits[:, mask] = -np.inf
|
||||
language_tokens = logits.argmax(dim=-1)
|
||||
language_token_probs = logits.softmax(dim=-1).cpu()
|
||||
language_probs = [
|
||||
{
|
||||
c: language_token_probs[i, j].item()
|
||||
for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)
|
||||
}
|
||||
for i in range(n_audio)
|
||||
]
|
||||
|
||||
if single:
|
||||
language_tokens = language_tokens[0]
|
||||
language_probs = language_probs[0]
|
||||
|
||||
return language_tokens, language_probs
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingOptions:
|
||||
# whether to perform X->X "transcribe" or X->English "translate"
|
||||
task: str = "transcribe"
|
||||
|
||||
# language that the audio is in; uses detected language if None
|
||||
language: Optional[str] = None
|
||||
|
||||
# sampling-related options
|
||||
temperature: float = 0.0
|
||||
sample_len: Optional[int] = None # maximum number of tokens to sample
|
||||
best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
|
||||
beam_size: Optional[int] = None # number of beams in beam search, if t == 0
|
||||
patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
|
||||
|
||||
# "alpha" in Google NMT, or None for length norm, when ranking generations
|
||||
# to select which to return among the beams or best-of-N samples
|
||||
length_penalty: Optional[float] = None
|
||||
|
||||
# text or tokens to feed as the prompt or the prefix; for more info:
|
||||
# https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
|
||||
prompt: Optional[Union[str, List[int]]] = None # for the previous context
|
||||
prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
|
||||
|
||||
# list of tokens ids (or comma-separated token ids) to suppress
|
||||
# "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
|
||||
suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
|
||||
suppress_blank: bool = True # this will suppress blank outputs
|
||||
|
||||
# timestamp sampling options
|
||||
without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
|
||||
max_initial_timestamp: Optional[float] = 1.0
|
||||
|
||||
# implementation details
|
||||
fp16: bool = False # use fp16 for most of the calculation
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DecodingResult:
|
||||
audio_features: Tensor
|
||||
language: str
|
||||
language_probs: Optional[Dict[str, float]] = None
|
||||
tokens: List[int] = field(default_factory=list)
|
||||
text: str = ""
|
||||
avg_logprob: float = np.nan
|
||||
no_speech_prob: float = np.nan
|
||||
temperature: float = np.nan
|
||||
compression_ratio: float = np.nan
|
||||
|
||||
|
||||
class Inference:
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
"""Perform a forward pass on the decoder and return per-token logits"""
|
||||
raise NotImplementedError
|
||||
|
||||
def rearrange_kv_cache(self, source_indices) -> None:
|
||||
"""Update the key-value cache according to the updated beams"""
|
||||
raise NotImplementedError
|
||||
|
||||
def cleanup_caching(self) -> None:
|
||||
"""Clean up any resources or hooks after decoding is finished"""
|
||||
pass
|
||||
|
||||
|
||||
class PyTorchInference(Inference):
|
||||
def __init__(self, model: "Whisper", initial_token_length: int):
|
||||
self.model: "Whisper" = model
|
||||
self.initial_token_length = initial_token_length
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
key_modules = [block.attn.key for block in self.model.decoder.blocks]
|
||||
value_modules = [block.attn.value for block in self.model.decoder.blocks]
|
||||
self.kv_modules = key_modules + value_modules
|
||||
self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
|
||||
|
||||
def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
|
||||
if tokens.shape[-1] > self.initial_token_length:
|
||||
# only need to use the last token except in the first forward pass
|
||||
tokens = tokens[:, -1:]
|
||||
|
||||
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
|
||||
|
||||
def cleanup_caching(self):
|
||||
for hook in self.hooks:
|
||||
hook.remove()
|
||||
|
||||
self.kv_cache = {}
|
||||
self.hooks = []
|
||||
|
||||
class SequenceRanker:
|
||||
def rank(
|
||||
self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]
|
||||
) -> List[int]:
|
||||
"""
|
||||
Given a list of groups of samples and their cumulative log probabilities,
|
||||
return the indices of the samples in each group to select as the final result
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MaximumLikelihoodRanker(SequenceRanker):
|
||||
"""
|
||||
Select the sample with the highest log probabilities, penalized using either
|
||||
a simple length normalization or Google NMT paper's length penalty
|
||||
"""
|
||||
|
||||
def __init__(self, length_penalty: Optional[float]):
|
||||
self.length_penalty = length_penalty
|
||||
|
||||
def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
|
||||
def scores(logprobs, lengths):
|
||||
result = []
|
||||
for logprob, length in zip(logprobs, lengths):
|
||||
if self.length_penalty is None:
|
||||
penalty = length
|
||||
else:
|
||||
# from the Google NMT paper
|
||||
penalty = ((5 + length) / 6) ** self.length_penalty
|
||||
result.append(logprob / penalty)
|
||||
return result
|
||||
|
||||
# get the sequence with the highest score
|
||||
lengths = [[len(t) for t in s] for s in tokens]
|
||||
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
|
||||
|
||||
|
||||
class TokenDecoder:
|
||||
def reset(self):
|
||||
"""Initialize any stateful variables for decoding a new sequence"""
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
"""Specify how to select the next token, based on the current trace and logits
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_batch)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
|
||||
the tokens, appended with the selected next token
|
||||
|
||||
completed : bool
|
||||
True if all sequences has reached the end of text
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def finalize(
|
||||
self, tokens: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
|
||||
"""Finalize search and return the final candidate sequences
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence
|
||||
|
||||
sum_logprobs : Tensor, shape = (n_audio, n_group)
|
||||
cumulative log probabilities for each sequence
|
||||
|
||||
Returns
|
||||
-------
|
||||
tokens : Sequence[Sequence[Tensor]], length = n_audio
|
||||
sequence of Tensors containing candidate token sequences, for each audio input
|
||||
|
||||
sum_logprobs : List[List[float]], length = n_audio
|
||||
sequence of cumulative log probabilities corresponding to the above
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class GreedyDecoder(TokenDecoder):
|
||||
def __init__(self, temperature: float, eot: int):
|
||||
self.temperature = temperature
|
||||
self.eot = eot
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if self.temperature == 0:
|
||||
next_tokens = logits.argmax(dim=-1)
|
||||
else:
|
||||
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
||||
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
||||
|
||||
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
||||
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
||||
|
||||
completed = (tokens[:, -1] == self.eot).all()
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
||||
# make sure each sequence has at least one EOT token at the end
|
||||
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
||||
return tokens, sum_logprobs.tolist()
|
||||
|
||||
|
||||
class BeamSearchDecoder(TokenDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
beam_size: int,
|
||||
eot: int,
|
||||
inference: Inference,
|
||||
patience: Optional[float] = None,
|
||||
):
|
||||
self.beam_size = beam_size
|
||||
self.eot = eot
|
||||
self.inference = inference
|
||||
self.patience = patience or 1.0
|
||||
self.max_candidates: int = round(beam_size * self.patience)
|
||||
self.finished_sequences = None
|
||||
|
||||
assert (
|
||||
self.max_candidates > 0
|
||||
), f"Invalid beam size ({beam_size}) or patience ({patience})"
|
||||
|
||||
def reset(self):
|
||||
self.finished_sequences = None
|
||||
|
||||
def update(
|
||||
self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor
|
||||
) -> Tuple[Tensor, bool]:
|
||||
if tokens.shape[0] % self.beam_size != 0:
|
||||
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
||||
|
||||
n_audio = tokens.shape[0] // self.beam_size
|
||||
if self.finished_sequences is None: # for the first update
|
||||
self.finished_sequences = [{} for _ in range(n_audio)]
|
||||
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
next_tokens, source_indices, finished_sequences = [], [], []
|
||||
for i in range(n_audio):
|
||||
scores, sources, finished = {}, {}, {}
|
||||
|
||||
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
||||
for j in range(self.beam_size):
|
||||
idx = i * self.beam_size + j
|
||||
prefix = tokens[idx].tolist()
|
||||
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
||||
new_logprob = (sum_logprobs[idx] + logprob).item()
|
||||
sequence = tuple(prefix + [token.item()])
|
||||
scores[sequence] = new_logprob
|
||||
sources[sequence] = idx
|
||||
|
||||
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
||||
saved = 0
|
||||
for sequence in sorted(scores, key=scores.get, reverse=True):
|
||||
if sequence[-1] == self.eot:
|
||||
finished[sequence] = scores[sequence]
|
||||
else:
|
||||
sum_logprobs[len(next_tokens)] = scores[sequence]
|
||||
next_tokens.append(sequence)
|
||||
source_indices.append(sources[sequence])
|
||||
|
||||
saved += 1
|
||||
if saved == self.beam_size:
|
||||
break
|
||||
|
||||
finished_sequences.append(finished)
|
||||
|
||||
tokens = torch.tensor(next_tokens, device=tokens.device)
|
||||
# self.inference.rearrange_kv_cache(source_indices)
|
||||
|
||||
# add newly finished sequences to self.finished_sequences
|
||||
assert len(self.finished_sequences) == len(finished_sequences)
|
||||
for previously_finished, newly_finished in zip(
|
||||
self.finished_sequences, finished_sequences
|
||||
):
|
||||
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
||||
if len(previously_finished) >= self.max_candidates:
|
||||
break # the candidate list is full
|
||||
previously_finished[seq] = newly_finished[seq]
|
||||
|
||||
# mark as completed if all audio has enough number of samples
|
||||
completed = all(
|
||||
len(sequences) >= self.max_candidates
|
||||
for sequences in self.finished_sequences
|
||||
)
|
||||
return tokens, completed
|
||||
|
||||
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
||||
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
||||
sum_logprobs = sum_logprobs.cpu()
|
||||
for i, sequences in enumerate(self.finished_sequences):
|
||||
if (
|
||||
len(sequences) < self.beam_size
|
||||
): # when not enough sequences are finished
|
||||
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
||||
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
||||
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
|
||||
if len(sequences) >= self.beam_size:
|
||||
break
|
||||
|
||||
tokens: List[List[Tensor]] = [
|
||||
[torch.tensor(seq) for seq in sequences.keys()]
|
||||
for sequences in self.finished_sequences
|
||||
]
|
||||
sum_logprobs: List[List[float]] = [
|
||||
list(sequences.values()) for sequences in self.finished_sequences
|
||||
]
|
||||
return tokens, sum_logprobs
|
||||
|
||||
|
||||
class LogitFilter:
|
||||
def apply(self, logits: Tensor, tokens: Tensor) -> None:
|
||||
"""Apply any filtering or masking to logits in-place
|
||||
|
||||
Parameters
|
||||
----------
|
||||
logits : Tensor, shape = (n_batch, vocab_size)
|
||||
per-token logits of the probability distribution at the current step
|
||||
|
||||
tokens : Tensor, shape = (n_batch, current_sequence_length)
|
||||
all tokens in the context so far, including the prefix and sot_sequence tokens
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SuppressBlank(LogitFilter):
|
||||
def __init__(self, tokenizer: Tokenizer, sample_begin: int):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
|
||||
|
||||
|
||||
class SuppressTokens(LogitFilter):
|
||||
def __init__(self, suppress_tokens: Sequence[int]):
|
||||
self.suppress_tokens = list(suppress_tokens)
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
logits[:, self.suppress_tokens] = -np.inf
|
||||
|
||||
|
||||
class ApplyTimestampRules(LogitFilter):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: Tokenizer,
|
||||
sample_begin: int,
|
||||
max_initial_timestamp_index: Optional[int],
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.sample_begin = sample_begin
|
||||
self.max_initial_timestamp_index = max_initial_timestamp_index
|
||||
|
||||
def apply(self, logits: Tensor, tokens: Tensor):
|
||||
# suppress <|notimestamps|> which is handled by without_timestamps
|
||||
if self.tokenizer.no_timestamps is not None:
|
||||
logits[:, self.tokenizer.no_timestamps] = -np.inf
|
||||
|
||||
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
|
||||
for k in range(tokens.shape[0]):
|
||||
sampled_tokens = tokens[k, self.sample_begin :]
|
||||
seq = [t for t in sampled_tokens.tolist()]
|
||||
last_was_timestamp = (
|
||||
len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
penultimate_was_timestamp = (
|
||||
len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
|
||||
)
|
||||
|
||||
if last_was_timestamp:
|
||||
if penultimate_was_timestamp: # has to be non-timestamp
|
||||
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
|
||||
else: # cannot be normal text tokens
|
||||
logits[k, : self.tokenizer.eot] = -np.inf
|
||||
|
||||
timestamps = sampled_tokens[
|
||||
sampled_tokens.ge(self.tokenizer.timestamp_begin)
|
||||
]
|
||||
if timestamps.numel() > 0:
|
||||
# timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
|
||||
# also force each segment to have a nonzero length, to prevent infinite looping
|
||||
if last_was_timestamp and not penultimate_was_timestamp:
|
||||
timestamp_last = timestamps[-1]
|
||||
else:
|
||||
timestamp_last = timestamps[-1] + 1
|
||||
logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
|
||||
|
||||
if tokens.shape[1] == self.sample_begin:
|
||||
# suppress generating non-timestamp tokens at the beginning
|
||||
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
# apply the `max_initial_timestamp` option
|
||||
if self.max_initial_timestamp_index is not None:
|
||||
last_allowed = (
|
||||
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
|
||||
)
|
||||
logits[:, last_allowed + 1 :] = -np.inf
|
||||
|
||||
# if sum of probability over timestamps is above any other token, sample timestamp
|
||||
logprobs = F.log_softmax(logits.float(), dim=-1)
|
||||
for k in range(tokens.shape[0]):
|
||||
timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(
|
||||
dim=-1
|
||||
)
|
||||
max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
|
||||
if timestamp_logprob > max_text_token_logprob:
|
||||
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
|
||||
|
||||
|
||||
class DecodingTask:
|
||||
inference: Inference
|
||||
sequence_ranker: SequenceRanker
|
||||
decoder: TokenDecoder
|
||||
logit_filters: List[LogitFilter]
|
||||
|
||||
def __init__(self, model: "Whisper", options: DecodingOptions):
|
||||
self.model = model
|
||||
|
||||
language = options.language or "en"
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=options.task,
|
||||
)
|
||||
self.tokenizer: Tokenizer = tokenizer
|
||||
self.options: DecodingOptions = self._verify_options(options)
|
||||
|
||||
self.n_group: int = options.beam_size or options.best_of or 1
|
||||
self.n_ctx: int = model.dims.n_text_ctx
|
||||
self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
|
||||
|
||||
self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
|
||||
if self.options.without_timestamps:
|
||||
self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
|
||||
|
||||
self.initial_tokens: Tuple[int] = self._get_initial_tokens()
|
||||
self.sample_begin: int = len(self.initial_tokens)
|
||||
self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
|
||||
|
||||
# inference: implements the forward pass through the decoder, including kv caching
|
||||
self.inference = PyTorchInference(model, len(self.initial_tokens))
|
||||
|
||||
# sequence ranker: implements how to rank a group of sampled sequences
|
||||
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
|
||||
|
||||
# decoder: implements how to select the next tokens, given the autoregressive distribution
|
||||
if options.beam_size is not None:
|
||||
self.decoder = BeamSearchDecoder(
|
||||
options.beam_size, tokenizer.eot, self.inference, options.patience
|
||||
)
|
||||
else:
|
||||
self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
|
||||
|
||||
# logit filters: applies various rules to suppress or penalize certain tokens
|
||||
self.logit_filters = []
|
||||
if self.options.suppress_blank:
|
||||
self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
|
||||
if self.options.suppress_tokens:
|
||||
self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
|
||||
if not options.without_timestamps:
|
||||
precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
|
||||
max_initial_timestamp_index = None
|
||||
if options.max_initial_timestamp:
|
||||
max_initial_timestamp_index = round(
|
||||
self.options.max_initial_timestamp / precision
|
||||
)
|
||||
self.logit_filters.append(
|
||||
ApplyTimestampRules(
|
||||
tokenizer, self.sample_begin, max_initial_timestamp_index
|
||||
)
|
||||
)
|
||||
|
||||
def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
|
||||
if options.beam_size is not None and options.best_of is not None:
|
||||
raise ValueError("beam_size and best_of can't be given together")
|
||||
if options.temperature == 0:
|
||||
if options.best_of is not None:
|
||||
raise ValueError("best_of with greedy sampling (T=0) is not compatible")
|
||||
if options.patience is not None and options.beam_size is None:
|
||||
raise ValueError("patience requires beam_size to be given")
|
||||
if options.length_penalty is not None and not (
|
||||
0 <= options.length_penalty <= 1
|
||||
):
|
||||
raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
|
||||
|
||||
return options
|
||||
|
||||
def _get_initial_tokens(self) -> Tuple[int]:
|
||||
tokens = list(self.sot_sequence)
|
||||
|
||||
if prefix := self.options.prefix:
|
||||
prefix_tokens = (
|
||||
self.tokenizer.encode(" " + prefix.strip())
|
||||
if isinstance(prefix, str)
|
||||
else prefix
|
||||
)
|
||||
if self.sample_len is not None:
|
||||
max_prefix_len = self.n_ctx // 2 - self.sample_len
|
||||
prefix_tokens = prefix_tokens[-max_prefix_len:]
|
||||
tokens = tokens + prefix_tokens
|
||||
|
||||
if prompt := self.options.prompt:
|
||||
prompt_tokens = (
|
||||
self.tokenizer.encode(" " + prompt.strip())
|
||||
if isinstance(prompt, str)
|
||||
else prompt
|
||||
)
|
||||
tokens = (
|
||||
[self.tokenizer.sot_prev]
|
||||
+ prompt_tokens[-(self.n_ctx // 2 - 1) :]
|
||||
+ tokens
|
||||
)
|
||||
|
||||
return tuple(tokens)
|
||||
|
||||
def _get_suppress_tokens(self) -> Tuple[int]:
|
||||
suppress_tokens = self.options.suppress_tokens
|
||||
|
||||
if isinstance(suppress_tokens, str):
|
||||
suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
|
||||
|
||||
if -1 in suppress_tokens:
|
||||
suppress_tokens = [t for t in suppress_tokens if t >= 0]
|
||||
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
|
||||
elif suppress_tokens is None or len(suppress_tokens) == 0:
|
||||
suppress_tokens = [] # interpret empty string as an empty list
|
||||
else:
|
||||
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
|
||||
|
||||
suppress_tokens.extend(
|
||||
[
|
||||
self.tokenizer.transcribe,
|
||||
self.tokenizer.translate,
|
||||
self.tokenizer.sot,
|
||||
self.tokenizer.sot_prev,
|
||||
self.tokenizer.sot_lm,
|
||||
]
|
||||
)
|
||||
if self.tokenizer.no_speech is not None:
|
||||
# no-speech probability is collected separately
|
||||
suppress_tokens.append(self.tokenizer.no_speech)
|
||||
|
||||
return tuple(sorted(set(suppress_tokens)))
|
||||
|
||||
def _get_audio_features(self, mel: Tensor):
|
||||
if self.options.fp16:
|
||||
mel = mel.half()
|
||||
if mel.shape[-2:] == (
|
||||
self.model.dims.n_audio_ctx,
|
||||
self.model.dims.n_audio_state,
|
||||
):
|
||||
# encoded audio features are given; skip audio encoding
|
||||
audio_features = mel
|
||||
else:
|
||||
audio_features = self.model.encoder(mel)
|
||||
if audio_features.dtype != (
|
||||
torch.float16 if self.options.fp16 else torch.float32
|
||||
):
|
||||
return TypeError(
|
||||
f"audio_features has an incorrect dtype: {audio_features.dtype}"
|
||||
)
|
||||
return audio_features
|
||||
|
||||
def _detect_language(self, audio_features: Tensor, tokens: Tensor):
|
||||
languages = [self.options.language] * audio_features.shape[0]
|
||||
lang_probs = None
|
||||
|
||||
if self.options.language is None or self.options.task == "lang_id":
|
||||
lang_tokens, lang_probs = self.model.detect_language(
|
||||
audio_features, self.tokenizer
|
||||
)
|
||||
languages = [max(probs, key=probs.get) for probs in lang_probs]
|
||||
if self.options.language is None:
|
||||
tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
|
||||
|
||||
return languages, lang_probs
|
||||
|
||||
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
||||
n_batch = tokens.shape[0]
|
||||
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
||||
no_speech_probs = [np.nan] * n_batch
|
||||
try:
|
||||
for i in range(self.sample_len):
|
||||
logits = self.inference.logits(tokens, audio_features)
|
||||
if (
|
||||
i == 0 and self.tokenizer.no_speech is not None
|
||||
): # save no_speech_probs
|
||||
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
||||
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
|
||||
|
||||
# now we need to consider the logits at the last token only
|
||||
logits = logits[:, -1]
|
||||
|
||||
# apply the logit filters, e.g. for suppressing or applying penalty to
|
||||
for logit_filter in self.logit_filters:
|
||||
logit_filter.apply(logits, tokens)
|
||||
|
||||
# expand the tokens tensor with the selected next tokens
|
||||
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
|
||||
if completed or tokens.shape[-1] > self.n_ctx:
|
||||
break
|
||||
finally:
|
||||
self.inference.cleanup_caching()
|
||||
return tokens, sum_logprobs, no_speech_probs
|
||||
|
||||
@torch.no_grad()
|
||||
def run(self, mel: Tensor) -> List[DecodingResult]:
|
||||
self.decoder.reset()
|
||||
tokenizer: Tokenizer = self.tokenizer
|
||||
n_audio: int = mel.shape[0]
|
||||
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
||||
|
||||
tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
|
||||
|
||||
# detect language if requested, overwriting the language token
|
||||
languages, language_probs = self._detect_language(audio_features, tokens)
|
||||
if self.options.task == "lang_id":
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features, language=language, language_probs=probs
|
||||
)
|
||||
for features, language, probs in zip(
|
||||
audio_features, languages, language_probs
|
||||
)
|
||||
]
|
||||
|
||||
# repeat text tensors by the group size, for beam search or best-of-n sampling
|
||||
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
||||
|
||||
# call the main sampling loop
|
||||
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
||||
|
||||
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
||||
audio_features = audio_features[:: self.n_group]
|
||||
no_speech_probs = no_speech_probs[:: self.n_group]
|
||||
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
||||
|
||||
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
||||
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
||||
|
||||
# get the final candidates for each group, and slice between the first sampled token and EOT
|
||||
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
|
||||
tokens: List[List[Tensor]] = [
|
||||
[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s]
|
||||
for s in tokens
|
||||
]
|
||||
|
||||
# select the top-ranked sample in each group
|
||||
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
||||
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
||||
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
||||
|
||||
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
||||
avg_logprobs: List[float] = [
|
||||
lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)
|
||||
]
|
||||
|
||||
fields = (
|
||||
texts,
|
||||
languages,
|
||||
tokens,
|
||||
audio_features,
|
||||
avg_logprobs,
|
||||
no_speech_probs,
|
||||
)
|
||||
if len(set(map(len, fields))) != 1:
|
||||
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
||||
|
||||
return [
|
||||
DecodingResult(
|
||||
audio_features=features,
|
||||
language=language,
|
||||
tokens=tokens,
|
||||
text=text,
|
||||
avg_logprob=avg_logprob,
|
||||
no_speech_prob=no_speech_prob,
|
||||
temperature=self.options.temperature,
|
||||
compression_ratio=compression_ratio(text),
|
||||
)
|
||||
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(
|
||||
*fields
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(
|
||||
model: "Whisper",
|
||||
mel: Tensor,
|
||||
options: DecodingOptions = DecodingOptions(),
|
||||
**kwargs,
|
||||
) -> Union[DecodingResult, List[DecodingResult]]:
|
||||
"""
|
||||
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
the Whisper model instance
|
||||
|
||||
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
||||
A tensor containing the Mel spectrogram(s)
|
||||
|
||||
options: DecodingOptions
|
||||
A dataclass that contains all necessary options for decoding 30-second segments
|
||||
|
||||
Returns
|
||||
-------
|
||||
result: Union[DecodingResult, List[DecodingResult]]
|
||||
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
||||
"""
|
||||
if single := mel.ndim == 2:
|
||||
mel = mel.unsqueeze(0)
|
||||
|
||||
if kwargs:
|
||||
options = replace(options, **kwargs)
|
||||
result = DecodingTask(model, options).run(mel)
|
||||
return result[0] if single else result
|
|
@ -0,0 +1,364 @@
|
|||
import base64
|
||||
import gzip
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, Optional
|
||||
import time
|
||||
|
||||
import torch_directml
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .decoding import decode as decode_function
|
||||
from .decoding import detect_language as detect_language_function
|
||||
from .transcribe import transcribe as transcribe_function
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelDimensions:
|
||||
n_mels: int
|
||||
n_audio_ctx: int
|
||||
n_audio_state: int
|
||||
n_audio_head: int
|
||||
n_audio_layer: int
|
||||
n_vocab: int
|
||||
n_text_ctx: int
|
||||
n_text_state: int
|
||||
n_text_head: int
|
||||
n_text_layer: int
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(
|
||||
x,
|
||||
self.weight.to(x.dtype),
|
||||
None if self.bias is None else self.bias.to(x.dtype),
|
||||
)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
self.past_key_tensor = None
|
||||
self.past_value_tensor = None
|
||||
|
||||
self.cross_key = None
|
||||
self.cross_value = None
|
||||
self.total_seq_len = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
cross_attention=False,
|
||||
use_dml_attn=False,
|
||||
):
|
||||
q = self.query(x)
|
||||
if xa is None:
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
else:
|
||||
cross_attention = True
|
||||
if self.past_key_tensor is None:
|
||||
# for cross-attention, calculate keys and values once and reuse in subsequent calls.
|
||||
self.cross_key = self.key(xa)
|
||||
self.cross_value = self.value(xa)
|
||||
|
||||
k = self.cross_key
|
||||
v = self.cross_value
|
||||
|
||||
if use_dml_attn:
|
||||
wv, qk = self.dml_sdp_attn(q, k, v, mask, cross_attention=cross_attention)
|
||||
else:
|
||||
wv, qk = self.qkv_attention(q, k, v, mask)
|
||||
|
||||
return self.out(wv), qk
|
||||
|
||||
def qkv_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
|
||||
):
|
||||
_, n_ctx, n_state = q.shape
|
||||
scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
|
||||
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
|
||||
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
|
||||
|
||||
qk = q @ k
|
||||
if mask is not None:
|
||||
qk = qk + mask[:n_ctx, :n_ctx]
|
||||
qk = qk.float()
|
||||
|
||||
w = F.softmax(qk, dim=-1).to(q.dtype)
|
||||
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
|
||||
|
||||
def dml_sdp_attn(self, q, k, v, mask, cross_attention=False):
|
||||
_, n_ctx, n_state = q.shape
|
||||
if mask is not None:
|
||||
self.total_seq_len += n_ctx
|
||||
mask = mask.expand(-1, -1, n_ctx, self.total_seq_len)
|
||||
|
||||
# cross attention i.e. encoder/decoder attn uses same k and v but the query changes
|
||||
if cross_attention:
|
||||
y, self.past_key_tensor, self.past_value_tensor= torch_directml.multi_head_attention(
|
||||
q, k, v, n_state, self.n_head, None, None, mask
|
||||
)
|
||||
else:
|
||||
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
|
||||
q, k, v, n_state, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
|
||||
)
|
||||
return y, None
|
||||
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
xa: Optional[Tensor] = None,
|
||||
mask: Optional[Tensor] = None,
|
||||
kv_cache: Optional[dict] = None,
|
||||
use_dml_attn: bool = False,
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, use_dml_attn=use_dml_attn)[0]
|
||||
if self.cross_attn:
|
||||
x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, use_dml_attn=use_dml_attn)[0]
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, use_dml_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_dml_attn = use_dml_attn
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
|
||||
def forward(self, x: Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
x = F.gelu(self.conv1(x))
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
# assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
|
||||
x = (x + self.positional_embedding).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, use_dml_attn=self.use_dml_attn)
|
||||
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
class TextDecoder(nn.Module):
|
||||
def __init__(
|
||||
self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, use_dml_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(n_vocab, n_state)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
|
||||
self.use_dml_attn = use_dml_attn
|
||||
|
||||
self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
|
||||
[
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
)
|
||||
self.ln = LayerNorm(n_state)
|
||||
|
||||
if self.use_dml_attn:
|
||||
mask = torch.ones([1, n_head, 1, 1], dtype=torch.int32)
|
||||
else:
|
||||
mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
|
||||
self.register_buffer("mask", mask, persistent=False)
|
||||
self.pos = 0
|
||||
|
||||
def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
|
||||
"""
|
||||
x : torch.LongTensor, shape = (batch_size, <= n_ctx)
|
||||
the text tokens
|
||||
xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
|
||||
the encoded audio features to be attended on
|
||||
"""
|
||||
offset = self.pos
|
||||
self.pos += x.shape[1]
|
||||
x = (
|
||||
self.token_embedding(x)
|
||||
+ self.positional_embedding[offset : offset + x.shape[-1]]
|
||||
)
|
||||
x = x.to(xa.dtype)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
x = block(x, xa, mask=self.mask, kv_cache=kv_cache, use_dml_attn=self.use_dml_attn)
|
||||
|
||||
x = self.ln(x)
|
||||
logits = (
|
||||
x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
|
||||
).float()
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
class Whisper(nn.Module):
|
||||
def __init__(self, dims: ModelDimensions, use_dml_attn=False):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
self.encoder = AudioEncoder(
|
||||
self.dims.n_mels,
|
||||
self.dims.n_audio_ctx,
|
||||
self.dims.n_audio_state,
|
||||
self.dims.n_audio_head,
|
||||
self.dims.n_audio_layer,
|
||||
use_dml_attn
|
||||
)
|
||||
self.decoder = TextDecoder(
|
||||
self.dims.n_vocab,
|
||||
self.dims.n_text_ctx,
|
||||
self.dims.n_text_state,
|
||||
self.dims.n_text_head,
|
||||
self.dims.n_text_layer,
|
||||
use_dml_attn,
|
||||
)
|
||||
# use the last half among the decoder layers for time alignment by default;
|
||||
# to use a specific set of heads, see `set_alignment_heads()` below.
|
||||
all_heads = torch.zeros(
|
||||
self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool
|
||||
)
|
||||
all_heads[self.dims.n_text_layer // 2 :] = True
|
||||
self.register_buffer("alignment_heads", all_heads, persistent=False)
|
||||
|
||||
def set_alignment_heads(self, dump: bytes):
|
||||
array = np.frombuffer(
|
||||
gzip.decompress(base64.b85decode(dump)), dtype=bool
|
||||
).copy()
|
||||
mask = torch.from_numpy(array).reshape(
|
||||
self.dims.n_text_layer, self.dims.n_text_head
|
||||
)
|
||||
self.register_buffer("alignment_heads", mask, persistent=False)
|
||||
|
||||
def embed_audio(self, mel: torch.Tensor):
|
||||
return self.encoder(mel)
|
||||
|
||||
def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
|
||||
return self.decoder(tokens, audio_features)
|
||||
|
||||
def forward(
|
||||
self, mel: torch.Tensor, tokens: torch.Tensor
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
return self.decoder(tokens, self.encoder(mel))
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return next(self.parameters()).device
|
||||
|
||||
@property
|
||||
def is_multilingual(self):
|
||||
return self.dims.n_vocab >= 51865
|
||||
|
||||
@property
|
||||
def num_languages(self):
|
||||
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
|
||||
|
||||
def install_kv_cache_hooks(self, cache: Optional[dict] = None):
|
||||
"""
|
||||
The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
|
||||
tensors calculated for the previous positions. This method returns a dictionary that stores
|
||||
all caches, and the necessary hooks for the key and value projection modules that save the
|
||||
intermediate tensors to be reused during later calculations.
|
||||
|
||||
Returns
|
||||
-------
|
||||
cache : Dict[nn.Module, torch.Tensor]
|
||||
A dictionary object mapping the key/value projection modules to its cache
|
||||
hooks : List[RemovableHandle]
|
||||
List of PyTorch RemovableHandle objects to stop the hooks to be called
|
||||
"""
|
||||
cache = {**cache} if cache is not None else {}
|
||||
hooks = []
|
||||
|
||||
def save_to_cache(module, _, output):
|
||||
if module not in cache or output.shape[1] > self.dims.n_text_ctx:
|
||||
# save as-is, for the first token or cross attention
|
||||
cache[module] = output
|
||||
else:
|
||||
cache[module] = torch.cat([cache[module], output], dim=1).detach()
|
||||
return cache[module]
|
||||
|
||||
def install_hooks(layer: nn.Module):
|
||||
if isinstance(layer, MultiHeadAttention):
|
||||
if self.decoder.use_dml_attn:
|
||||
layer.total_seq_len = 0
|
||||
layer.past_key_tensor = None
|
||||
layer.past_value_tensor = None
|
||||
else:
|
||||
hooks.append(layer.key.register_forward_hook(save_to_cache))
|
||||
hooks.append(layer.value.register_forward_hook(save_to_cache))
|
||||
|
||||
self.decoder.apply(install_hooks)
|
||||
self.decoder.pos = 0
|
||||
return cache, hooks
|
||||
|
||||
detect_language = detect_language_function
|
||||
transcribe = transcribe_function
|
||||
decode = decode_function
|
|
@ -0,0 +1,2 @@
|
|||
from .basic import BasicTextNormalizer as BasicTextNormalizer
|
||||
from .english import EnglishTextNormalizer as EnglishTextNormalizer
|
|
@ -0,0 +1,76 @@
|
|||
import re
|
||||
import unicodedata
|
||||
|
||||
import regex
|
||||
|
||||
# non-ASCII letters that are not separated by "NFKD" normalization
|
||||
ADDITIONAL_DIACRITICS = {
|
||||
"œ": "oe",
|
||||
"Œ": "OE",
|
||||
"ø": "o",
|
||||
"Ø": "O",
|
||||
"æ": "ae",
|
||||
"Æ": "AE",
|
||||
"ß": "ss",
|
||||
"ẞ": "SS",
|
||||
"đ": "d",
|
||||
"Đ": "D",
|
||||
"ð": "d",
|
||||
"Ð": "D",
|
||||
"þ": "th",
|
||||
"Þ": "th",
|
||||
"ł": "l",
|
||||
"Ł": "L",
|
||||
}
|
||||
|
||||
|
||||
def remove_symbols_and_diacritics(s: str, keep=""):
|
||||
"""
|
||||
Replace any other markers, symbols, and punctuations with a space,
|
||||
and drop any diacritics (category 'Mn' and some manual mappings)
|
||||
"""
|
||||
return "".join(
|
||||
c
|
||||
if c in keep
|
||||
else ADDITIONAL_DIACRITICS[c]
|
||||
if c in ADDITIONAL_DIACRITICS
|
||||
else ""
|
||||
if unicodedata.category(c) == "Mn"
|
||||
else " "
|
||||
if unicodedata.category(c)[0] in "MSP"
|
||||
else c
|
||||
for c in unicodedata.normalize("NFKD", s)
|
||||
)
|
||||
|
||||
|
||||
def remove_symbols(s: str):
|
||||
"""
|
||||
Replace any other markers, symbols, punctuations with a space, keeping diacritics
|
||||
"""
|
||||
return "".join(
|
||||
" " if unicodedata.category(c)[0] in "MSP" else c
|
||||
for c in unicodedata.normalize("NFKC", s)
|
||||
)
|
||||
|
||||
|
||||
class BasicTextNormalizer:
|
||||
def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
|
||||
self.clean = (
|
||||
remove_symbols_and_diacritics if remove_diacritics else remove_symbols
|
||||
)
|
||||
self.split_letters = split_letters
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = self.clean(s).lower()
|
||||
|
||||
if self.split_letters:
|
||||
s = " ".join(regex.findall(r"\X", s, regex.U))
|
||||
|
||||
s = re.sub(
|
||||
r"\s+", " ", s
|
||||
) # replace any successive whitespace characters with a space
|
||||
|
||||
return s
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,550 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from fractions import Fraction
|
||||
from typing import Iterator, List, Match, Optional, Union
|
||||
|
||||
from more_itertools import windowed
|
||||
|
||||
from .basic import remove_symbols_and_diacritics
|
||||
|
||||
|
||||
class EnglishNumberNormalizer:
|
||||
"""
|
||||
Convert any spelled-out numbers into arabic numbers, while handling:
|
||||
|
||||
- remove any commas
|
||||
- keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
|
||||
- spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
|
||||
- spell out `one` and `ones`
|
||||
- interpret successive single-digit numbers as nominal: `one oh one` -> `101`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.zeros = {"o", "oh", "zero"}
|
||||
self.ones = {
|
||||
name: i
|
||||
for i, name in enumerate(
|
||||
[
|
||||
"one",
|
||||
"two",
|
||||
"three",
|
||||
"four",
|
||||
"five",
|
||||
"six",
|
||||
"seven",
|
||||
"eight",
|
||||
"nine",
|
||||
"ten",
|
||||
"eleven",
|
||||
"twelve",
|
||||
"thirteen",
|
||||
"fourteen",
|
||||
"fifteen",
|
||||
"sixteen",
|
||||
"seventeen",
|
||||
"eighteen",
|
||||
"nineteen",
|
||||
],
|
||||
start=1,
|
||||
)
|
||||
}
|
||||
self.ones_plural = {
|
||||
"sixes" if name == "six" else name + "s": (value, "s")
|
||||
for name, value in self.ones.items()
|
||||
}
|
||||
self.ones_ordinal = {
|
||||
"zeroth": (0, "th"),
|
||||
"first": (1, "st"),
|
||||
"second": (2, "nd"),
|
||||
"third": (3, "rd"),
|
||||
"fifth": (5, "th"),
|
||||
"twelfth": (12, "th"),
|
||||
**{
|
||||
name + ("h" if name.endswith("t") else "th"): (value, "th")
|
||||
for name, value in self.ones.items()
|
||||
if value > 3 and value != 5 and value != 12
|
||||
},
|
||||
}
|
||||
self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
|
||||
|
||||
self.tens = {
|
||||
"twenty": 20,
|
||||
"thirty": 30,
|
||||
"forty": 40,
|
||||
"fifty": 50,
|
||||
"sixty": 60,
|
||||
"seventy": 70,
|
||||
"eighty": 80,
|
||||
"ninety": 90,
|
||||
}
|
||||
self.tens_plural = {
|
||||
name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_ordinal = {
|
||||
name.replace("y", "ieth"): (value, "th")
|
||||
for name, value in self.tens.items()
|
||||
}
|
||||
self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
|
||||
|
||||
self.multipliers = {
|
||||
"hundred": 100,
|
||||
"thousand": 1_000,
|
||||
"million": 1_000_000,
|
||||
"billion": 1_000_000_000,
|
||||
"trillion": 1_000_000_000_000,
|
||||
"quadrillion": 1_000_000_000_000_000,
|
||||
"quintillion": 1_000_000_000_000_000_000,
|
||||
"sextillion": 1_000_000_000_000_000_000_000,
|
||||
"septillion": 1_000_000_000_000_000_000_000_000,
|
||||
"octillion": 1_000_000_000_000_000_000_000_000_000,
|
||||
"nonillion": 1_000_000_000_000_000_000_000_000_000_000,
|
||||
"decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
|
||||
}
|
||||
self.multipliers_plural = {
|
||||
name + "s": (value, "s") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_ordinal = {
|
||||
name + "th": (value, "th") for name, value in self.multipliers.items()
|
||||
}
|
||||
self.multipliers_suffixed = {
|
||||
**self.multipliers_plural,
|
||||
**self.multipliers_ordinal,
|
||||
}
|
||||
self.decimals = {*self.ones, *self.tens, *self.zeros}
|
||||
|
||||
self.preceding_prefixers = {
|
||||
"minus": "-",
|
||||
"negative": "-",
|
||||
"plus": "+",
|
||||
"positive": "+",
|
||||
}
|
||||
self.following_prefixers = {
|
||||
"pound": "£",
|
||||
"pounds": "£",
|
||||
"euro": "€",
|
||||
"euros": "€",
|
||||
"dollar": "$",
|
||||
"dollars": "$",
|
||||
"cent": "¢",
|
||||
"cents": "¢",
|
||||
}
|
||||
self.prefixes = set(
|
||||
list(self.preceding_prefixers.values())
|
||||
+ list(self.following_prefixers.values())
|
||||
)
|
||||
self.suffixers = {
|
||||
"per": {"cent": "%"},
|
||||
"percent": "%",
|
||||
}
|
||||
self.specials = {"and", "double", "triple", "point"}
|
||||
|
||||
self.words = set(
|
||||
[
|
||||
key
|
||||
for mapping in [
|
||||
self.zeros,
|
||||
self.ones,
|
||||
self.ones_suffixed,
|
||||
self.tens,
|
||||
self.tens_suffixed,
|
||||
self.multipliers,
|
||||
self.multipliers_suffixed,
|
||||
self.preceding_prefixers,
|
||||
self.following_prefixers,
|
||||
self.suffixers,
|
||||
self.specials,
|
||||
]
|
||||
for key in mapping
|
||||
]
|
||||
)
|
||||
self.literal_words = {"one", "ones"}
|
||||
|
||||
def process_words(self, words: List[str]) -> Iterator[str]:
|
||||
prefix: Optional[str] = None
|
||||
value: Optional[Union[str, int]] = None
|
||||
skip = False
|
||||
|
||||
def to_fraction(s: str):
|
||||
try:
|
||||
return Fraction(s)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
def output(result: Union[str, int]):
|
||||
nonlocal prefix, value
|
||||
result = str(result)
|
||||
if prefix is not None:
|
||||
result = prefix + result
|
||||
value = None
|
||||
prefix = None
|
||||
return result
|
||||
|
||||
if len(words) == 0:
|
||||
return
|
||||
|
||||
for prev, current, next in windowed([None] + words + [None], 3):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
|
||||
has_prefix = current[0] in self.prefixes
|
||||
current_without_prefix = current[1:] if has_prefix else current
|
||||
if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
|
||||
# arabic numbers (potentially with signs and fractions)
|
||||
f = to_fraction(current_without_prefix)
|
||||
assert f is not None
|
||||
if value is not None:
|
||||
if isinstance(value, str) and value.endswith("."):
|
||||
# concatenate decimals / ip address components
|
||||
value = str(value) + str(current)
|
||||
continue
|
||||
else:
|
||||
yield output(value)
|
||||
|
||||
prefix = current[0] if has_prefix else prefix
|
||||
if f.denominator == 1:
|
||||
value = f.numerator # store integers as int
|
||||
else:
|
||||
value = current_without_prefix
|
||||
elif current not in self.words:
|
||||
# non-numeric words
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current in self.zeros:
|
||||
value = str(value or "") + "0"
|
||||
elif current in self.ones:
|
||||
ones = self.ones[current]
|
||||
|
||||
if value is None:
|
||||
value = ones
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if (
|
||||
prev in self.tens and ones < 10
|
||||
): # replace the last zero with the digit
|
||||
assert value[-1] == "0"
|
||||
value = value[:-1] + str(ones)
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
value += ones
|
||||
else:
|
||||
value = str(value) + str(ones)
|
||||
elif current in self.ones_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
ones, suffix = self.ones_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(ones) + suffix)
|
||||
elif isinstance(value, str) or prev in self.ones:
|
||||
if prev in self.tens and ones < 10:
|
||||
assert value[-1] == "0"
|
||||
yield output(value[:-1] + str(ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
elif ones < 10:
|
||||
if value % 10 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
else: # eleven to nineteen
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + ones) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(ones) + suffix)
|
||||
value = None
|
||||
elif current in self.tens:
|
||||
tens = self.tens[current]
|
||||
if value is None:
|
||||
value = tens
|
||||
elif isinstance(value, str):
|
||||
value = str(value) + str(tens)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
value += tens
|
||||
else:
|
||||
value = str(value) + str(tens)
|
||||
elif current in self.tens_suffixed:
|
||||
# ordinal or cardinal; yield the number right away
|
||||
tens, suffix = self.tens_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(tens) + suffix)
|
||||
elif isinstance(value, str):
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
else:
|
||||
if value % 100 == 0:
|
||||
yield output(str(value + tens) + suffix)
|
||||
else:
|
||||
yield output(str(value) + str(tens) + suffix)
|
||||
elif current in self.multipliers:
|
||||
multiplier = self.multipliers[current]
|
||||
if value is None:
|
||||
value = multiplier
|
||||
elif isinstance(value, str) or value == 0:
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
value = p.numerator
|
||||
else:
|
||||
yield output(value)
|
||||
value = multiplier
|
||||
else:
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
elif current in self.multipliers_suffixed:
|
||||
multiplier, suffix = self.multipliers_suffixed[current]
|
||||
if value is None:
|
||||
yield output(str(multiplier) + suffix)
|
||||
elif isinstance(value, str):
|
||||
f = to_fraction(value)
|
||||
p = f * multiplier if f is not None else None
|
||||
if f is not None and p.denominator == 1:
|
||||
yield output(str(p.numerator) + suffix)
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(str(multiplier) + suffix)
|
||||
else: # int
|
||||
before = value // 1000 * 1000
|
||||
residual = value % 1000
|
||||
value = before + residual * multiplier
|
||||
yield output(str(value) + suffix)
|
||||
value = None
|
||||
elif current in self.preceding_prefixers:
|
||||
# apply prefix (positive, minus, etc.) if it precedes a number
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
if next in self.words or next_is_numeric:
|
||||
prefix = self.preceding_prefixers[current]
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.following_prefixers:
|
||||
# apply prefix (dollars, cents, etc.) only after a number
|
||||
if value is not None:
|
||||
prefix = self.following_prefixers[current]
|
||||
yield output(value)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.suffixers:
|
||||
# apply suffix symbols (percent -> '%')
|
||||
if value is not None:
|
||||
suffix = self.suffixers[current]
|
||||
if isinstance(suffix, dict):
|
||||
if next in suffix:
|
||||
yield output(str(value) + suffix[next])
|
||||
skip = True
|
||||
else:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
else:
|
||||
yield output(str(value) + suffix)
|
||||
else:
|
||||
yield output(current)
|
||||
elif current in self.specials:
|
||||
if next not in self.words and not next_is_numeric:
|
||||
# apply special handling only if the next word can be numeric
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "and":
|
||||
# ignore "and" after hundreds, thousands, etc.
|
||||
if prev not in self.multipliers:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "double" or current == "triple":
|
||||
if next in self.ones or next in self.zeros:
|
||||
repeats = 2 if current == "double" else 3
|
||||
ones = self.ones.get(next, 0)
|
||||
value = str(value or "") + str(ones) * repeats
|
||||
skip = True
|
||||
else:
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
yield output(current)
|
||||
elif current == "point":
|
||||
if next in self.decimals or next_is_numeric:
|
||||
value = str(value or "") + "."
|
||||
else:
|
||||
# should all have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
else:
|
||||
# all should have been covered at this point
|
||||
raise ValueError(f"Unexpected token: {current}")
|
||||
|
||||
if value is not None:
|
||||
yield output(value)
|
||||
|
||||
def preprocess(self, s: str):
|
||||
# replace "<number> and a half" with "<number> point five"
|
||||
results = []
|
||||
|
||||
segments = re.split(r"\band\s+a\s+half\b", s)
|
||||
for i, segment in enumerate(segments):
|
||||
if len(segment.strip()) == 0:
|
||||
continue
|
||||
if i == len(segments) - 1:
|
||||
results.append(segment)
|
||||
else:
|
||||
results.append(segment)
|
||||
last_word = segment.rsplit(maxsplit=2)[-1]
|
||||
if last_word in self.decimals or last_word in self.multipliers:
|
||||
results.append("point five")
|
||||
else:
|
||||
results.append("and a half")
|
||||
|
||||
s = " ".join(results)
|
||||
|
||||
# put a space at number/letter boundary
|
||||
s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
|
||||
s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
|
||||
|
||||
# but remove spaces which could be a suffix
|
||||
s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
|
||||
|
||||
return s
|
||||
|
||||
def postprocess(self, s: str):
|
||||
def combine_cents(m: Match):
|
||||
try:
|
||||
currency = m.group(1)
|
||||
integer = m.group(2)
|
||||
cents = int(m.group(3))
|
||||
return f"{currency}{integer}.{cents:02d}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
def extract_cents(m: Match):
|
||||
try:
|
||||
return f"¢{int(m.group(1))}"
|
||||
except ValueError:
|
||||
return m.string
|
||||
|
||||
# apply currency postprocessing; "$2 and ¢7" -> "$2.07"
|
||||
s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s)
|
||||
s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s)
|
||||
|
||||
# write "one(s)" instead of "1(s)", just for the readability
|
||||
s = re.sub(r"\b1(s?)\b", r"one\1", s)
|
||||
|
||||
return s
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = self.preprocess(s)
|
||||
s = " ".join(word for word in self.process_words(s.split()) if word is not None)
|
||||
s = self.postprocess(s)
|
||||
|
||||
return s
|
||||
|
||||
|
||||
class EnglishSpellingNormalizer:
|
||||
"""
|
||||
Applies British-American spelling mappings as listed in [1].
|
||||
|
||||
[1] https://www.tysto.com/uk-us-spelling-list.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
|
||||
self.mapping = json.load(open(mapping_path))
|
||||
|
||||
def __call__(self, s: str):
|
||||
return " ".join(self.mapping.get(word, word) for word in s.split())
|
||||
|
||||
|
||||
class EnglishTextNormalizer:
|
||||
def __init__(self):
|
||||
self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
|
||||
self.replacers = {
|
||||
# common contractions
|
||||
r"\bwon't\b": "will not",
|
||||
r"\bcan't\b": "can not",
|
||||
r"\blet's\b": "let us",
|
||||
r"\bain't\b": "aint",
|
||||
r"\by'all\b": "you all",
|
||||
r"\bwanna\b": "want to",
|
||||
r"\bgotta\b": "got to",
|
||||
r"\bgonna\b": "going to",
|
||||
r"\bi'ma\b": "i am going to",
|
||||
r"\bimma\b": "i am going to",
|
||||
r"\bwoulda\b": "would have",
|
||||
r"\bcoulda\b": "could have",
|
||||
r"\bshoulda\b": "should have",
|
||||
r"\bma'am\b": "madam",
|
||||
# contractions in titles/prefixes
|
||||
r"\bmr\b": "mister ",
|
||||
r"\bmrs\b": "missus ",
|
||||
r"\bst\b": "saint ",
|
||||
r"\bdr\b": "doctor ",
|
||||
r"\bprof\b": "professor ",
|
||||
r"\bcapt\b": "captain ",
|
||||
r"\bgov\b": "governor ",
|
||||
r"\bald\b": "alderman ",
|
||||
r"\bgen\b": "general ",
|
||||
r"\bsen\b": "senator ",
|
||||
r"\brep\b": "representative ",
|
||||
r"\bpres\b": "president ",
|
||||
r"\brev\b": "reverend ",
|
||||
r"\bhon\b": "honorable ",
|
||||
r"\basst\b": "assistant ",
|
||||
r"\bassoc\b": "associate ",
|
||||
r"\blt\b": "lieutenant ",
|
||||
r"\bcol\b": "colonel ",
|
||||
r"\bjr\b": "junior ",
|
||||
r"\bsr\b": "senior ",
|
||||
r"\besq\b": "esquire ",
|
||||
# prefect tenses, ideally it should be any past participles, but it's harder..
|
||||
r"'d been\b": " had been",
|
||||
r"'s been\b": " has been",
|
||||
r"'d gone\b": " had gone",
|
||||
r"'s gone\b": " has gone",
|
||||
r"'d done\b": " had done", # "'s done" is ambiguous
|
||||
r"'s got\b": " has got",
|
||||
# general contractions
|
||||
r"n't\b": " not",
|
||||
r"'re\b": " are",
|
||||
r"'s\b": " is",
|
||||
r"'d\b": " would",
|
||||
r"'ll\b": " will",
|
||||
r"'t\b": " not",
|
||||
r"'ve\b": " have",
|
||||
r"'m\b": " am",
|
||||
}
|
||||
self.standardize_numbers = EnglishNumberNormalizer()
|
||||
self.standardize_spellings = EnglishSpellingNormalizer()
|
||||
|
||||
def __call__(self, s: str):
|
||||
s = s.lower()
|
||||
|
||||
s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
|
||||
s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
|
||||
s = re.sub(self.ignore_patterns, "", s)
|
||||
s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
|
||||
|
||||
for pattern, replacement in self.replacers.items():
|
||||
s = re.sub(pattern, replacement, s)
|
||||
|
||||
s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
|
||||
s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
|
||||
s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols
|
||||
|
||||
s = self.standardize_numbers(s)
|
||||
s = self.standardize_spellings(s)
|
||||
|
||||
# now remove prefix/suffix symbols that are not preceded/followed by numbers
|
||||
s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s)
|
||||
s = re.sub(r"([^0-9])%", r"\1 ", s)
|
||||
|
||||
s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
|
||||
|
||||
return s
|
|
@ -0,0 +1,365 @@
|
|||
import itertools
|
||||
import subprocess
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def median_filter(x: torch.Tensor, filter_width: int):
|
||||
"""Apply a median filter of width `filter_width` along the last dimension of `x`"""
|
||||
pad_width = filter_width // 2
|
||||
if x.shape[-1] <= pad_width:
|
||||
# F.pad requires the padding width to be smaller than the input dimension
|
||||
return x
|
||||
|
||||
if (ndim := x.ndim) <= 2:
|
||||
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
|
||||
x = x[None, None, :]
|
||||
|
||||
assert (
|
||||
filter_width > 0 and filter_width % 2 == 1
|
||||
), "`filter_width` should be an odd number"
|
||||
|
||||
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
|
||||
|
||||
# sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
|
||||
result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
|
||||
|
||||
if ndim <= 2:
|
||||
result = result[0, 0]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@numba.jit(nopython=True)
|
||||
def backtrace(trace: np.ndarray):
|
||||
i = trace.shape[0] - 1
|
||||
j = trace.shape[1] - 1
|
||||
trace[0, :] = 2
|
||||
trace[:, 0] = 1
|
||||
|
||||
result = []
|
||||
while i > 0 or j > 0:
|
||||
result.append((i - 1, j - 1))
|
||||
|
||||
if trace[i, j] == 0:
|
||||
i -= 1
|
||||
j -= 1
|
||||
elif trace[i, j] == 1:
|
||||
i -= 1
|
||||
elif trace[i, j] == 2:
|
||||
j -= 1
|
||||
else:
|
||||
raise ValueError("Unexpected trace[i, j]")
|
||||
|
||||
result = np.array(result)
|
||||
return result[::-1, :].T
|
||||
|
||||
|
||||
@numba.jit(nopython=True, parallel=True)
|
||||
def dtw_cpu(x: np.ndarray):
|
||||
N, M = x.shape
|
||||
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
|
||||
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
|
||||
|
||||
cost[0, 0] = 0
|
||||
for j in range(1, M + 1):
|
||||
for i in range(1, N + 1):
|
||||
c0 = cost[i - 1, j - 1]
|
||||
c1 = cost[i - 1, j]
|
||||
c2 = cost[i, j - 1]
|
||||
|
||||
if c0 < c1 and c0 < c2:
|
||||
c, t = c0, 0
|
||||
elif c1 < c0 and c1 < c2:
|
||||
c, t = c1, 1
|
||||
else:
|
||||
c, t = c2, 2
|
||||
|
||||
cost[i, j] = x[i - 1, j - 1] + c
|
||||
trace[i, j] = t
|
||||
|
||||
return backtrace(trace)
|
||||
|
||||
|
||||
def dtw_cuda(x, BLOCK_SIZE=1024):
|
||||
from .triton_ops import dtw_kernel
|
||||
|
||||
M, N = x.shape
|
||||
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
|
||||
|
||||
x_skew = (
|
||||
F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
|
||||
)
|
||||
x_skew = x_skew.T.contiguous()
|
||||
cost = torch.ones(N + M + 2, M + 2) * np.inf
|
||||
cost[0, 0] = 0
|
||||
cost = cost.cuda()
|
||||
trace = torch.zeros_like(cost, dtype=torch.int32)
|
||||
|
||||
dtw_kernel[(1,)](
|
||||
cost,
|
||||
trace,
|
||||
x_skew,
|
||||
x_skew.stride(0),
|
||||
cost.stride(0),
|
||||
trace.stride(0),
|
||||
N,
|
||||
M,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
|
||||
:, : N + 1
|
||||
]
|
||||
return backtrace(trace.cpu().numpy())
|
||||
|
||||
|
||||
def dtw(x: torch.Tensor) -> np.ndarray:
|
||||
return dtw_cpu(x.double().cpu().numpy())
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordTiming:
|
||||
word: str
|
||||
tokens: List[int]
|
||||
start: float
|
||||
end: float
|
||||
probability: float
|
||||
|
||||
|
||||
def find_alignment(
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
text_tokens: List[int],
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
*,
|
||||
medfilt_width: int = 7,
|
||||
qk_scale: float = 1.0,
|
||||
) -> List[WordTiming]:
|
||||
if len(text_tokens) == 0:
|
||||
return []
|
||||
|
||||
tokens = torch.tensor(
|
||||
[
|
||||
*tokenizer.sot_sequence,
|
||||
tokenizer.no_timestamps,
|
||||
*text_tokens,
|
||||
tokenizer.eot,
|
||||
]
|
||||
).to(model.device)
|
||||
|
||||
# install hooks on the cross attention layers to retrieve the attention weights
|
||||
QKs = [None] * model.dims.n_text_layer
|
||||
hooks = [
|
||||
block.cross_attn.register_forward_hook(
|
||||
lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
|
||||
)
|
||||
for i, block in enumerate(model.decoder.blocks)
|
||||
]
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
|
||||
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
|
||||
token_probs = sampled_logits.softmax(dim=-1)
|
||||
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
|
||||
text_token_probs = text_token_probs.tolist()
|
||||
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# heads * tokens * frames
|
||||
weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
|
||||
weights = weights[:, :, : num_frames // 2]
|
||||
weights = (weights * qk_scale).softmax(dim=-1)
|
||||
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
|
||||
weights = (weights - mean) / std
|
||||
weights = median_filter(weights, medfilt_width)
|
||||
|
||||
matrix = weights.mean(axis=0)
|
||||
matrix = matrix[len(tokenizer.sot_sequence) : -1]
|
||||
text_indices, time_indices = dtw(-matrix)
|
||||
|
||||
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
|
||||
if len(word_tokens) <= 1:
|
||||
# return on eot only
|
||||
# >>> np.pad([], (1, 0))
|
||||
# array([0.])
|
||||
# This results in crashes when we lookup jump_times with float, like
|
||||
# IndexError: arrays used as indices must be of integer (or boolean) type
|
||||
return []
|
||||
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
||||
|
||||
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
||||
jump_times = time_indices[jumps] / TOKENS_PER_SECOND
|
||||
start_times = jump_times[word_boundaries[:-1]]
|
||||
end_times = jump_times[word_boundaries[1:]]
|
||||
word_probabilities = [
|
||||
np.mean(text_token_probs[i:j])
|
||||
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
||||
]
|
||||
|
||||
return [
|
||||
WordTiming(word, tokens, start, end, probability)
|
||||
for word, tokens, start, end, probability in zip(
|
||||
words, word_tokens, start_times, end_times, word_probabilities
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
|
||||
# merge prepended punctuations
|
||||
i = len(alignment) - 2
|
||||
j = len(alignment) - 1
|
||||
while i >= 0:
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if previous.word.startswith(" ") and previous.word.strip() in prepended:
|
||||
# prepend it to the following word
|
||||
following.word = previous.word + following.word
|
||||
following.tokens = previous.tokens + following.tokens
|
||||
previous.word = ""
|
||||
previous.tokens = []
|
||||
else:
|
||||
j = i
|
||||
i -= 1
|
||||
|
||||
# merge appended punctuations
|
||||
i = 0
|
||||
j = 1
|
||||
while j < len(alignment):
|
||||
previous = alignment[i]
|
||||
following = alignment[j]
|
||||
if not previous.word.endswith(" ") and following.word in appended:
|
||||
# append it to the previous word
|
||||
previous.word = previous.word + following.word
|
||||
previous.tokens = previous.tokens + following.tokens
|
||||
following.word = ""
|
||||
following.tokens = []
|
||||
else:
|
||||
i = j
|
||||
j += 1
|
||||
|
||||
|
||||
def add_word_timestamps(
|
||||
*,
|
||||
segments: List[dict],
|
||||
model: "Whisper",
|
||||
tokenizer: Tokenizer,
|
||||
mel: torch.Tensor,
|
||||
num_frames: int,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
last_speech_timestamp: float,
|
||||
**kwargs,
|
||||
):
|
||||
if len(segments) == 0:
|
||||
return
|
||||
|
||||
text_tokens_per_segment = [
|
||||
[token for token in segment["tokens"] if token < tokenizer.eot]
|
||||
for segment in segments
|
||||
]
|
||||
|
||||
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
||||
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
|
||||
word_durations = np.array([t.end - t.start for t in alignment])
|
||||
word_durations = word_durations[word_durations.nonzero()]
|
||||
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
||||
median_duration = min(0.7, float(median_duration))
|
||||
max_duration = median_duration * 2
|
||||
|
||||
# hack: truncate long words at sentence boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(word_durations) > 0:
|
||||
sentence_end_marks = ".。!!??"
|
||||
# ensure words at sentence boundaries are not longer than twice the median word duration.
|
||||
for i in range(1, len(alignment)):
|
||||
if alignment[i].end - alignment[i].start > max_duration:
|
||||
if alignment[i].word in sentence_end_marks:
|
||||
alignment[i].end = alignment[i].start + max_duration
|
||||
elif alignment[i - 1].word in sentence_end_marks:
|
||||
alignment[i].start = alignment[i].end - max_duration
|
||||
|
||||
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
||||
|
||||
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
|
||||
word_index = 0
|
||||
|
||||
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
||||
saved_tokens = 0
|
||||
words = []
|
||||
|
||||
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
||||
timing = alignment[word_index]
|
||||
|
||||
if timing.word:
|
||||
words.append(
|
||||
dict(
|
||||
word=timing.word,
|
||||
start=round(time_offset + timing.start, 2),
|
||||
end=round(time_offset + timing.end, 2),
|
||||
probability=timing.probability,
|
||||
)
|
||||
)
|
||||
|
||||
saved_tokens += len(timing.tokens)
|
||||
word_index += 1
|
||||
|
||||
# hack: truncate long words at segment boundaries.
|
||||
# a better segmentation algorithm based on VAD should be able to replace this.
|
||||
if len(words) > 0:
|
||||
# ensure the first and second word after a pause is not longer than
|
||||
# twice the median word duration.
|
||||
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
||||
words[0]["end"] - words[0]["start"] > max_duration
|
||||
or (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
||||
)
|
||||
):
|
||||
if (
|
||||
len(words) > 1
|
||||
and words[1]["end"] - words[1]["start"] > max_duration
|
||||
):
|
||||
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
|
||||
words[0]["end"] = words[1]["start"] = boundary
|
||||
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
||||
|
||||
# prefer the segment-level start timestamp if the first word is too long.
|
||||
if (
|
||||
segment["start"] < words[0]["end"]
|
||||
and segment["start"] - 0.5 > words[0]["start"]
|
||||
):
|
||||
words[0]["start"] = max(
|
||||
0, min(words[0]["end"] - median_duration, segment["start"])
|
||||
)
|
||||
else:
|
||||
segment["start"] = words[0]["start"]
|
||||
|
||||
# prefer the segment-level end timestamp if the last word is too long.
|
||||
if (
|
||||
segment["end"] > words[-1]["start"]
|
||||
and segment["end"] + 0.5 < words[-1]["end"]
|
||||
):
|
||||
words[-1]["end"] = max(
|
||||
words[-1]["start"] + median_duration, segment["end"]
|
||||
)
|
||||
else:
|
||||
segment["end"] = words[-1]["end"]
|
||||
|
||||
last_speech_timestamp = segment["end"]
|
||||
|
||||
segment["words"] = words
|
|
@ -0,0 +1,395 @@
|
|||
import base64
|
||||
import os
|
||||
import string
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tokenizer:
|
||||
"""A thin wrapper around `tiktoken` providing quick access to special tokens"""
|
||||
|
||||
encoding: tiktoken.Encoding
|
||||
num_languages: int
|
||||
language: Optional[str] = None
|
||||
task: Optional[str] = None
|
||||
sot_sequence: Tuple[int] = ()
|
||||
special_tokens: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for special in self.encoding.special_tokens_set:
|
||||
special_token = self.encoding.encode_single_token(special)
|
||||
self.special_tokens[special] = special_token
|
||||
|
||||
sot: int = self.special_tokens["<|startoftranscript|>"]
|
||||
translate: int = self.special_tokens["<|translate|>"]
|
||||
transcribe: int = self.special_tokens["<|transcribe|>"]
|
||||
|
||||
langs = tuple(LANGUAGES.keys())[: self.num_languages]
|
||||
sot_sequence = [sot]
|
||||
if self.language is not None:
|
||||
sot_sequence.append(sot + 1 + langs.index(self.language))
|
||||
if self.task is not None:
|
||||
task_token: int = transcribe if self.task == "transcribe" else translate
|
||||
sot_sequence.append(task_token)
|
||||
|
||||
self.sot_sequence = tuple(sot_sequence)
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
return self.encoding.encode(text, **kwargs)
|
||||
|
||||
def decode(self, token_ids: List[int], **kwargs) -> str:
|
||||
token_ids = [t for t in token_ids if t < self.timestamp_begin]
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
|
||||
"""
|
||||
Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
|
||||
This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
|
||||
"""
|
||||
return self.encoding.decode(token_ids, **kwargs)
|
||||
|
||||
@cached_property
|
||||
def eot(self) -> int:
|
||||
return self.encoding.eot_token
|
||||
|
||||
@cached_property
|
||||
def transcribe(self) -> int:
|
||||
return self.special_tokens["<|transcribe|>"]
|
||||
|
||||
@cached_property
|
||||
def translate(self) -> int:
|
||||
return self.special_tokens["<|translate|>"]
|
||||
|
||||
@cached_property
|
||||
def sot(self) -> int:
|
||||
return self.special_tokens["<|startoftranscript|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_lm(self) -> int:
|
||||
return self.special_tokens["<|startoflm|>"]
|
||||
|
||||
@cached_property
|
||||
def sot_prev(self) -> int:
|
||||
return self.special_tokens["<|startofprev|>"]
|
||||
|
||||
@cached_property
|
||||
def no_speech(self) -> int:
|
||||
return self.special_tokens["<|nospeech|>"]
|
||||
|
||||
@cached_property
|
||||
def no_timestamps(self) -> int:
|
||||
return self.special_tokens["<|notimestamps|>"]
|
||||
|
||||
@cached_property
|
||||
def timestamp_begin(self) -> int:
|
||||
return self.special_tokens["<|0.00|>"]
|
||||
|
||||
@cached_property
|
||||
def language_token(self) -> int:
|
||||
"""Returns the token id corresponding to the value of the `language` field"""
|
||||
if self.language is None:
|
||||
raise ValueError("This tokenizer does not have language token configured")
|
||||
|
||||
return self.to_language_token(self.language)
|
||||
|
||||
def to_language_token(self, language):
|
||||
if token := self.special_tokens.get(f"<|{language}|>", None):
|
||||
return token
|
||||
|
||||
raise KeyError(f"Language {language} not found in tokenizer.")
|
||||
|
||||
@cached_property
|
||||
def all_language_tokens(self) -> Tuple[int]:
|
||||
result = []
|
||||
for token, token_id in self.special_tokens.items():
|
||||
if token.strip("<|>") in LANGUAGES:
|
||||
result.append(token_id)
|
||||
return tuple(result)[: self.num_languages]
|
||||
|
||||
@cached_property
|
||||
def all_language_codes(self) -> Tuple[str]:
|
||||
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
|
||||
|
||||
@cached_property
|
||||
def sot_sequence_including_notimestamps(self) -> Tuple[int]:
|
||||
return tuple(list(self.sot_sequence) + [self.no_timestamps])
|
||||
|
||||
@cached_property
|
||||
def non_speech_tokens(self) -> Tuple[int]:
|
||||
"""
|
||||
Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
|
||||
annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
|
||||
|
||||
- ♪♪♪
|
||||
- ( SPEAKING FOREIGN LANGUAGE )
|
||||
- [DAVID] Hey there,
|
||||
|
||||
keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
|
||||
"""
|
||||
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
|
||||
symbols += (
|
||||
"<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split()
|
||||
)
|
||||
|
||||
# symbols that may be a single token or multiple tokens depending on the tokenizer.
|
||||
# In case they're multiple tokens, suppress the first token, which is safe because:
|
||||
# These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
|
||||
# in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
|
||||
miscellaneous = set("♩♪♫♬♭♮♯")
|
||||
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
|
||||
|
||||
# allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
|
||||
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
|
||||
for symbol in symbols + list(miscellaneous):
|
||||
for tokens in [
|
||||
self.encoding.encode(symbol),
|
||||
self.encoding.encode(" " + symbol),
|
||||
]:
|
||||
if len(tokens) == 1 or symbol in miscellaneous:
|
||||
result.add(tokens[0])
|
||||
|
||||
return tuple(sorted(result))
|
||||
|
||||
def split_to_word_tokens(self, tokens: List[int]):
|
||||
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
|
||||
# These languages don't typically use spaces, so it is difficult to split words
|
||||
# without morpheme analysis. Here, we instead split words at any
|
||||
# position where the tokens are decoded as valid unicode points
|
||||
return self.split_tokens_on_unicode(tokens)
|
||||
|
||||
return self.split_tokens_on_spaces(tokens)
|
||||
|
||||
def split_tokens_on_unicode(self, tokens: List[int]):
|
||||
decoded_full = self.decode_with_timestamps(tokens)
|
||||
replacement_char = "\ufffd"
|
||||
|
||||
words = []
|
||||
word_tokens = []
|
||||
current_tokens = []
|
||||
unicode_offset = 0
|
||||
|
||||
for token in tokens:
|
||||
current_tokens.append(token)
|
||||
decoded = self.decode_with_timestamps(current_tokens)
|
||||
|
||||
if (
|
||||
replacement_char not in decoded
|
||||
or decoded_full[unicode_offset + decoded.index(replacement_char)]
|
||||
== replacement_char
|
||||
):
|
||||
words.append(decoded)
|
||||
word_tokens.append(current_tokens)
|
||||
current_tokens = []
|
||||
unicode_offset += len(decoded)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
def split_tokens_on_spaces(self, tokens: List[int]):
|
||||
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
||||
words = []
|
||||
word_tokens = []
|
||||
|
||||
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
||||
special = subword_tokens[0] >= self.eot
|
||||
with_space = subword.startswith(" ")
|
||||
punctuation = subword.strip() in string.punctuation
|
||||
if special or with_space or punctuation or len(words) == 0:
|
||||
words.append(subword)
|
||||
word_tokens.append(subword_tokens)
|
||||
else:
|
||||
words[-1] = words[-1] + subword
|
||||
word_tokens[-1].extend(subword_tokens)
|
||||
|
||||
return words, word_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
encoding_name = "multilingual"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
|
@ -0,0 +1,605 @@
|
|||
import argparse
|
||||
import os
|
||||
import traceback
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
|
||||
from .audio import (
|
||||
FRAMES_PER_SECOND,
|
||||
HOP_LENGTH,
|
||||
N_FRAMES,
|
||||
N_SAMPLES,
|
||||
SAMPLE_RATE,
|
||||
log_mel_spectrogram,
|
||||
pad_or_trim,
|
||||
)
|
||||
from .decoding import DecodingOptions, DecodingResult
|
||||
from .timing import add_word_timestamps
|
||||
from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
||||
from .utils import (
|
||||
exact_div,
|
||||
format_timestamp,
|
||||
get_end,
|
||||
get_writer,
|
||||
make_safe,
|
||||
optional_float,
|
||||
optional_int,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import Whisper
|
||||
|
||||
|
||||
def transcribe(
|
||||
model: "Whisper",
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
*,
|
||||
verbose: Optional[bool] = None,
|
||||
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
||||
compression_ratio_threshold: Optional[float] = 2.4,
|
||||
logprob_threshold: Optional[float] = -1.0,
|
||||
no_speech_threshold: Optional[float] = 0.6,
|
||||
condition_on_previous_text: bool = True,
|
||||
initial_prompt: Optional[str] = None,
|
||||
word_timestamps: bool = False,
|
||||
prepend_punctuations: str = "\"'“¿([{-",
|
||||
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
||||
clip_timestamps: Union[str, List[float]] = "0",
|
||||
hallucination_silence_threshold: Optional[float] = None,
|
||||
**decode_options,
|
||||
):
|
||||
"""
|
||||
Transcribe an audio file using Whisper
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model: Whisper
|
||||
The Whisper model instance
|
||||
|
||||
audio: Union[str, np.ndarray, torch.Tensor]
|
||||
The path to the audio file to open, or the audio waveform
|
||||
|
||||
verbose: bool
|
||||
Whether to display the text being decoded to the console. If True, displays all the details,
|
||||
If False, displays minimal details. If None, does not display anything
|
||||
|
||||
temperature: Union[float, Tuple[float, ...]]
|
||||
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
||||
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
||||
|
||||
compression_ratio_threshold: float
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
logprob_threshold: float
|
||||
If the average log probability over sampled tokens is below this value, treat as failed
|
||||
|
||||
no_speech_threshold: float
|
||||
If the no_speech probability is higher than this value AND the average log probability
|
||||
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
||||
|
||||
condition_on_previous_text: bool
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
word_timestamps: bool
|
||||
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
||||
and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the next word
|
||||
|
||||
append_punctuations: str
|
||||
If word_timestamps is True, merge these punctuation symbols with the previous word
|
||||
|
||||
initial_prompt: Optional[str]
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
decode_options: dict
|
||||
Keyword arguments to construct `DecodingOptions` instances
|
||||
|
||||
clip_timestamps: Union[str, List[float]]
|
||||
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
|
||||
The last end timestamp defaults to the end of the file.
|
||||
|
||||
hallucination_silence_threshold: Optional[float]
|
||||
When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
|
||||
when a possible hallucination is detected
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
||||
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
||||
"""
|
||||
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
||||
if model.device == torch.device("cpu"):
|
||||
if torch.cuda.is_available():
|
||||
warnings.warn("Performing inference on CPU when CUDA is available")
|
||||
if dtype == torch.float16:
|
||||
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
||||
dtype = torch.float32
|
||||
|
||||
if dtype == torch.float32:
|
||||
decode_options["fp16"] = False
|
||||
|
||||
# Pad 30-seconds of silence to the input audio, for slicing
|
||||
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
|
||||
content_frames = mel.shape[-1] - N_FRAMES
|
||||
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
|
||||
|
||||
if decode_options.get("language", None) is None:
|
||||
if not model.is_multilingual:
|
||||
decode_options["language"] = "en"
|
||||
else:
|
||||
if verbose:
|
||||
print(
|
||||
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
||||
)
|
||||
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
||||
_, probs = model.detect_language(mel_segment)
|
||||
decode_options["language"] = max(probs, key=probs.get)
|
||||
if verbose is not None:
|
||||
print(
|
||||
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
||||
)
|
||||
|
||||
language: str = decode_options["language"]
|
||||
task: str = decode_options.get("task", "transcribe")
|
||||
tokenizer = get_tokenizer(
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language=language,
|
||||
task=task,
|
||||
)
|
||||
|
||||
if isinstance(clip_timestamps, str):
|
||||
clip_timestamps = [
|
||||
float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
|
||||
]
|
||||
seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
|
||||
if len(seek_points) == 0:
|
||||
seek_points.append(0)
|
||||
if len(seek_points) % 2 == 1:
|
||||
seek_points.append(content_frames)
|
||||
seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
|
||||
|
||||
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
||||
|
||||
if word_timestamps and task == "translate":
|
||||
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
||||
|
||||
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
||||
temperatures = (
|
||||
[temperature] if isinstance(temperature, (int, float)) else temperature
|
||||
)
|
||||
decode_result = None
|
||||
|
||||
for t in temperatures:
|
||||
kwargs = {**decode_options}
|
||||
if t > 0:
|
||||
# disable beam_size and patience when t > 0
|
||||
kwargs.pop("beam_size", None)
|
||||
kwargs.pop("patience", None)
|
||||
else:
|
||||
# disable best_of when t == 0
|
||||
kwargs.pop("best_of", None)
|
||||
|
||||
options = DecodingOptions(**kwargs, temperature=t)
|
||||
decode_result = model.decode(segment, options)
|
||||
|
||||
needs_fallback = False
|
||||
if (
|
||||
compression_ratio_threshold is not None
|
||||
and decode_result.compression_ratio > compression_ratio_threshold
|
||||
):
|
||||
needs_fallback = True # too repetitive
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and decode_result.avg_logprob < logprob_threshold
|
||||
):
|
||||
needs_fallback = True # average log probability is too low
|
||||
if (
|
||||
no_speech_threshold is not None
|
||||
and decode_result.no_speech_prob > no_speech_threshold
|
||||
):
|
||||
needs_fallback = False # silence
|
||||
if not needs_fallback:
|
||||
break
|
||||
|
||||
return decode_result
|
||||
|
||||
clip_idx = 0
|
||||
seek = seek_clips[clip_idx][0]
|
||||
input_stride = exact_div(
|
||||
N_FRAMES, model.dims.n_audio_ctx
|
||||
) # mel frames per output token: 2
|
||||
time_precision = (
|
||||
input_stride * HOP_LENGTH / SAMPLE_RATE
|
||||
) # time per output token: 0.02 (seconds)
|
||||
all_tokens = []
|
||||
all_segments = []
|
||||
prompt_reset_since = 0
|
||||
|
||||
if initial_prompt is not None:
|
||||
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
||||
all_tokens.extend(initial_prompt_tokens)
|
||||
else:
|
||||
initial_prompt_tokens = []
|
||||
|
||||
def new_segment(
|
||||
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
||||
):
|
||||
tokens = tokens.tolist()
|
||||
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
||||
return {
|
||||
"seek": seek,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"text": tokenizer.decode(text_tokens),
|
||||
"tokens": tokens,
|
||||
"temperature": result.temperature,
|
||||
"avg_logprob": result.avg_logprob,
|
||||
"compression_ratio": result.compression_ratio,
|
||||
"no_speech_prob": result.no_speech_prob,
|
||||
}
|
||||
|
||||
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
||||
with tqdm.tqdm(
|
||||
total=content_frames, unit="frames", disable=verbose is not False
|
||||
) as pbar:
|
||||
last_speech_timestamp = 0.0
|
||||
# NOTE: This loop is obscurely flattened to make the diff readable.
|
||||
# A later commit should turn this into a simpler nested loop.
|
||||
# for seek_clip_start, seek_clip_end in seek_clips:
|
||||
# while seek < seek_clip_end
|
||||
while clip_idx < len(seek_clips):
|
||||
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
||||
if seek < seek_clip_start:
|
||||
seek = seek_clip_start
|
||||
if seek >= seek_clip_end:
|
||||
clip_idx += 1
|
||||
if clip_idx < len(seek_clips):
|
||||
seek = seek_clips[clip_idx][0]
|
||||
continue
|
||||
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
||||
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
|
||||
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
|
||||
mel_segment = mel[:, seek : seek + segment_size]
|
||||
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
||||
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
||||
|
||||
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
||||
result: DecodingResult = decode_with_fallback(mel_segment)
|
||||
tokens = torch.tensor(result.tokens)
|
||||
|
||||
if no_speech_threshold is not None:
|
||||
# no voice activity check
|
||||
should_skip = result.no_speech_prob > no_speech_threshold
|
||||
if (
|
||||
logprob_threshold is not None
|
||||
and result.avg_logprob > logprob_threshold
|
||||
):
|
||||
# don't skip if the logprob is high enough, despite the no_speech_prob
|
||||
should_skip = False
|
||||
|
||||
if should_skip:
|
||||
seek += segment_size # fast-forward to the next segment boundary
|
||||
continue
|
||||
|
||||
previous_seek = seek
|
||||
current_segments = []
|
||||
|
||||
# anomalous words are very long/short/improbable
|
||||
def word_anomaly_score(word: dict) -> float:
|
||||
probability = word.get("probability", 0.0)
|
||||
duration = word["end"] - word["start"]
|
||||
score = 0.0
|
||||
if probability < 0.15:
|
||||
score += 1.0
|
||||
if duration < 0.133:
|
||||
score += (0.133 - duration) * 15
|
||||
if duration > 2.0:
|
||||
score += duration - 2.0
|
||||
return score
|
||||
|
||||
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
||||
if segment is None or not segment["words"]:
|
||||
return False
|
||||
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
||||
words = words[:8]
|
||||
score = sum(word_anomaly_score(w) for w in words)
|
||||
return score >= 3 or score + 0.01 >= len(words)
|
||||
|
||||
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
||||
return next((s for s in segments if s["words"]), None)
|
||||
|
||||
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
||||
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
||||
|
||||
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
||||
consecutive.add_(1)
|
||||
if len(consecutive) > 0:
|
||||
# if the output contains two consecutive timestamp tokens
|
||||
slices = consecutive.tolist()
|
||||
if single_timestamp_ending:
|
||||
slices.append(len(tokens))
|
||||
|
||||
last_slice = 0
|
||||
for current_slice in slices:
|
||||
sliced_tokens = tokens[last_slice:current_slice]
|
||||
start_timestamp_pos = (
|
||||
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
end_timestamp_pos = (
|
||||
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset + start_timestamp_pos * time_precision,
|
||||
end=time_offset + end_timestamp_pos * time_precision,
|
||||
tokens=sliced_tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
last_slice = current_slice
|
||||
|
||||
if single_timestamp_ending:
|
||||
# single timestamp at the end means no speech after the last timestamp.
|
||||
seek += segment_size
|
||||
else:
|
||||
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
||||
last_timestamp_pos = (
|
||||
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
seek += last_timestamp_pos * input_stride
|
||||
else:
|
||||
duration = segment_duration
|
||||
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
||||
if (
|
||||
len(timestamps) > 0
|
||||
and timestamps[-1].item() != tokenizer.timestamp_begin
|
||||
):
|
||||
# no consecutive timestamps but it has a timestamp; use the last one.
|
||||
last_timestamp_pos = (
|
||||
timestamps[-1].item() - tokenizer.timestamp_begin
|
||||
)
|
||||
duration = last_timestamp_pos * time_precision
|
||||
|
||||
current_segments.append(
|
||||
new_segment(
|
||||
start=time_offset,
|
||||
end=time_offset + duration,
|
||||
tokens=tokens,
|
||||
result=result,
|
||||
)
|
||||
)
|
||||
seek += segment_size
|
||||
|
||||
if word_timestamps:
|
||||
add_word_timestamps(
|
||||
segments=current_segments,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
mel=mel_segment,
|
||||
num_frames=segment_size,
|
||||
prepend_punctuations=prepend_punctuations,
|
||||
append_punctuations=append_punctuations,
|
||||
last_speech_timestamp=last_speech_timestamp,
|
||||
)
|
||||
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
|
||||
# skip silence before possible hallucinations
|
||||
if hallucination_silence_threshold is not None:
|
||||
threshold = hallucination_silence_threshold
|
||||
if not single_timestamp_ending:
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None and last_word_end > time_offset:
|
||||
remaining_duration = window_end_time - last_word_end
|
||||
if remaining_duration > threshold:
|
||||
seek = round(last_word_end * FRAMES_PER_SECOND)
|
||||
else:
|
||||
seek = previous_seek + segment_size
|
||||
|
||||
# if first segment might be a hallucination, skip leading silence
|
||||
first_segment = next_words_segment(current_segments)
|
||||
if first_segment is not None and is_segment_anomaly(first_segment):
|
||||
gap = first_segment["start"] - time_offset
|
||||
if gap > threshold:
|
||||
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
|
||||
continue
|
||||
|
||||
# skip silence before any possible hallucination that is surrounded
|
||||
# by silence or more hallucinations
|
||||
hal_last_end = last_speech_timestamp
|
||||
for si in range(len(current_segments)):
|
||||
segment = current_segments[si]
|
||||
if not segment["words"]:
|
||||
continue
|
||||
if is_segment_anomaly(segment):
|
||||
next_segment = next_words_segment(
|
||||
current_segments[si + 1 :]
|
||||
)
|
||||
if next_segment is not None:
|
||||
hal_next_start = next_segment["words"][0]["start"]
|
||||
else:
|
||||
hal_next_start = time_offset + segment_duration
|
||||
silence_before = (
|
||||
segment["start"] - hal_last_end > threshold
|
||||
or segment["start"] < threshold
|
||||
or segment["start"] - time_offset < 2.0
|
||||
)
|
||||
silence_after = (
|
||||
hal_next_start - segment["end"] > threshold
|
||||
or is_segment_anomaly(next_segment)
|
||||
or window_end_time - segment["end"] < 2.0
|
||||
)
|
||||
if silence_before and silence_after:
|
||||
seek = round(
|
||||
max(time_offset + 1, segment["start"])
|
||||
* FRAMES_PER_SECOND
|
||||
)
|
||||
if content_duration - segment["end"] < threshold:
|
||||
seek = content_frames
|
||||
current_segments[si:] = []
|
||||
break
|
||||
hal_last_end = segment["end"]
|
||||
|
||||
last_word_end = get_end(current_segments)
|
||||
if last_word_end is not None:
|
||||
last_speech_timestamp = last_word_end
|
||||
|
||||
if verbose:
|
||||
for segment in current_segments:
|
||||
start, end, text = segment["start"], segment["end"], segment["text"]
|
||||
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
||||
print(make_safe(line))
|
||||
|
||||
# if a segment is instantaneous or does not contain text, clear it
|
||||
for i, segment in enumerate(current_segments):
|
||||
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
||||
segment["text"] = ""
|
||||
segment["tokens"] = []
|
||||
segment["words"] = []
|
||||
|
||||
all_segments.extend(
|
||||
[
|
||||
{"id": i, **segment}
|
||||
for i, segment in enumerate(
|
||||
current_segments, start=len(all_segments)
|
||||
)
|
||||
]
|
||||
)
|
||||
all_tokens.extend(
|
||||
[token for segment in current_segments for token in segment["tokens"]]
|
||||
)
|
||||
|
||||
if not condition_on_previous_text or result.temperature > 0.5:
|
||||
# do not feed the prompt tokens if a high temperature was used
|
||||
prompt_reset_since = len(all_tokens)
|
||||
|
||||
# update progress bar
|
||||
pbar.update(min(content_frames, seek) - previous_seek)
|
||||
|
||||
return dict(
|
||||
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
||||
segments=all_segments,
|
||||
language=language,
|
||||
)
|
||||
|
||||
|
||||
def cli():
|
||||
from . import available_models
|
||||
|
||||
def valid_model_name(name):
|
||||
if name in available_models() or os.path.exists(name):
|
||||
return name
|
||||
raise ValueError(
|
||||
f"model should be one of {available_models()} or path to a model checkpoint"
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
||||
parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
|
||||
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
||||
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
||||
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
||||
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
|
||||
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
||||
|
||||
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
||||
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
||||
|
||||
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
||||
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
||||
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
||||
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
||||
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
||||
|
||||
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
||||
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
||||
parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
||||
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
||||
|
||||
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
||||
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
||||
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
||||
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
||||
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
||||
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
||||
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
||||
parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
||||
parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
|
||||
parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
|
||||
parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
|
||||
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
||||
parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
|
||||
parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
|
||||
# fmt: on
|
||||
|
||||
args = parser.parse_args().__dict__
|
||||
model_name: str = args.pop("model")
|
||||
model_dir: str = args.pop("model_dir")
|
||||
output_dir: str = args.pop("output_dir")
|
||||
output_format: str = args.pop("output_format")
|
||||
device: str = args.pop("device")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
||||
if args["language"] is not None:
|
||||
warnings.warn(
|
||||
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
||||
)
|
||||
args["language"] = "en"
|
||||
|
||||
temperature = args.pop("temperature")
|
||||
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
||||
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
||||
else:
|
||||
temperature = [temperature]
|
||||
|
||||
if (threads := args.pop("threads")) > 0:
|
||||
torch.set_num_threads(threads)
|
||||
|
||||
from . import load_model
|
||||
|
||||
model = load_model(model_name, device=device, download_root=model_dir)
|
||||
|
||||
writer = get_writer(output_format, output_dir)
|
||||
word_options = [
|
||||
"highlight_words",
|
||||
"max_line_count",
|
||||
"max_line_width",
|
||||
"max_words_per_line",
|
||||
]
|
||||
if not args["word_timestamps"]:
|
||||
for option in word_options:
|
||||
if args[option]:
|
||||
parser.error(f"--{option} requires --word_timestamps True")
|
||||
if args["max_line_count"] and not args["max_line_width"]:
|
||||
warnings.warn("--max_line_count has no effect without --max_line_width")
|
||||
if args["max_words_per_line"] and args["max_line_width"]:
|
||||
warnings.warn("--max_words_per_line has no effect with --max_line_width")
|
||||
writer_args = {arg: args.pop(arg) for arg in word_options}
|
||||
for audio_path in args.pop("audio"):
|
||||
try:
|
||||
result = transcribe(model, audio_path, temperature=temperature, **args)
|
||||
writer(result, audio_path, **writer_args)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
|
@ -0,0 +1,316 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO
|
||||
|
||||
system_encoding = sys.getdefaultencoding()
|
||||
|
||||
if system_encoding != "utf-8":
|
||||
|
||||
def make_safe(string):
|
||||
# replaces any character not representable using the system default encoding with an '?',
|
||||
# avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
|
||||
return string.encode(system_encoding, errors="replace").decode(system_encoding)
|
||||
|
||||
else:
|
||||
|
||||
def make_safe(string):
|
||||
# utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
|
||||
return string
|
||||
|
||||
|
||||
def exact_div(x, y):
|
||||
assert x % y == 0
|
||||
return x // y
|
||||
|
||||
|
||||
def str2bool(string):
|
||||
str2val = {"True": True, "False": False}
|
||||
if string in str2val:
|
||||
return str2val[string]
|
||||
else:
|
||||
raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
|
||||
|
||||
|
||||
def optional_int(string):
|
||||
return None if string == "None" else int(string)
|
||||
|
||||
|
||||
def optional_float(string):
|
||||
return None if string == "None" else float(string)
|
||||
|
||||
|
||||
def compression_ratio(text) -> float:
|
||||
text_bytes = text.encode("utf-8")
|
||||
return len(text_bytes) / len(zlib.compress(text_bytes))
|
||||
|
||||
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
|
||||
):
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
segments[0]["start"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(
|
||||
self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
audio_basename = os.path.basename(audio_path)
|
||||
audio_basename = os.path.splitext(audio_basename)[0]
|
||||
output_path = os.path.join(
|
||||
self.output_dir, audio_basename + "." + self.extension
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
def iterate_result(
|
||||
self,
|
||||
result: dict,
|
||||
options: Optional[dict] = None,
|
||||
*,
|
||||
max_line_width: Optional[int] = None,
|
||||
max_line_count: Optional[int] = None,
|
||||
highlight_words: bool = False,
|
||||
max_words_per_line: Optional[int] = None,
|
||||
):
|
||||
options = options or {}
|
||||
max_line_width = max_line_width or options.get("max_line_width")
|
||||
max_line_count = max_line_count or options.get("max_line_count")
|
||||
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||
preserve_segments = max_line_count is None or max_line_width is None
|
||||
max_line_width = max_line_width or 1000
|
||||
max_words_per_line = max_words_per_line or 1000
|
||||
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: List[dict] = []
|
||||
last: float = get_start(result["segments"]) or 0.0
|
||||
for segment in result["segments"]:
|
||||
chunk_index = 0
|
||||
words_count = max_words_per_line
|
||||
while chunk_index < len(segment["words"]):
|
||||
remaining_words = len(segment["words"]) - chunk_index
|
||||
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||
words_count = remaining_words
|
||||
for i, original_timing in enumerate(
|
||||
segment["words"][chunk_index : chunk_index + words_count]
|
||||
):
|
||||
timing = original_timing.copy()
|
||||
long_pause = (
|
||||
not preserve_segments and timing["start"] - last > 3.0
|
||||
)
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if (
|
||||
line_len > 0
|
||||
and has_room
|
||||
and not long_pause
|
||||
and not seg_break
|
||||
):
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
subtitle = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
chunk_index += max_words_per_line
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
|
||||
if len(result["segments"]) > 0 and "words" in result["segments"][0]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
if highlight_words:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||
|
||||
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
"srt": WriteSRT,
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(
|
||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for writer in all_writers:
|
||||
writer(result, file, options, **kwargs)
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
|
@ -44,7 +44,7 @@ To use the Llama and Mistral models, you will need to go through an extra step t
|
|||
1. Visit
|
||||
- LLaMA 2: [https://huggingface.co/meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
- LLaMA 3: [https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
- Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
- Mistral: [https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
||||
2. Follow the steps on the Hugging Face page to obtain access
|
||||
3. Run `huggingface-cli login`
|
||||
4. Paste your [Hugging Face User Access Token](https://huggingface.co/docs/hub/en/security-tokens) to login
|
||||
|
@ -57,7 +57,7 @@ Run the chatbot app using the following command:
|
|||
> python app.py
|
||||
```
|
||||
|
||||
The chatbot app will start with the default settings, which uses `DirectML` as the backend to run the `Phi-3` model for inference using `float16`. The app will automatically download `Phi-3-4k-instruct` on the first run from the default `model_repo` which is set to `microsoft/Phi-3-4k-instruct`.
|
||||
The chatbot app will start with the default settings, which uses `DirectML` as the backend to run the `Phi-3` model for inference using `float16`. The app will automatically download `Phi-3-4k-instruct` on the first run from the default `hf_model` which is set to `microsoft/Phi-3-4k-instruct`.
|
||||
|
||||
This model is optimized to take advantage of DirectML operators and to use the custom DirectML graph implementations for Rotary Positional Embedding (RoPE), Multi-Head Attention (MHA), and the Feedforward layers (MLP).
|
||||
|
||||
|
@ -109,13 +109,13 @@ You can also select another model to run (`microsoft/Phi-3-mini-4k-instruct`, `m
|
|||
For example to run `Mistral-7B-Instruct-v0.1` use the following command:
|
||||
|
||||
```
|
||||
> python app.py --precision float16 --model_repo "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
```
|
||||
|
||||
You should see the result such as this:
|
||||
|
||||
```
|
||||
> python app.py --precision float16 --model_repo "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
> python app.py --precision float16 --hf_model "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
checkpoints\mistralai\Mistral-7B-Instruct-v0.1\model.pth doesnt exist. Downloading and converting from huggingface hub
|
||||
README.md: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3.90k/3.90k [00:00<?, ?B/s]
|
||||
model.safetensors.index.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25.1k/25.1k [00:00<?, ?B/s]
|
||||
|
@ -146,9 +146,9 @@ Following is a list of the basic settings supported by `app.py`:
|
|||
| Flag | Description | Default |
|
||||
| ---------------------- | ------------------------------------------------------------ | ------- |
|
||||
| `--help` | Show this help message. | N/A |
|
||||
| `--model_repo` | Specify the model to downloading using the Hugging Face Repository ID. | `microsoft/Phi-3-mini-4k-instruct` |
|
||||
| `--hf_model` | Specify the model to downloading using the Hugging Face Repository ID. | `microsoft/Phi-3-mini-4k-instruct` |
|
||||
| `--precision` | Model precision to use during generation. Options: [`float16`, `float32`] | `float16` |
|
||||
| `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{model_repo}/model.pth` |
|
||||
| `--checkpoint_path` | Path to converted PyTorch model checkpoint. | `checkpoints/{hf_model}/model.pth` |
|
||||
| `--max_context_length` | Max prompt length including the history. If exceeded, history is clipped starting from the first (user, assistant) pair. | `1500` |
|
||||
| `--disable_history` | Disable the chat history during generation. | Enabled |
|
||||
|
||||
|
@ -170,25 +170,25 @@ We offer two methods for preparing PyTorch models:
|
|||
### Use `download_and_convert.py` to download a language model:
|
||||
|
||||
```
|
||||
> python .\scripts\download_and_convert.py --model_repo "microsoft/Phi-3-mini-4k-instruct"
|
||||
> python .\scripts\download_and_convert.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
|
||||
```
|
||||
|
||||
After the model is downloaded and converted, you can pass the following parameter to `app.py` to run the language model:
|
||||
|
||||
```
|
||||
> python app.py --model_repo "microsoft/Phi-3-mini-4k-instruct"
|
||||
> python app.py --hf_model "microsoft/Phi-3-mini-4k-instruct"
|
||||
```
|
||||
|
||||
### Download a DirectML optimized PyTorch model from the [Microsoft Hugging Face repo](https://huggingface.co/microsoft):
|
||||
|
||||
1. cd checkpoints
|
||||
2. git clone https://huggingface.co/{model_repo} {model_repo}
|
||||
2. git clone https://huggingface.co/{hf_model} {hf_model}
|
||||
3. cd ../
|
||||
|
||||
After the model is downloaded, you can pass the following parameter to `app.py` to run the language model:
|
||||
|
||||
```
|
||||
> python app.py --checkpoint_path "checkpoints/{model_repo}/model.pth"
|
||||
> python app.py --checkpoint_path "checkpoints/{hf_model}/model.pth"
|
||||
```
|
||||
|
||||
## External Links
|
||||
|
|
|
@ -1,13 +1,19 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Union, List, Tuple, Iterator
|
||||
|
||||
import torch_directml
|
||||
import torch
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
wd = Path(__file__).parent.parent.resolve()
|
||||
sys.path.append(str(wd))
|
||||
|
@ -18,18 +24,19 @@ from models.phi3 import Transformer as Phi3Transformer
|
|||
from models.phi2 import Transformer as Phi2Transformer
|
||||
from models.llama import Transformer as LlamaTransformer
|
||||
|
||||
|
||||
device = torch_directml.device(torch_directml.default_device())
|
||||
|
||||
def decode_n_tokens(
|
||||
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
|
||||
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
|
||||
cur_token: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
num_new_tokens: int,
|
||||
tokenizer,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
stream_every_n: int,
|
||||
is_llama_3: bool = False,
|
||||
**sampling_kwargs
|
||||
):
|
||||
) -> Iterator[str]:
|
||||
res = tokenizer.decode(cur_token[0][0].item(), skip_special_tokens=True).strip() + " "
|
||||
yield res
|
||||
|
||||
|
@ -48,7 +55,7 @@ def decode_n_tokens(
|
|||
new_tokens.append(next_token.clone())
|
||||
cur_token = next_token.view(1, -1)
|
||||
|
||||
# Handle output and overlap at the specified intervals or at the last token for adding
|
||||
# Handle output and overlap at the specified intervals or at the last token for adding
|
||||
# the space correctly between stream batches
|
||||
if ((i + 1) % stream_every_n == 0 or i == num_new_tokens - 1):
|
||||
# Determine the range of tokens to decode, including the overlap
|
||||
|
@ -68,8 +75,6 @@ def decode_n_tokens(
|
|||
from_index = max(0, start_pos - overlap_size)
|
||||
yield decode_with_overlap(tokenizer, new_tokens, from_index, overlap_text)
|
||||
break
|
||||
return new_tokens
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
|
@ -77,11 +82,11 @@ def generate(
|
|||
prompt: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
*,
|
||||
tokenizer = None,
|
||||
tokenizer: PreTrainedTokenizerFast,
|
||||
stream_every_n: int = 10,
|
||||
is_llama_3: bool = False,
|
||||
**sampling_kwargs
|
||||
) -> torch.Tensor:
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
|
||||
"""
|
||||
|
@ -98,36 +103,38 @@ def generate(
|
|||
# create an empty tensor of the expected final shape and fill in the current tokens
|
||||
input_pos = torch.arange(0, T, device=device)
|
||||
next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs)
|
||||
|
||||
|
||||
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
||||
|
||||
# generated_tokens = decode_n_tokens(
|
||||
return decode_n_tokens(
|
||||
yield from decode_n_tokens(
|
||||
model, next_token.view(1, -1), input_pos, max_new_tokens - 1, tokenizer, stream_every_n, is_llama_3=is_llama_3, **sampling_kwargs)
|
||||
|
||||
|
||||
class LLM_Model:
|
||||
def __init__(self,
|
||||
prompt: str = "Hello, my name is",
|
||||
interactive: bool = False,
|
||||
num_samples: int = 5,
|
||||
max_new_tokens: int = 100,
|
||||
top_k: int = 200,
|
||||
temperature: float = 0.01,
|
||||
model_repo: str = "microsoft/Phi-3-mini-4k-instruct",
|
||||
checkpoint_path: str = None,
|
||||
precision: str = 'float32',
|
||||
stream_every_n: int = 7,
|
||||
max_context_length: int = 3500,
|
||||
use_history: bool = False):
|
||||
def __init__(
|
||||
self,
|
||||
prompt: str = "Hello, my name is",
|
||||
interactive: bool = False,
|
||||
num_samples: int = 5,
|
||||
max_new_tokens: int = 100,
|
||||
top_k: int = 200,
|
||||
temperature: float = 0.01,
|
||||
hf_model: str = "microsoft/Phi-3-mini-4k-instruct",
|
||||
checkpoint_path: str = None,
|
||||
precision: str = 'float32',
|
||||
stream_every_n: int = 7,
|
||||
max_context_length: int = 3500,
|
||||
use_history: bool = False
|
||||
):
|
||||
self.prompt = prompt
|
||||
self.interactive = interactive
|
||||
self.num_samples = num_samples
|
||||
self.max_new_tokens = max_new_tokens
|
||||
self.top_k = top_k
|
||||
self.temperature = temperature
|
||||
self.model_repo = model_repo
|
||||
self.checkpoint_path = Path(f"checkpoints/{model_repo}/model.pth") if checkpoint_path is None else Path(checkpoint_path)
|
||||
self.hf_model = hf_model
|
||||
self.checkpoint_path = Path(f"checkpoints/{hf_model}/model.pth") if checkpoint_path is None else Path(checkpoint_path)
|
||||
self.precision = torch.float32 if precision == 'float32' else torch.float16
|
||||
self.stream_every_n = stream_every_n
|
||||
self.max_context_length = max_context_length
|
||||
|
@ -136,21 +143,34 @@ class LLM_Model:
|
|||
self.tokenizer = None
|
||||
self.model = None
|
||||
|
||||
def encode_tokens(self, prompt, conversation_history, device=None, max_context_length=1500, bos=True):
|
||||
def encode_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_history: List[List[str]],
|
||||
device: torch.device = None,
|
||||
max_context_length: int = 1500,
|
||||
bos: bool = True
|
||||
) -> torch.Tensor:
|
||||
if self.is_phi_2:
|
||||
tokens = self.format_prompt_phi2_chat_and_encode(
|
||||
prompt, conversation_history, device, max_context_length, bos
|
||||
)
|
||||
else:
|
||||
tokens = self.format_prompt_and_encode(
|
||||
prompt, conversation_history, device, max_context_length,
|
||||
prompt, conversation_history, device, max_context_length,
|
||||
)
|
||||
return tokens
|
||||
|
||||
def format_prompt_and_encode(self, prompt, history, device=None, max_context_length=3500):
|
||||
|
||||
def format_prompt_and_encode(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_history: List[List[str]],
|
||||
device: torch.device = None,
|
||||
max_context_length: int = 1500,
|
||||
) -> torch.Tensor:
|
||||
messages = []
|
||||
if len(history) and self.use_history:
|
||||
for user, assistant in history:
|
||||
if len(conversation_history) and self.use_history:
|
||||
for user, assistant in conversation_history:
|
||||
user = {"role": "user", "content": user}
|
||||
assistant = {"role": "assistant", "content": assistant}
|
||||
messages.append(user)
|
||||
|
@ -172,7 +192,14 @@ class LLM_Model:
|
|||
|
||||
return tokens
|
||||
|
||||
def format_prompt_phi2_chat_and_encode(self, prompt, conversation_history, device=None, max_context_length=1500, bos=True):
|
||||
def format_prompt_phi2_chat_and_encode(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_history: List[List[str]],
|
||||
device: torch.device = None,
|
||||
max_context_length: int = 1500,
|
||||
bos: bool = True
|
||||
) -> torch.Tensor:
|
||||
formatted_prompt = ""
|
||||
if self.use_history:
|
||||
for user_prompt, llm_response in conversation_history:
|
||||
|
@ -194,7 +221,14 @@ class LLM_Model:
|
|||
token_tensor = torch.tensor(tokens, dtype=torch.int, device=device)
|
||||
return token_tensor
|
||||
|
||||
def format_prompt_phi2_qa_and_encode(self, prompt, conversation_history, max_context_length=1500, bos=True, device=None):
|
||||
def format_prompt_phi2_qa_and_encode(
|
||||
self,
|
||||
prompt: str,
|
||||
conversation_history: List[List[str]],
|
||||
device: torch.device = None,
|
||||
max_context_length: int = 1500,
|
||||
bos: bool = True
|
||||
) -> torch.Tensor:
|
||||
formatted_prompt = ""
|
||||
if self.use_history:
|
||||
for user_prompt, llm_response in conversation_history:
|
||||
|
@ -218,22 +252,22 @@ class LLM_Model:
|
|||
|
||||
token_tensor = torch.tensor(tokens, dtype=torch.int, device=device)
|
||||
return token_tensor
|
||||
|
||||
def download_and_convert(self):
|
||||
checkpoint_dir = hf_download(self.model_repo)
|
||||
|
||||
def download_and_convert(self) -> None:
|
||||
checkpoint_dir = hf_download(self.hf_model)
|
||||
convert_hf_checkpoint(
|
||||
checkpoint_dir=Path(checkpoint_dir),
|
||||
)
|
||||
self.checkpoint_path = Path(f"checkpoints/{self.model_repo}/model.pth")
|
||||
self.checkpoint_path = Path(f"{checkpoint_dir}/model.pth")
|
||||
|
||||
def load_model(self):
|
||||
def load_model(self) -> None:
|
||||
if not self.checkpoint_path.is_file():
|
||||
print(f"{self.checkpoint_path} doesnt exist. Downloading and converting {self.model_repo} from huggingface hub. "
|
||||
"Specify a different model with --model_repo or valid pre-converted checkpoint with --checkpoint_path")
|
||||
print(f"{self.checkpoint_path} doesnt exist. Downloading and converting {self.hf_model} from huggingface hub. "
|
||||
"Specify a different model with --hf_model or valid pre-converted checkpoint with --checkpoint_path")
|
||||
self.download_and_convert()
|
||||
print("Running app...")
|
||||
print(f"Loading model from {self.checkpoint_path}")
|
||||
|
||||
|
||||
self.is_llama_3 = "Llama-3" in self.checkpoint_path.parent.name
|
||||
self.is_phi_2 = "phi-2" in self.checkpoint_path.parent.name
|
||||
print(f"Using device={device}, is_llama_3={self.is_llama_3}, is_phi_2={self.is_phi_2}")
|
||||
|
@ -245,9 +279,14 @@ class LLM_Model:
|
|||
if self.max_context_length > self.model.config.block_size - (self.max_new_tokens+1):
|
||||
raise ValueError(
|
||||
f"Expected max_context_length to be less than {self.model.config.block_size - (self.max_new_tokens+1)} but got {self.max_context_length}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def chat(self, prompt, history, **sampling_kwargs):
|
||||
def chat(
|
||||
self,
|
||||
prompt: str,
|
||||
history: List[List[str]],
|
||||
**sampling_kwargs
|
||||
) -> Iterator[str]:
|
||||
torch.manual_seed(1235)
|
||||
encoded = self.encode_tokens(
|
||||
prompt,
|
||||
|
@ -256,7 +295,7 @@ class LLM_Model:
|
|||
max_context_length=self.max_context_length,
|
||||
)
|
||||
|
||||
toks = generate(
|
||||
yield from generate(
|
||||
self.model,
|
||||
encoded,
|
||||
self.max_new_tokens,
|
||||
|
@ -266,32 +305,30 @@ class LLM_Model:
|
|||
temperature=self.temperature,
|
||||
top_k=self.top_k,
|
||||
)
|
||||
return toks
|
||||
|
||||
|
||||
def chat(message, history):
|
||||
def chat(message: str, history: List[List[str]]) -> Iterator[str]:
|
||||
total_msg = ""
|
||||
for msg in llm_model.chat(message, history):
|
||||
total_msg += msg
|
||||
yield total_msg
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Your CLI description.')
|
||||
|
||||
parser.add_argument(
|
||||
'--model_repo',
|
||||
'--hf_model',
|
||||
type=str,
|
||||
default="microsoft/Phi-3-mini-4k-instruct",
|
||||
help='Huggingface Repository ID to download from.'
|
||||
default="phi-3",
|
||||
help='Huggingface Repository ID to download from. Or one of the model name from ["phi-2", "phi-3", "llama-2", "llama-3", "mistral"]'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--checkpoint_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Converted pytorch model checkpoint path. Defaults to `checkpoints/{model_repo}/model.pth`.'
|
||||
help='Converted pytorch model checkpoint path. Defaults to `checkpoints/{hf_model}/model.pth`.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max_context_length',
|
||||
|
@ -319,7 +356,7 @@ if __name__ == "__main__":
|
|||
max_new_tokens = 500,
|
||||
top_k = 200,
|
||||
temperature = 0.8,
|
||||
model_repo = args.model_repo,
|
||||
hf_model = args.hf_model,
|
||||
checkpoint_path = args.checkpoint_path,
|
||||
precision = args.precision,
|
||||
max_context_length = args.max_context_length,
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.configs import ModelArgs, find_multiple
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.max_position_embeddings = config.block_size
|
||||
self.rope_base = config.rope_base
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length):
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
dtype=self.output.weight.dtype
|
||||
for b in self.layers:
|
||||
b.attention._init_rope(self.max_position_embeddings, self.rope_base, dtype=dtype)
|
||||
|
||||
self.causal_mask = torch.tril(
|
||||
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
|
||||
)
|
||||
|
||||
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
|
||||
x = self.tok_embeddings(idx)
|
||||
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x, input_pos, mask)
|
||||
|
||||
x = self.norm(x)
|
||||
logits = self.output(x)
|
||||
return logits
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
return cls(ModelArgs.from_name(name))
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 2048
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = None
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
partial_rotary_factor: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
# fuzzy search
|
||||
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
|
||||
|
||||
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
||||
# take longer name (as it have more symbols matched)
|
||||
if len(config) > 1:
|
||||
config.sort(key=len, reverse=True)
|
||||
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
||||
|
||||
if not config:
|
||||
raise ValueError(f"No configuration found for the model named '{name}'. Supported models: {list(transformer_configs.keys())}")
|
||||
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
transformer_configs = {
|
||||
"7B": dict(block_size=4096, n_layer=32, n_head=32, dim=4096),
|
||||
"phi-2": dict(block_size=2048, n_layer=32, n_head=32, dim=2560, intermediate_size=10240, rope_base=10000, vocab_size=51200, partial_rotary_factor=0.4),
|
||||
"Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064),
|
||||
"Mistral-7B": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
|
||||
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0),
|
||||
}
|
||||
|
||||
default_models = {
|
||||
"llama-3": "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"llama-2": "meta-llama/Llama-2-7b-chat-hf",
|
||||
"phi-3": "microsoft/Phi-3-mini-4k-instruct",
|
||||
"phi-2": "microsoft/phi-2",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.1"
|
||||
}
|
|
@ -0,0 +1,137 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Union, Tuple, Dict
|
||||
|
||||
import torch_directml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from models.configs import ModelArgs
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.rmsnorm = torch_directml.rmsnorm
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
|
||||
return output.type_as(x)
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
max_position_embeddings: int = 4096,
|
||||
base: Union[int, float] = 10000,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
device = device if device is not None else torch_directml.device(torch_directml.default_device())
|
||||
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len: int, device: torch.device) -> None:
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
|
||||
|
||||
def forward(self) -> Tuple[torch.Tensor, torch.Tensor] :
|
||||
return (
|
||||
self.cos_cached,
|
||||
self.sin_cached
|
||||
)
|
||||
|
||||
class LlamaAttention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
||||
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
||||
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
||||
self.kv_cache = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.n_local_heads = config.n_local_heads
|
||||
self.dim = config.dim
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict: Dict[str, torch.Tensor], prefix: str, *argspy):
|
||||
if prefix + "wq.weight" in state_dict:
|
||||
wq = state_dict.pop(prefix + "wq.weight")
|
||||
wk = state_dict.pop(prefix + "wk.weight")
|
||||
wv = state_dict.pop(prefix + "wv.weight")
|
||||
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
||||
|
||||
def _init_rope( self, max_position_embeddings: int = 4096, rope_base: Union[int, float] = 10000.0, dtype: torch.dtype = torch.float16) -> None:
|
||||
self.min_position = 0
|
||||
self.past_key_tensor = None
|
||||
self.past_value_tensor = None
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
kv_size = self.n_local_heads * self.head_dim
|
||||
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
||||
|
||||
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
|
||||
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb()
|
||||
|
||||
q, k = torch_directml.apply_rotary_position_emb(
|
||||
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
|
||||
self.min_position += seqlen
|
||||
|
||||
if self.n_head != self.n_local_heads:
|
||||
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
|
||||
|
||||
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
|
||||
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
|
||||
)
|
||||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
class LlamaTransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs, feed_forward_module: nn.Module):
|
||||
super().__init__()
|
||||
self.attention = LlamaAttention(config)
|
||||
self.feed_forward = feed_forward_module(config)
|
||||
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
|
||||
h = x + self.attention(self.attention_norm(x), mask, input_pos)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
|
@ -1,180 +1,30 @@
|
|||
from dataclasses import dataclass
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch_directml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.configs import ModelArgs
|
||||
from models.layers import RMSNorm
|
||||
from models.layers import LlamaTransformerBlock as TransformerBlock
|
||||
from models.base import Transformer
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 2048
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = None
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
# fuzzy search
|
||||
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
|
||||
|
||||
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
||||
# take longer name (as it have more symbols matched)
|
||||
if len(config) > 1:
|
||||
config.sort(key=len, reverse=True)
|
||||
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
||||
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
transformer_configs = {
|
||||
"7B": dict(block_size=4096, n_layer=32, n_head=32, dim=4096),
|
||||
"Mistral-7B": dict(block_size=4096, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=32000),
|
||||
"Llama-3-8B": dict(block_size=8192, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, vocab_size=128256, rope_base=500000.0),
|
||||
}
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class LlamaTransformer(Transformer):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.max_position_embeddings = config.block_size
|
||||
self.rope_base = config.rope_base
|
||||
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
||||
self.layers = nn.ModuleList(TransformerBlock(config, FeedForward) for _ in range(config.n_layer))
|
||||
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length):
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
dtype=self.output.weight.dtype
|
||||
for b in self.layers:
|
||||
b.attention._init_rope(self.max_position_embeddings, self.rope_base, dtype=dtype)
|
||||
|
||||
self.causal_mask = torch.tril(
|
||||
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
|
||||
)
|
||||
|
||||
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
|
||||
x = self.tok_embeddings(idx)
|
||||
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x, input_pos, mask)
|
||||
|
||||
x = self.norm(x)
|
||||
logits = self.output(x)
|
||||
return logits
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
return cls(ModelArgs.from_name(name))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.attention = Attention(config)
|
||||
self.feed_forward = FeedForward(config)
|
||||
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
|
||||
h = x + self.attention(self.attention_norm(x), mask, input_pos)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
||||
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
||||
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
||||
self.kv_cache = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.n_local_heads = config.n_local_heads
|
||||
self.dim = config.dim
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict, prefix, *argspy):
|
||||
if prefix + "wq.weight" in state_dict:
|
||||
wq = state_dict.pop(prefix + "wq.weight")
|
||||
wk = state_dict.pop(prefix + "wk.weight")
|
||||
wv = state_dict.pop(prefix + "wv.weight")
|
||||
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
||||
|
||||
def _init_rope(self, max_position_embeddings=4096, rope_base=10000.0, dtype=torch.float16):
|
||||
self.min_position = 0
|
||||
self.past_key_tensor = None
|
||||
self.past_value_tensor = None
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_base, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
kv_size = self.n_local_heads * self.head_dim
|
||||
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
||||
|
||||
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
|
||||
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb()
|
||||
|
||||
q, k = torch_directml.apply_rotary_position_emb(
|
||||
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
|
||||
self.min_position += seqlen
|
||||
|
||||
if self.n_head != self.n_local_heads:
|
||||
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
|
||||
|
||||
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
|
||||
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
|
||||
)
|
||||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
@ -185,47 +35,3 @@ class FeedForward(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.mlp(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.rmsnorm = torch_directml.rmsnorm
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
|
||||
return output.type_as(x)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=4096, base=10000, dtype=torch.float16, device=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
device = device if device is not None else torch_directml.device(torch_directml.default_device())
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(self.dtype), persistent=False)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
self.cos_cached,
|
||||
self.sin_cached
|
||||
)
|
||||
|
|
|
@ -1,94 +1,28 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import math
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional, Dict, Union
|
||||
import torch_directml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 2048
|
||||
vocab_size: int = 51200
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 2560
|
||||
intermediate_size: int = 10240
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
else:
|
||||
self.n_kv_groups = self.n_head // self.n_local_heads
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
# fuzzy search
|
||||
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
|
||||
assert len(config) == 1, name
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
transformer_configs = {
|
||||
"phi-2": dict(block_size=2048, n_layer=32, n_head=32, dim=2560, intermediate_size=10240, rope_base=10000),
|
||||
}
|
||||
from models.configs import ModelArgs
|
||||
from models.layers import RotaryEmbedding
|
||||
from models.base import Transformer
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Phi2Transformer(Transformer):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
||||
self.norm = nn.LayerNorm(config.dim, eps=config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size)
|
||||
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length, dtype=torch.float32):
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
for b in self.layers:
|
||||
b.attention._init_rope(dtype=dtype)
|
||||
|
||||
self.causal_mask = torch.tril(
|
||||
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
|
||||
)
|
||||
|
||||
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
|
||||
x = self.tok_embeddings(idx)
|
||||
|
||||
for _, layer in enumerate(self.layers):
|
||||
x = layer(x, input_pos, mask)
|
||||
x = self.norm(x)
|
||||
logits = self.output(x)
|
||||
return logits
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
return cls(ModelArgs.from_name(name))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
@ -102,7 +36,6 @@ class TransformerBlock(nn.Module):
|
|||
ffn_hidden_states = self.feed_forward(hidden_states)
|
||||
return attn_outputs + ffn_hidden_states + x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
|
@ -111,26 +44,37 @@ class Attention(nn.Module):
|
|||
self.wqkv = nn.Linear(config.dim, 3 * config.dim)
|
||||
self.wo = nn.Linear(config.dim, config.dim)
|
||||
self.kv_cache = None
|
||||
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.n_local_heads = config.n_local_heads
|
||||
self.dim = config.dim
|
||||
self.partial_rotary_factor = config.partial_rotary_factor
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict, prefix, *argspy):
|
||||
def load_hook(self, state_dict: Dict[str, torch.Tensor], prefix: str, *argspy):
|
||||
if prefix + "wq.weight" in state_dict:
|
||||
wq = state_dict.pop(prefix + "wq.weight")
|
||||
wk = state_dict.pop(prefix + "wk.weight")
|
||||
wv = state_dict.pop(prefix + "wv.weight")
|
||||
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
||||
|
||||
def _init_rope(self, dtype=torch.float32):
|
||||
def _init_rope(
|
||||
self,
|
||||
max_position_embeddings: int = 4096,
|
||||
rope_base: Union[int, float] = 10000.0,
|
||||
dtype: torch.dtype = torch.float32
|
||||
) -> None:
|
||||
self.min_position = 0
|
||||
self.past_key_tensor = None
|
||||
self.past_value_tensor = None
|
||||
self.rotary_emb = PhiRotaryEmbedding(self.n_head, dtype=dtype)
|
||||
self.rotary_emb = RotaryEmbedding(
|
||||
int(self.head_dim * self.partial_rotary_factor),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_base,
|
||||
dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
@ -139,9 +83,9 @@ class Attention(nn.Module):
|
|||
|
||||
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1,2)
|
||||
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1,2)
|
||||
|
||||
|
||||
cos, sin = self.rotary_emb()
|
||||
|
||||
|
||||
q, k = torch_directml.apply_rotary_position_emb(
|
||||
q, k, cos, sin, self.min_position, seqlen, self.rotary_emb.dim)
|
||||
|
||||
|
@ -155,7 +99,6 @@ class Attention(nn.Module):
|
|||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
@ -165,34 +108,3 @@ class FeedForward(nn.Module):
|
|||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.mlp(x, self.w1.weight, self.w2.weight, self.w1.bias, self.w2.bias)
|
||||
|
||||
|
||||
class PhiRotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
device = device if device is not None else torch_directml.device(torch_directml.default_device())
|
||||
self.dtype = dtype
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=dtype).to(device) / self.dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=dtype
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.dtype)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
self.cos_cached,
|
||||
self.sin_cached,
|
||||
)
|
||||
|
|
|
@ -1,86 +1,29 @@
|
|||
from dataclasses import dataclass
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch_directml
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from models.configs import ModelArgs
|
||||
from models.layers import RMSNorm
|
||||
from models.layers import LlamaTransformerBlock as TransformerBlock
|
||||
from models.base import Transformer
|
||||
|
||||
|
||||
def find_multiple(n: int, k: int) -> int:
|
||||
if n % k == 0:
|
||||
return n
|
||||
return n + k - (n % k)
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
block_size: int = 2048
|
||||
vocab_size: int = 32000
|
||||
n_layer: int = 32
|
||||
n_head: int = 32
|
||||
dim: int = 4096
|
||||
intermediate_size: int = None
|
||||
n_local_heads: int = -1
|
||||
head_dim: int = 64
|
||||
rope_base: float = 10000
|
||||
norm_eps: float = 1e-5
|
||||
|
||||
def __post_init__(self):
|
||||
if self.n_local_heads == -1:
|
||||
self.n_local_heads = self.n_head
|
||||
if self.intermediate_size is None:
|
||||
hidden_dim = 4 * self.dim
|
||||
n_hidden = int(2 * hidden_dim / 3)
|
||||
self.intermediate_size = find_multiple(n_hidden, 256)
|
||||
self.head_dim = self.dim // self.n_head
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
if name in transformer_configs:
|
||||
return cls(**transformer_configs[name])
|
||||
# fuzzy search
|
||||
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
|
||||
|
||||
# We may have two or more configs matched (e.g. "7B" and "Mistral-7B"). Find the best config match,
|
||||
# take longer name (as it have more symbols matched)
|
||||
if len(config) > 1:
|
||||
config.sort(key=len, reverse=True)
|
||||
assert len(config[0]) != len(config[1]), name # make sure only one 'best' match
|
||||
|
||||
return cls(**transformer_configs[config[0]])
|
||||
|
||||
transformer_configs = {
|
||||
"Phi-3-mini-4k-instruct": dict(block_size=4096, n_layer=32, n_head=32, dim=3072, intermediate_size=8192, rope_base=10000, vocab_size=32064),
|
||||
}
|
||||
|
||||
class Transformer(nn.Module):
|
||||
class Phi3Transformer(Transformer):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
super().__init__(config)
|
||||
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
|
||||
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
|
||||
self.layers = nn.ModuleList(TransformerBlock(config, FeedForward) for _ in range(config.n_layer))
|
||||
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
|
||||
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
|
||||
|
||||
self.mask_cache: Optional[Tensor] = None
|
||||
self.max_batch_size = -1
|
||||
self.max_seq_length = -1
|
||||
|
||||
def setup_caches(self, max_batch_size, max_seq_length):
|
||||
head_dim = self.config.dim // self.config.n_head
|
||||
max_seq_length = find_multiple(max_seq_length, 8)
|
||||
self.max_seq_length = max_seq_length
|
||||
self.max_batch_size = max_batch_size
|
||||
dtype=self.output.weight.dtype
|
||||
for b in self.layers:
|
||||
b.attention._init_rope(dtype=dtype)
|
||||
|
||||
self.causal_mask = torch.tril(
|
||||
torch.ones(self.config.n_head, self.max_seq_length, self.max_seq_length, dtype=torch.int32)
|
||||
)
|
||||
|
||||
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
mask = self.causal_mask[None, :, input_pos, :input_pos[-1].item()+1]
|
||||
x = self.tok_embeddings(idx)
|
||||
|
@ -92,80 +35,6 @@ class Transformer(nn.Module):
|
|||
logits = self.output(x)
|
||||
return logits
|
||||
|
||||
@classmethod
|
||||
def from_name(cls, name: str):
|
||||
return cls(ModelArgs.from_name(name))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.attention = Attention(config)
|
||||
self.feed_forward = FeedForward(config)
|
||||
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
|
||||
|
||||
def forward(self, x: Tensor, input_pos: Tensor, mask: Tensor) -> Tensor:
|
||||
h = x + self.attention(self.attention_norm(x), mask, input_pos)
|
||||
out = h + self.feed_forward(self.ffn_norm(h))
|
||||
return out
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
assert config.dim % config.n_head == 0
|
||||
|
||||
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
|
||||
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
|
||||
self.wo = nn.Linear(config.dim, config.dim, bias=False)
|
||||
self.kv_cache = None
|
||||
|
||||
self.n_head = config.n_head
|
||||
self.head_dim = config.head_dim
|
||||
self.n_local_heads = config.n_local_heads
|
||||
self.dim = config.dim
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(self, state_dict, prefix, *argspy):
|
||||
if prefix + "wq.weight" in state_dict:
|
||||
wq = state_dict.pop(prefix + "wq.weight")
|
||||
wk = state_dict.pop(prefix + "wk.weight")
|
||||
wv = state_dict.pop(prefix + "wv.weight")
|
||||
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
|
||||
|
||||
def _init_rope(self, dtype=torch.float32):
|
||||
self.min_position = 0
|
||||
self.past_key_tensor = None
|
||||
self.past_value_tensor = None
|
||||
self.rotary_emb = RotaryEmbedding(self.head_dim, dtype=dtype)
|
||||
|
||||
def forward(self, x: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
|
||||
bsz, seqlen, _ = x.shape
|
||||
|
||||
kv_size = self.n_local_heads * self.head_dim
|
||||
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
|
||||
|
||||
q = q.reshape(bsz, seqlen, self.n_head, self.head_dim).transpose(1, 2)
|
||||
k = k.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
v = v.reshape(bsz, seqlen, self.n_local_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
cos, sin = self.rotary_emb()
|
||||
q, k = torch_directml.apply_rotary_position_emb(
|
||||
q, k, cos, sin, self.min_position, seqlen, self.head_dim)
|
||||
self.min_position += seqlen
|
||||
|
||||
q, k, v = map(lambda x: x.transpose(1, 2).reshape(bsz, -1, self.dim), (q, k, v))
|
||||
|
||||
y, self.past_key_tensor, self.past_value_tensor = torch_directml.multi_head_attention(
|
||||
q, k, v, self.dim, self.n_head, self.past_key_tensor, self.past_value_tensor, mask
|
||||
)
|
||||
|
||||
y = self.wo(y)
|
||||
return y
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
|
@ -176,46 +45,3 @@ class FeedForward(nn.Module):
|
|||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.mlp(x, self.w1.weight, self.w2.weight)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
self.rmsnorm = torch_directml.rmsnorm
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
output = self.rmsnorm(x.float(), self.weight.float(), self.eps)
|
||||
return output.type_as(x)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_position_embeddings=4096, base=10000, dtype=torch.float32, device=None):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
device = device if device is not None else torch_directml.device(torch_directml.default_device())
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=dtype).to(device) / self.dim))
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self._set_cos_sin_cache(
|
||||
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=dtype
|
||||
)
|
||||
|
||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||
self.max_seq_len_cached = seq_len
|
||||
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.dtype)
|
||||
|
||||
freqs = torch.outer(t, self.inv_freq)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
|
||||
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0).to(dtype), persistent=False)
|
||||
|
||||
def forward(self):
|
||||
return (
|
||||
self.cos_cached,
|
||||
self.sin_cached
|
||||
)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
|
@ -13,61 +15,73 @@ from typing import Optional
|
|||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from requests.exceptions import HTTPError
|
||||
from huggingface_hub.utils._errors import RepositoryNotFoundError
|
||||
|
||||
# support running without installing as a package
|
||||
wd = Path(__file__).parent.parent.resolve()
|
||||
sys.path.append(str(wd))
|
||||
|
||||
from models.phi2 import ModelArgs as Phi2ModelArgs
|
||||
from models.phi3 import ModelArgs as Phi3ModelArgs
|
||||
from models.llama import ModelArgs as LlamaModelArgs
|
||||
from models.configs import ModelArgs, default_models
|
||||
|
||||
def hf_download(model_repo: Optional[str] = None, hf_token: Optional[str] = None) -> None:
|
||||
|
||||
def is_dir_empty(directory: str) -> bool:
|
||||
return not any(os.scandir(directory))
|
||||
|
||||
def download_model_from_hf(hf_model: str, checkpoint_dir: str, hf_token: Optional[str]) -> str:
|
||||
from huggingface_hub import snapshot_download
|
||||
checkpoint_dir = f"checkpoints/{model_repo}"
|
||||
checkpoint_dir = f"{checkpoint_dir}/{hf_model}"
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
if os.listdir(checkpoint_dir):
|
||||
if not is_dir_empty(checkpoint_dir):
|
||||
print(f"The directory {checkpoint_dir} is not empty. Skipping download.")
|
||||
else:
|
||||
try:
|
||||
snapshot_download(model_repo, local_dir=checkpoint_dir, local_dir_use_symlinks=False, token=hf_token)
|
||||
snapshot_download(hf_model, local_dir=checkpoint_dir, local_dir_use_symlinks=False, token=hf_token)
|
||||
print(f"Downloaded {hf_model} successfully.")
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 401:
|
||||
print("You need to pass a valid `--hf_token=...` to download private checkpoints.")
|
||||
else:
|
||||
raise e
|
||||
print("You need to pass a valid Hugging Face token to download private models.")
|
||||
raise e
|
||||
return checkpoint_dir
|
||||
|
||||
def hf_download(hf_model: Optional[str] = None, hf_token: Optional[str] = None, checkpoint_dir: str = "checkpoints") -> None:
|
||||
try:
|
||||
checkpoint_dir_download = download_model_from_hf(hf_model, checkpoint_dir, hf_token)
|
||||
except RepositoryNotFoundError as e:
|
||||
# invalid repo passed, try to search for a default repo from the given hf_model
|
||||
os.rmdir(f"{checkpoint_dir}/{hf_model}")
|
||||
print(f"Couldn't find {hf_model} on HuggingFace. Searching for the closest supported match ...")
|
||||
if hf_model in default_models:
|
||||
hf_model = default_models[hf_model]
|
||||
else:
|
||||
raise ValueError(f"Please provide a valid hf_model to download from Huggingface. {hf_model} doesnt exist on Huggingface.")
|
||||
|
||||
print(f"Found closest match on Huggingface: {hf_model}")
|
||||
checkpoint_dir_download = download_model_from_hf(hf_model, checkpoint_dir, hf_token)
|
||||
return checkpoint_dir_download
|
||||
|
||||
@torch.inference_mode()
|
||||
def convert_hf_checkpoint(
|
||||
*,
|
||||
checkpoint_dir: Path = Path("microsoft/Phi-3-mini-4k-instruct"),
|
||||
checkpoint_dir: Path = Path("checkpoints/microsoft/Phi-3-mini-4k-instruct"),
|
||||
weight_map_path: str = "config/weight_map.json",
|
||||
model_name: Optional[str] = None,
|
||||
) -> None:
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
raise ValueError("Please download you model first with the hf_download function.")
|
||||
|
||||
if os.path.exists(checkpoint_dir / "model.pth"):
|
||||
print(f"Converted checkpoint already exists here {checkpoint_dir / 'model.pth'}. Skipping Conversion.")
|
||||
return
|
||||
|
||||
model_name = checkpoint_dir.name
|
||||
config = ModelArgs.from_name(model_name)
|
||||
|
||||
with open(weight_map_path, 'r') as file:
|
||||
weight_maps = json.load(file)
|
||||
|
||||
if model_name is None:
|
||||
model_name = checkpoint_dir.name
|
||||
|
||||
is_llama3 = "Llama-3" in model_name
|
||||
is_phi3 = "Phi-3" in model_name
|
||||
model_name = checkpoint_dir.name
|
||||
model_name = "llama" if "phi" not in model_name.lower() else model_name
|
||||
weight_map = weight_maps[model_name]
|
||||
|
||||
if "phi" not in model_name.lower():
|
||||
weight_map = weight_maps["llama"]
|
||||
config = LlamaModelArgs.from_name(model_name)
|
||||
else:
|
||||
weight_map = weight_maps[model_name]
|
||||
if is_phi3:
|
||||
config = Phi3ModelArgs.from_name(model_name)
|
||||
else:
|
||||
config = Phi2ModelArgs.from_name(model_name)
|
||||
|
||||
# Load the json file containing weight mapping
|
||||
model_map_json = checkpoint_dir / "model.safetensors.index.json"
|
||||
|
||||
|
@ -77,12 +91,12 @@ def convert_hf_checkpoint(
|
|||
bin_index = json.load(json_map)
|
||||
|
||||
bin_files = {checkpoint_dir / bin for bin in bin_index["weight_map"].values()}
|
||||
|
||||
|
||||
merged_result = {}
|
||||
for file in sorted(bin_files):
|
||||
state_dict = load_file(str(file))
|
||||
merged_result.update(state_dict)
|
||||
|
||||
|
||||
final_result = {}
|
||||
for key, value in merged_result.items():
|
||||
if "layers" in key:
|
||||
|
@ -125,13 +139,15 @@ def convert_hf_checkpoint(
|
|||
if __name__ == '__main__':
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description='Download and convert HuggingFace checkpoint.')
|
||||
parser.add_argument('--model_repo', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Huggingface Repository ID to download from.')
|
||||
parser.add_argument('--hf_model', type=str, default="microsoft/Phi-3-mini-4k-instruct", help='Huggingface Repository ID to download from.')
|
||||
parser.add_argument('--hf_token', type=str, default=None, help='HuggingFace API token.')
|
||||
parser.add_argument('--model_name', type=str, default=None)
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir', type=str, default="checkpoints",
|
||||
help="Directory to downloads the Huggingface repo to. The model will be downloaded and converted to '{checkpoint_dir}/{hf_model}/"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
checkpoint_dir = hf_download(args.model_repo, args.hf_token)
|
||||
checkpoint_dir = hf_download(args.hf_model, args.hf_token, args.checkpoint_dir)
|
||||
convert_hf_checkpoint(
|
||||
checkpoint_dir=Path(checkpoint_dir),
|
||||
model_name=args.model_name,
|
||||
)
|
||||
|
|
|
@ -1,27 +1,33 @@
|
|||
# Copyright (c) Microsoft Corporation.
|
||||
#
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from transformers import PreTrainedTokenizerFast
|
||||
|
||||
# support running without installing as a package
|
||||
wd = Path(__file__).parent.parent.resolve()
|
||||
sys.path.append(str(wd))
|
||||
|
||||
from models.phi3 import Transformer as Phi3Transformer
|
||||
from models.phi2 import Transformer as Phi2Transformer
|
||||
from models.llama import Transformer as LlamaTransformer
|
||||
from models.phi3 import Phi3Transformer
|
||||
from models.phi2 import Phi2Transformer
|
||||
from models.llama import LlamaTransformer
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
|
||||
def multinomial_sample_one_no_sync(probs_sort: Tensor) -> Tensor: # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.empty_like(probs_sort).exponential_(1)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
def logits_to_probs(logits: Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> Tensor:
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
|
@ -31,8 +37,7 @@ def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = Non
|
|||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
|
||||
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
||||
def sample(logits: Tensor, temperature: float = 1.0, top_k: Optional[int] = None) -> Tensor:
|
||||
probs = logits_to_probs(logits[0, -1], temperature, top_k)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
@ -40,10 +45,10 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
|
|||
|
||||
def prefill(
|
||||
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
|
||||
x: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
x: Tensor,
|
||||
input_pos: Tensor,
|
||||
**sampling_kwargs
|
||||
) -> torch.Tensor:
|
||||
) -> Tensor:
|
||||
# input_pos: [B, S]
|
||||
logits = model(x, input_pos)
|
||||
return sample(logits, **sampling_kwargs)[0]
|
||||
|
@ -51,17 +56,16 @@ def prefill(
|
|||
|
||||
def decode_one_token(
|
||||
model: Union[Phi2Transformer, Phi3Transformer, LlamaTransformer],
|
||||
x: torch.Tensor,
|
||||
input_pos: torch.Tensor,
|
||||
x: Tensor,
|
||||
input_pos: Tensor,
|
||||
**sampling_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
# input_pos: [B, 1]
|
||||
assert input_pos.shape[-1] == 1
|
||||
logits = model(x, input_pos)
|
||||
return sample(logits, **sampling_kwargs)
|
||||
|
||||
|
||||
def decode_with_overlap(tokenizer, tokens, start, overlap):
|
||||
def decode_with_overlap(tokenizer: PreTrainedTokenizerFast, tokens: List[Tensor], start: int, overlap: str) -> str:
|
||||
"""Helper function to decode text, managing overlap."""
|
||||
current_decoded = tokenizer.decode(torch.IntTensor(tokens[start:]).tolist(), skip_special_tokens=True)
|
||||
if overlap and current_decoded.startswith(overlap):
|
||||
|
@ -70,8 +74,7 @@ def decode_with_overlap(tokenizer, tokens, start, overlap):
|
|||
text_output = current_decoded
|
||||
return text_output
|
||||
|
||||
|
||||
def _load_model(checkpoint_path, device, precision):
|
||||
def _load_model(checkpoint_path: str, device: torch.device, precision: torch.dtype) -> torch.nn.Module:
|
||||
model_name = checkpoint_path.parent.name
|
||||
with torch.device('meta'):
|
||||
if 'phi-2' in model_name.lower():
|
||||
|
|
Загрузка…
Ссылка в новой задаче