зеркало из https://github.com/microsoft/UniSpeech.git
add code
This commit is contained in:
Коммит
8f8cbd22d3
|
@ -0,0 +1,74 @@
|
|||
Attribution-ShareAlike 3.0 Unported
|
||||
|
||||
CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE LEGAL SERVICES. DISTRIBUTION OF THIS LICENSE DOES NOT CREATE AN ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES REGARDING THE INFORMATION PROVIDED, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM ITS USE.
|
||||
|
||||
License
|
||||
|
||||
THE WORK (AS DEFINED BELOW) IS PROVIDED UNDER THE TERMS OF THIS CREATIVE COMMONS PUBLIC LICENSE ("CCPL" OR "LICENSE"). THE WORK IS PROTECTED BY COPYRIGHT AND/OR OTHER APPLICABLE LAW. ANY USE OF THE WORK OTHER THAN AS AUTHORIZED UNDER THIS LICENSE OR COPYRIGHT LAW IS PROHIBITED.
|
||||
|
||||
BY EXERCISING ANY RIGHTS TO THE WORK PROVIDED HERE, YOU ACCEPT AND AGREE TO BE BOUND BY THE TERMS OF THIS LICENSE. TO THE EXTENT THIS LICENSE MAY BE CONSIDERED TO BE A CONTRACT, THE LICENSOR GRANTS YOU THE RIGHTS CONTAINED HERE IN CONSIDERATION OF YOUR ACCEPTANCE OF SUCH TERMS AND CONDITIONS.
|
||||
|
||||
1. Definitions
|
||||
|
||||
"Adaptation" means a work based upon the Work, or upon the Work and other pre-existing works, such as a translation, adaptation, derivative work, arrangement of music or other alterations of a literary or artistic work, or phonogram or performance and includes cinematographic adaptations or any other form in which the Work may be recast, transformed, or adapted including in any form recognizably derived from the original, except that a work that constitutes a Collection will not be considered an Adaptation for the purpose of this License. For the avoidance of doubt, where the Work is a musical work, performance or phonogram, the synchronization of the Work in timed-relation with a moving image ("synching") will be considered an Adaptation for the purpose of this License.
|
||||
"Collection" means a collection of literary or artistic works, such as encyclopedias and anthologies, or performances, phonograms or broadcasts, or other works or subject matter other than works listed in Section 1(f) below, which, by reason of the selection and arrangement of their contents, constitute intellectual creations, in which the Work is included in its entirety in unmodified form along with one or more other contributions, each constituting separate and independent works in themselves, which together are assembled into a collective whole. A work that constitutes a Collection will not be considered an Adaptation (as defined below) for the purposes of this License.
|
||||
"Creative Commons Compatible License" means a license that is listed at https://creativecommons.org/compatiblelicenses that has been approved by Creative Commons as being essentially equivalent to this License, including, at a minimum, because that license: (i) contains terms that have the same purpose, meaning and effect as the License Elements of this License; and, (ii) explicitly permits the relicensing of adaptations of works made available under that license under this License or a Creative Commons jurisdiction license with the same License Elements as this License.
|
||||
"Distribute" means to make available to the public the original and copies of the Work or Adaptation, as appropriate, through sale or other transfer of ownership.
|
||||
"License Elements" means the following high-level license attributes as selected by Licensor and indicated in the title of this License: Attribution, ShareAlike.
|
||||
"Licensor" means the individual, individuals, entity or entities that offer(s) the Work under the terms of this License.
|
||||
"Original Author" means, in the case of a literary or artistic work, the individual, individuals, entity or entities who created the Work or if no individual or entity can be identified, the publisher; and in addition (i) in the case of a performance the actors, singers, musicians, dancers, and other persons who act, sing, deliver, declaim, play in, interpret or otherwise perform literary or artistic works or expressions of folklore; (ii) in the case of a phonogram the producer being the person or legal entity who first fixes the sounds of a performance or other sounds; and, (iii) in the case of broadcasts, the organization that transmits the broadcast.
|
||||
"Work" means the literary and/or artistic work offered under the terms of this License including without limitation any production in the literary, scientific and artistic domain, whatever may be the mode or form of its expression including digital form, such as a book, pamphlet and other writing; a lecture, address, sermon or other work of the same nature; a dramatic or dramatico-musical work; a choreographic work or entertainment in dumb show; a musical composition with or without words; a cinematographic work to which are assimilated works expressed by a process analogous to cinematography; a work of drawing, painting, architecture, sculpture, engraving or lithography; a photographic work to which are assimilated works expressed by a process analogous to photography; a work of applied art; an illustration, map, plan, sketch or three-dimensional work relative to geography, topography, architecture or science; a performance; a broadcast; a phonogram; a compilation of data to the extent it is protected as a copyrightable work; or a work performed by a variety or circus performer to the extent it is not otherwise considered a literary or artistic work.
|
||||
"You" means an individual or entity exercising rights under this License who has not previously violated the terms of this License with respect to the Work, or who has received express permission from the Licensor to exercise rights under this License despite a previous violation.
|
||||
"Publicly Perform" means to perform public recitations of the Work and to communicate to the public those public recitations, by any means or process, including by wire or wireless means or public digital performances; to make available to the public Works in such a way that members of the public may access these Works from a place and at a place individually chosen by them; to perform the Work to the public by any means or process and the communication to the public of the performances of the Work, including by public digital performance; to broadcast and rebroadcast the Work by any means including signs, sounds or images.
|
||||
"Reproduce" means to make copies of the Work by any means including without limitation by sound or visual recordings and the right of fixation and reproducing fixations of the Work, including storage of a protected performance or phonogram in digital form or other electronic medium.
|
||||
|
||||
2. Fair Dealing Rights. Nothing in this License is intended to reduce, limit, or restrict any uses free from copyright or rights arising from limitations or exceptions that are provided for in connection with the copyright protection under copyright law or other applicable laws.
|
||||
|
||||
3. License Grant. Subject to the terms and conditions of this License, Licensor hereby grants You a worldwide, royalty-free, non-exclusive, perpetual (for the duration of the applicable copyright) license to exercise the rights in the Work as stated below:
|
||||
|
||||
to Reproduce the Work, to incorporate the Work into one or more Collections, and to Reproduce the Work as incorporated in the Collections;
|
||||
to create and Reproduce Adaptations provided that any such Adaptation, including any translation in any medium, takes reasonable steps to clearly label, demarcate or otherwise identify that changes were made to the original Work. For example, a translation could be marked "The original work was translated from English to Spanish," or a modification could indicate "The original work has been modified.";
|
||||
to Distribute and Publicly Perform the Work including as incorporated in Collections; and,
|
||||
to Distribute and Publicly Perform Adaptations.
|
||||
|
||||
For the avoidance of doubt:
|
||||
Non-waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme cannot be waived, the Licensor reserves the exclusive right to collect such royalties for any exercise by You of the rights granted under this License;
|
||||
Waivable Compulsory License Schemes. In those jurisdictions in which the right to collect royalties through any statutory or compulsory licensing scheme can be waived, the Licensor waives the exclusive right to collect such royalties for any exercise by You of the rights granted under this License; and,
|
||||
Voluntary License Schemes. The Licensor waives the right to collect royalties, whether individually or, in the event that the Licensor is a member of a collecting society that administers voluntary licensing schemes, via that society, from any exercise by You of the rights granted under this License.
|
||||
|
||||
The above rights may be exercised in all media and formats whether now known or hereafter devised. The above rights include the right to make such modifications as are technically necessary to exercise the rights in other media and formats. Subject to Section 8(f), all rights not expressly granted by Licensor are hereby reserved.
|
||||
|
||||
4. Restrictions. The license granted in Section 3 above is expressly made subject to and limited by the following restrictions:
|
||||
|
||||
You may Distribute or Publicly Perform the Work only under the terms of this License. You must include a copy of, or the Uniform Resource Identifier (URI) for, this License with every copy of the Work You Distribute or Publicly Perform. You may not offer or impose any terms on the Work that restrict the terms of this License or the ability of the recipient of the Work to exercise the rights granted to that recipient under the terms of the License. You may not sublicense the Work. You must keep intact all notices that refer to this License and to the disclaimer of warranties with every copy of the Work You Distribute or Publicly Perform. When You Distribute or Publicly Perform the Work, You may not impose any effective technological measures on the Work that restrict the ability of a recipient of the Work from You to exercise the rights granted to that recipient under the terms of the License. This Section 4(a) applies to the Work as incorporated in a Collection, but this does not require the Collection apart from the Work itself to be made subject to the terms of this License. If You create a Collection, upon notice from any Licensor You must, to the extent practicable, remove from the Collection any credit as required by Section 4(c), as requested. If You create an Adaptation, upon notice from any Licensor You must, to the extent practicable, remove from the Adaptation any credit as required by Section 4(c), as requested.
|
||||
You may Distribute or Publicly Perform an Adaptation only under the terms of: (i) this License; (ii) a later version of this License with the same License Elements as this License; (iii) a Creative Commons jurisdiction license (either this or a later license version) that contains the same License Elements as this License (e.g., Attribution-ShareAlike 3.0 US)); (iv) a Creative Commons Compatible License. If you license the Adaptation under one of the licenses mentioned in (iv), you must comply with the terms of that license. If you license the Adaptation under the terms of any of the licenses mentioned in (i), (ii) or (iii) (the "Applicable License"), you must comply with the terms of the Applicable License generally and the following provisions: (I) You must include a copy of, or the URI for, the Applicable License with every copy of each Adaptation You Distribute or Publicly Perform; (II) You may not offer or impose any terms on the Adaptation that restrict the terms of the Applicable License or the ability of the recipient of the Adaptation to exercise the rights granted to that recipient under the terms of the Applicable License; (III) You must keep intact all notices that refer to the Applicable License and to the disclaimer of warranties with every copy of the Work as included in the Adaptation You Distribute or Publicly Perform; (IV) when You Distribute or Publicly Perform the Adaptation, You may not impose any effective technological measures on the Adaptation that restrict the ability of a recipient of the Adaptation from You to exercise the rights granted to that recipient under the terms of the Applicable License. This Section 4(b) applies to the Adaptation as incorporated in a Collection, but this does not require the Collection apart from the Adaptation itself to be made subject to the terms of the Applicable License.
|
||||
If You Distribute, or Publicly Perform the Work or any Adaptations or Collections, You must, unless a request has been made pursuant to Section 4(a), keep intact all copyright notices for the Work and provide, reasonable to the medium or means You are utilizing: (i) the name of the Original Author (or pseudonym, if applicable) if supplied, and/or if the Original Author and/or Licensor designate another party or parties (e.g., a sponsor institute, publishing entity, journal) for attribution ("Attribution Parties") in Licensor's copyright notice, terms of service or by other reasonable means, the name of such party or parties; (ii) the title of the Work if supplied; (iii) to the extent reasonably practicable, the URI, if any, that Licensor specifies to be associated with the Work, unless such URI does not refer to the copyright notice or licensing information for the Work; and (iv) , consistent with Ssection 3(b), in the case of an Adaptation, a credit identifying the use of the Work in the Adaptation (e.g., "French translation of the Work by Original Author," or "Screenplay based on original Work by Original Author"). The credit required by this Section 4(c) may be implemented in any reasonable manner; provided, however, that in the case of a Adaptation or Collection, at a minimum such credit will appear, if a credit for all contributing authors of the Adaptation or Collection appears, then as part of these credits and in a manner at least as prominent as the credits for the other contributing authors. For the avoidance of doubt, You may only use the credit required by this Section for the purpose of attribution in the manner set out above and, by exercising Your rights under this License, You may not implicitly or explicitly assert or imply any connection with, sponsorship or endorsement by the Original Author, Licensor and/or Attribution Parties, as appropriate, of You or Your use of the Work, without the separate, express prior written permission of the Original Author, Licensor and/or Attribution Parties.
|
||||
Except as otherwise agreed in writing by the Licensor or as may be otherwise permitted by applicable law, if You Reproduce, Distribute or Publicly Perform the Work either by itself or as part of any Adaptations or Collections, You must not distort, mutilate, modify or take other derogatory action in relation to the Work which would be prejudicial to the Original Author's honor or reputation. Licensor agrees that in those jurisdictions (e.g. Japan), in which any exercise of the right granted in Section 3(b) of this License (the right to make Adaptations) would be deemed to be a distortion, mutilation, modification or other derogatory action prejudicial to the Original Author's honor and reputation, the Licensor will waive or not assert, as appropriate, this Section, to the fullest extent permitted by the applicable national law, to enable You to reasonably exercise Your right under Section 3(b) of this License (right to make Adaptations) but not otherwise.
|
||||
|
||||
5. Representations, Warranties and Disclaimer
|
||||
|
||||
UNLESS OTHERWISE MUTUALLY AGREED TO BY THE PARTIES IN WRITING, LICENSOR OFFERS THE WORK AS-IS AND MAKES NO REPRESENTATIONS OR WARRANTIES OF ANY KIND CONCERNING THE WORK, EXPRESS, IMPLIED, STATUTORY OR OTHERWISE, INCLUDING, WITHOUT LIMITATION, WARRANTIES OF TITLE, MERCHANTIBILITY, FITNESS FOR A PARTICULAR PURPOSE, NONINFRINGEMENT, OR THE ABSENCE OF LATENT OR OTHER DEFECTS, ACCURACY, OR THE PRESENCE OF ABSENCE OF ERRORS, WHETHER OR NOT DISCOVERABLE. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OF IMPLIED WARRANTIES, SO SUCH EXCLUSION MAY NOT APPLY TO YOU.
|
||||
|
||||
6. Limitation on Liability. EXCEPT TO THE EXTENT REQUIRED BY APPLICABLE LAW, IN NO EVENT WILL LICENSOR BE LIABLE TO YOU ON ANY LEGAL THEORY FOR ANY SPECIAL, INCIDENTAL, CONSEQUENTIAL, PUNITIVE OR EXEMPLARY DAMAGES ARISING OUT OF THIS LICENSE OR THE USE OF THE WORK, EVEN IF LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
||||
|
||||
7. Termination
|
||||
|
||||
This License and the rights granted hereunder will terminate automatically upon any breach by You of the terms of this License. Individuals or entities who have received Adaptations or Collections from You under this License, however, will not have their licenses terminated provided such individuals or entities remain in full compliance with those licenses. Sections 1, 2, 5, 6, 7, and 8 will survive any termination of this License.
|
||||
Subject to the above terms and conditions, the license granted here is perpetual (for the duration of the applicable copyright in the Work). Notwithstanding the above, Licensor reserves the right to release the Work under different license terms or to stop distributing the Work at any time; provided, however that any such election will not serve to withdraw this License (or any other license that has been, or is required to be, granted under the terms of this License), and this License will continue in full force and effect unless terminated as stated above.
|
||||
|
||||
8. Miscellaneous
|
||||
|
||||
Each time You Distribute or Publicly Perform the Work or a Collection, the Licensor offers to the recipient a license to the Work on the same terms and conditions as the license granted to You under this License.
|
||||
Each time You Distribute or Publicly Perform an Adaptation, Licensor offers to the recipient a license to the original Work on the same terms and conditions as the license granted to You under this License.
|
||||
If any provision of this License is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this License, and without further action by the parties to this agreement, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
|
||||
No term or provision of this License shall be deemed waived and no breach consented to unless such waiver or consent shall be in writing and signed by the party to be charged with such waiver or consent.
|
||||
This License constitutes the entire agreement between the parties with respect to the Work licensed here. There are no understandings, agreements or representations with respect to the Work not specified here. Licensor shall not be bound by any additional provisions that may appear in any communication from You. This License may not be modified without the mutual written agreement of the Licensor and You.
|
||||
The rights granted under, and the subject matter referenced, in this License were drafted utilizing the terminology of the Berne Convention for the Protection of Literary and Artistic Works (as amended on September 28, 1979), the Rome Convention of 1961, the WIPO Copyright Treaty of 1996, the WIPO Performances and Phonograms Treaty of 1996 and the Universal Copyright Convention (as revised on July 24, 1971). These rights and subject matter take effect in the relevant jurisdiction in which the License terms are sought to be enforced according to the corresponding provisions of the implementation of those treaty provisions in the applicable national law. If the standard suite of rights granted under applicable copyright law includes additional rights not granted under this License, such additional rights are deemed to be included in the License; this License is not intended to restrict the license of any rights under applicable law.
|
||||
|
||||
Creative Commons Notice
|
||||
|
||||
Creative Commons is not a party to this License, and makes no warranty whatsoever in connection with the Work. Creative Commons will not be liable to You or any party on any legal theory for any damages whatsoever, including without limitation any general, special, incidental or consequential damages arising in connection to this license. Notwithstanding the foregoing two (2) sentences, if Creative Commons has expressly identified itself as the Licensor hereunder, it shall have all rights and obligations of Licensor.
|
||||
|
||||
Except for the limited purpose of indicating to the public that the Work is licensed under the CCPL, Creative Commons does not authorize the use by either party of the trademark "Creative Commons" or any related trademark or logo of Creative Commons without the prior written consent of Creative Commons. Any permitted use will be in compliance with Creative Commons' then-current trademark usage guidelines, as may be published on its website or otherwise made available upon request from time to time. For the avoidance of doubt, this trademark restriction does not form part of the License.
|
||||
|
||||
Creative Commons may be contacted at https://creativecommons.org/.
|
|
@ -0,0 +1,58 @@
|
|||
# UniSpeech
|
||||
|
||||
This is the official implementation of paper "[UniSpeech: Unified Speech Representation Learning with Labeled and Unlabeled Data](https://arxiv.org/abs/2101.07597)". The implementation mainly based on [fairseq](https://github.com/pytorch/fairseq) codebase. We release the training recipes on CommonVoice dataset.
|
||||
|
||||
## Requirements and Installation
|
||||
|
||||
- Pytorch >= 1.4.0
|
||||
- python version >= 3.6
|
||||
``` bash
|
||||
cd unispeech
|
||||
pip install soundfile
|
||||
pip install librosa
|
||||
pip install pydub
|
||||
pip install --editable ./
|
||||
```
|
||||
## Data Preparation
|
||||
Download pretraining audio data from [here](https://commonvoice.mozilla.org/datasets). (We use the June 2020 release version in our paper).
|
||||
Get the wav list and the transcription for each dataset by run:
|
||||
```
|
||||
python examples/unispeech/unispeech_manifest.py input_meta_file --dest examples/unispeech/data/LANG
|
||||
```
|
||||
|
||||
Then convert the audio files in common voices to 16k HZ using the commond:
|
||||
```
|
||||
python examples/unispeech/adjust_sample_rate.py --wav-path /path/to/wav/ --dest-path /path/to/16kwav/ --input examples/unispeech/data/LANG/*.tsv --output examples/unispeech/data/LANG/*_16k.tsv
|
||||
```
|
||||
For the finetuning data, our train/val/test splits are following [this](https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz).
|
||||
The phoneme transcriptions are generated by [phonemizer](https://github.com/bootphon/phonemizer) to convert texts to phonemes. Then we create .id files using different vocabularies. All our pre-processed data as well as the dictionaries can be downloaded from [here].
|
||||
|
||||
## Pretraining
|
||||
|
||||
We give the training examples for large model here.
|
||||
### Stage 1. Pretraining UniSpeech with labeled data.
|
||||
The following script can be used to pre-train an English model:
|
||||
```
|
||||
bash examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
|
||||
```
|
||||
To train a multilingual model:
|
||||
```
|
||||
bash examples/unispeech/scripts/multilingual_large_pretrain.sh
|
||||
```
|
||||
|
||||
### Stage 2. Continue pre-training with low-resource unlabeled data. (Optional)
|
||||
After stage 1, you can continue pre-training the UniSpeech model with only contrastive loss:
|
||||
```
|
||||
bash examples/unispeech/scripts/continue_pretran.sh
|
||||
```
|
||||
|
||||
### Stage 3. Finetuning with low-resource labeled data.
|
||||
Finally, fint-tune the model with 1 hour labeled data.
|
||||
For multilingual models, you can choose to use separate vocabulary (examples/unispeech/data/en/vocab_sep.json) or shared vocabulary (examples/unispeech/data/en/vocab_share.json)
|
||||
```
|
||||
bash examples/unispeech/scripts/finetune.sh
|
||||
```
|
||||
|
||||
### Models
|
||||
We release our pre-trained English large model and multilingual large model in stage 1 at [here]. You can finetune the model to get the reported results in our paper.
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
# @package _group_
|
||||
common:
|
||||
no_progress_bar: false
|
||||
log_interval: 100
|
||||
log_format: null
|
||||
tensorboard_logdir: null
|
||||
seed: 1
|
||||
cpu: false
|
||||
tpu: false
|
||||
bf16: false
|
||||
fp16: false
|
||||
memory_efficient_fp16: false
|
||||
memory_efficient_bf16: false
|
||||
fp16_no_flatten_grads: false
|
||||
fp16_init_scale: 128
|
||||
fp16_scale_window: null
|
||||
fp16_scale_tolerance: 0.0
|
||||
min_loss_scale: 1.0e-4
|
||||
threshold_loss_scale: null
|
||||
user_dir: null
|
||||
empty_cache_freq: 0
|
||||
all_gather_list_size: 16384
|
||||
model_parallel_size: 1
|
||||
quantization_config_path: null
|
||||
profile: false
|
||||
distributed_training:
|
||||
distributed_rank: 0
|
||||
distributed_backend: "nccl"
|
||||
distributed_init_method: null
|
||||
distributed_port: -1
|
||||
device_id: 0
|
||||
local_rank: 0
|
||||
distributed_no_spawn: false
|
||||
ddp_backend: "c10d"
|
||||
bucket_cap_mb: 25
|
||||
fix_batches_to_gpus: false
|
||||
find_unused_parameters: false
|
||||
fast_stat_sync: false
|
||||
broadcast_buffers: false
|
||||
distributed_wrapper: "DDP"
|
||||
slowmo_momentum: null
|
||||
slowmo_algorithm: "LocalSGD"
|
||||
localsgd_frequency: 3
|
||||
dataset:
|
||||
num_workers: 1
|
||||
skip_invalid_size_inputs_valid_test: false
|
||||
max_tokens: null
|
||||
batch_size: null
|
||||
required_batch_size_multiple: 8
|
||||
dataset_impl: null
|
||||
data_buffer_size: 10
|
||||
train_subset: "train"
|
||||
valid_subset: "valid"
|
||||
validate_interval: 1
|
||||
fixed_validation_seed: null
|
||||
disable_validation: false
|
||||
curriculum: 0
|
||||
gen_subset: "test"
|
||||
num_shards: 1
|
||||
shard_id: 0
|
||||
max_tokens_valid: ${dataset.max_tokens}
|
||||
batch_size_valid: ${dataset.batch_size}
|
||||
optimization:
|
||||
max_epoch: 0
|
||||
max_update: 0
|
||||
clip_norm: 25.0
|
||||
sentence_avg: false
|
||||
update_freq: [ 1 ]
|
||||
lr: [ 0.25 ]
|
||||
min_lr: -1.0
|
||||
use_bmuf: false
|
||||
checkpoint:
|
||||
save_dir: "checkpoints"
|
||||
restore_file: "checkpoint_last.pt"
|
||||
reset_dataloader: false
|
||||
reset_lr_scheduler: false
|
||||
reset_meters: false
|
||||
reset_optimizer: false
|
||||
optimizer_overrides: "{}"
|
||||
save_interval: 1
|
||||
save_interval_updates: 0
|
||||
keep_interval_updates: -1
|
||||
keep_last_epochs: -1
|
||||
keep_best_checkpoints: -1
|
||||
no_save: false
|
||||
no_epoch_checkpoints: false
|
||||
no_last_checkpoints: false
|
||||
no_save_optimizer_state: false
|
||||
best_checkpoint_metric: "loss"
|
||||
maximize_best_checkpoint_metric: false
|
||||
patience: -1
|
||||
checkpoint_suffix: ""
|
||||
bmuf:
|
||||
block_lr: 1
|
||||
block_momentum: 0.875
|
||||
global_sync_iter: 50
|
||||
warmup_iterations: 500
|
||||
use_nbm: false
|
||||
average_sync: false
|
||||
defaults:
|
||||
- task: language_modeling
|
||||
- model: null
|
||||
- criterion: null
|
||||
- optimizer: null
|
||||
- lr_scheduler: null
|
||||
- bpe: null
|
||||
- tokenizer: null
|
||||
- scoring: null
|
||||
- generation: null
|
||||
- common_eval: null
|
||||
- eval_lm: null
|
|
@ -0,0 +1,3 @@
|
|||
# @package _group_
|
||||
sentence_avg: ${optimization.sentence_avg}
|
||||
ddp_backend: ${distributed_training.ddp_backend}
|
|
@ -0,0 +1,2 @@
|
|||
# @package _group_
|
||||
sentence_avg: ${optimization.sentence_avg}
|
|
@ -0,0 +1,7 @@
|
|||
# @package _group_
|
||||
warmup_updates: 0
|
||||
warmup_init_lr: -1
|
||||
max_lr: 1.0
|
||||
t_mult: 1.0
|
||||
lr_period_updates: -1
|
||||
lr_shrink: 0.1
|
|
@ -0,0 +1,3 @@
|
|||
# @package _group_
|
||||
warmup_updates: 4000
|
||||
warmup_init_lr: -1
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 512
|
||||
decoder_output_dim: 512
|
||||
decoder_input_dim: 512
|
||||
decoder_ffn_embed_dim: 2048
|
||||
decoder_layers: 6
|
||||
decoder_attention_heads: 8
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 512
|
||||
decoder_output_dim: 512
|
||||
decoder_input_dim: 512
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.3
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.1
|
||||
relu_dropout: 0.1
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 16
|
||||
decoder_attention_heads: 8
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: "20000,60000"
|
||||
adaptive_softmax_dropout: 0.2
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: true
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: "20000,60000"
|
||||
tie_adaptive_weights: true
|
||||
tie_adaptive_proj: true
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.0
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 512
|
||||
decoder_output_dim: 512
|
||||
decoder_input_dim: 512
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 768
|
||||
decoder_output_dim: 768
|
||||
decoder_input_dim: 768
|
||||
decoder_ffn_embed_dim: 3072
|
||||
decoder_layers: 12
|
||||
decoder_attention_heads: 12
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1600
|
||||
decoder_output_dim: 1600
|
||||
decoder_input_dim: 1600
|
||||
decoder_ffn_embed_dim: 6400
|
||||
decoder_layers: 48
|
||||
decoder_attention_heads: 25
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1280
|
||||
decoder_output_dim: 1280
|
||||
decoder_input_dim: 1280
|
||||
decoder_ffn_embed_dim: 5120
|
||||
decoder_layers: 36
|
||||
decoder_attention_heads: 20
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "gelu"
|
||||
dropout: 0.1
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.0
|
||||
relu_dropout: 0.0
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 24
|
||||
decoder_attention_heads: 16
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: false
|
||||
adaptive_softmax_cutoff: null
|
||||
adaptive_softmax_dropout: 0
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: false
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: null
|
||||
tie_adaptive_weights: false
|
||||
tie_adaptive_proj: false
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,36 @@
|
|||
# @package _group_
|
||||
activation_fn: "relu"
|
||||
dropout: 0.3
|
||||
attention_dropout: 0.1
|
||||
activation_dropout: 0.1
|
||||
relu_dropout: 0.1
|
||||
decoder_embed_dim: 1024
|
||||
decoder_output_dim: 1024
|
||||
decoder_input_dim: 1024
|
||||
decoder_ffn_embed_dim: 4096
|
||||
decoder_layers: 16
|
||||
decoder_attention_heads: 8
|
||||
decoder_normalize_before: true
|
||||
no_decoder_final_norm: true
|
||||
adaptive_softmax_cutoff: "20000,60000"
|
||||
adaptive_softmax_dropout: 0.2
|
||||
adaptive_softmax_factor: 4
|
||||
no_token_positional_embeddings: false
|
||||
share_decoder_input_output_embed: false
|
||||
character_embeddings: false
|
||||
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
||||
character_embedding_dim: 4
|
||||
char_embedder_highway_layers: 2
|
||||
adaptive_input: true
|
||||
adaptive_input_factor: 4
|
||||
adaptive_input_cutoff: "20000,60000"
|
||||
tie_adaptive_weights: true
|
||||
tie_adaptive_proj: true
|
||||
decoder_learned_pos: false
|
||||
decoder_layerdrop: 0
|
||||
decoder_layers_to_keep: null
|
||||
layernorm_embedding: false
|
||||
no_scale_embedding: false
|
||||
quant_noise_pq: 0
|
||||
quant_noise_pq_block_size: 8
|
||||
quant_noise_scalar: 0
|
|
@ -0,0 +1,5 @@
|
|||
# @package _group_
|
||||
adam_betas: "(0.9, 0.999)"
|
||||
adam_eps: 1.0e-8
|
||||
weight_decay: 0
|
||||
use_old_adam: false
|
|
@ -0,0 +1,3 @@
|
|||
# @package _group_
|
||||
momentum: 0.99
|
||||
weight_decay: 0.0
|
|
@ -0,0 +1,10 @@
|
|||
# @package _group_
|
||||
data: ???
|
||||
sample_break_mode: "none"
|
||||
tokens_per_sample: 1024
|
||||
output_dictionary_size: -1
|
||||
self_target: false
|
||||
future_target: false
|
||||
past_target: false
|
||||
add_bos_token: false
|
||||
max_target_positions: null
|
|
@ -0,0 +1,2 @@
|
|||
!*/*.sh
|
||||
!*/*.md
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq import __version__ # noqa
|
|
@ -0,0 +1,106 @@
|
|||
# Speech Recognition
|
||||
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
|
||||
|
||||
|
||||
## Additional dependencies
|
||||
On top of main fairseq dependencies there are couple more additional requirements.
|
||||
|
||||
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
|
||||
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
|
||||
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
|
||||
|
||||
## Preparing librispeech data
|
||||
```
|
||||
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
|
||||
```
|
||||
|
||||
## Training librispeech data
|
||||
```
|
||||
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
|
||||
```
|
||||
|
||||
## Inference for librispeech
|
||||
`$SET` can be `test_clean` or `test_other`
|
||||
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
|
||||
```
|
||||
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
|
||||
```
|
||||
|
||||
## Inference for librispeech
|
||||
```
|
||||
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
|
||||
```
|
||||
`Sum/Avg` row from first table of the report has WER
|
||||
|
||||
## Using wav2letter components
|
||||
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
|
||||
|
||||
* AutoSegmentationCriterion (ASG)
|
||||
* wav2letter-style Conv/GLU model
|
||||
* wav2letter's beam search decoder
|
||||
|
||||
To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
|
||||
|
||||
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
|
||||
```
|
||||
# additional prerequisites - use equivalents for your distro
|
||||
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
|
||||
# install KenLM from source
|
||||
git clone https://github.com/kpu/kenlm.git
|
||||
cd kenlm
|
||||
mkdir -p build && cd build
|
||||
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
||||
make -j16
|
||||
cd ..
|
||||
export KENLM_ROOT_DIR=$(pwd)
|
||||
cd ..
|
||||
# install wav2letter python bindings
|
||||
git clone https://github.com/facebookresearch/wav2letter.git
|
||||
cd wav2letter/bindings/python
|
||||
# make sure your python environment is active at this point
|
||||
pip install torch packaging
|
||||
pip install -e .
|
||||
# try some examples to verify installation succeeded
|
||||
python ./examples/criterion_example.py
|
||||
python ./examples/decoder_example.py ../../src/decoder/test
|
||||
python ./examples/feature_example.py ../../src/feature/test/data
|
||||
```
|
||||
|
||||
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
|
||||
Training command:
|
||||
```
|
||||
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
|
||||
```
|
||||
|
||||
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
|
||||
|
||||
## Inference for librispeech (wav2letter decoder, n-gram LM)
|
||||
Inference command:
|
||||
```
|
||||
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
||||
```
|
||||
|
||||
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
|
||||
```
|
||||
doorbell D O 1 R B E L 1 ▁
|
||||
```
|
||||
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
|
||||
```
|
||||
doorbell ▁DOOR BE LL
|
||||
doorbell ▁DOOR B E LL
|
||||
doorbell ▁DO OR BE LL
|
||||
doorbell ▁DOOR B EL L
|
||||
doorbell ▁DOOR BE L L
|
||||
doorbell ▁DO OR B E LL
|
||||
doorbell ▁DOOR B E L L
|
||||
doorbell ▁DO OR B EL L
|
||||
doorbell ▁DO O R BE LL
|
||||
doorbell ▁DO OR BE L L
|
||||
```
|
||||
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
|
||||
|
||||
## Inference for librispeech (wav2letter decoder, viterbi only)
|
||||
Inference command:
|
||||
```
|
||||
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
||||
```
|
|
@ -0,0 +1 @@
|
|||
from . import criterions, models, tasks # noqa
|
|
@ -0,0 +1,170 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
from examples.speech_recognition.data.replabels import pack_replabels
|
||||
from fairseq import utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("asg_loss")
|
||||
class ASGCriterion(FairseqCriterion):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
group = parser.add_argument_group("ASG Loss")
|
||||
group.add_argument(
|
||||
"--asg-transitions-init",
|
||||
help="initial diagonal value of transition matrix",
|
||||
type=float,
|
||||
default=0.0,
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-replabel", help="maximum # of replabels", type=int, default=2
|
||||
)
|
||||
group.add_argument(
|
||||
"--linseg-updates",
|
||||
help="# of training updates to use LinSeg initialization",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
group.add_argument(
|
||||
"--hide-linseg-messages",
|
||||
help="hide messages about LinSeg initialization",
|
||||
action="store_true",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
silence_token,
|
||||
asg_transitions_init,
|
||||
max_replabel,
|
||||
linseg_updates,
|
||||
hide_linseg_messages,
|
||||
):
|
||||
from wav2letter.criterion import ASGLoss, CriterionScaleMode
|
||||
|
||||
super().__init__(task)
|
||||
self.tgt_dict = task.target_dictionary
|
||||
self.eos = self.tgt_dict.eos()
|
||||
self.silence = (
|
||||
self.tgt_dict.index(silence_token)
|
||||
if silence_token in self.tgt_dict
|
||||
else None
|
||||
)
|
||||
self.max_replabel = max_replabel
|
||||
|
||||
num_labels = len(self.tgt_dict)
|
||||
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
|
||||
self.asg.trans = torch.nn.Parameter(
|
||||
asg_transitions_init * torch.eye(num_labels), requires_grad=True
|
||||
)
|
||||
|
||||
self.linseg_progress = torch.nn.Parameter(
|
||||
torch.tensor([0], dtype=torch.int), requires_grad=False
|
||||
)
|
||||
self.linseg_maximum = linseg_updates
|
||||
self.linseg_message_state = "none" if hide_linseg_messages else "start"
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
return cls(
|
||||
task,
|
||||
args.silence_token,
|
||||
args.asg_transitions_init,
|
||||
args.max_replabel,
|
||||
args.linseg_updates,
|
||||
args.hide_linseg_messages,
|
||||
)
|
||||
|
||||
def linseg_step(self):
|
||||
if not self.training:
|
||||
return False
|
||||
if self.linseg_progress.item() < self.linseg_maximum:
|
||||
if self.linseg_message_state == "start":
|
||||
print("| using LinSeg to initialize ASG")
|
||||
self.linseg_message_state = "finish"
|
||||
self.linseg_progress.add_(1)
|
||||
return True
|
||||
elif self.linseg_message_state == "finish":
|
||||
print("| finished LinSeg initialization")
|
||||
self.linseg_message_state = "none"
|
||||
return False
|
||||
|
||||
def replace_eos_with_silence(self, tgt):
|
||||
if tgt[-1] != self.eos:
|
||||
return tgt
|
||||
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
|
||||
return tgt[:-1]
|
||||
else:
|
||||
return tgt[:-1] + [self.silence]
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
|
||||
net_output = model(**sample["net_input"])
|
||||
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
|
||||
B = emissions.size(0)
|
||||
T = emissions.size(1)
|
||||
device = emissions.device
|
||||
|
||||
target = torch.IntTensor(B, T)
|
||||
target_size = torch.IntTensor(B)
|
||||
using_linseg = self.linseg_step()
|
||||
|
||||
for b in range(B):
|
||||
initial_target_size = sample["target_lengths"][b].item()
|
||||
if initial_target_size == 0:
|
||||
raise ValueError("target size cannot be zero")
|
||||
|
||||
tgt = sample["target"][b, :initial_target_size].tolist()
|
||||
tgt = self.replace_eos_with_silence(tgt)
|
||||
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
|
||||
tgt = tgt[:T]
|
||||
|
||||
if using_linseg:
|
||||
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
|
||||
|
||||
target[b][: len(tgt)] = torch.IntTensor(tgt)
|
||||
target_size[b] = len(tgt)
|
||||
|
||||
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
|
||||
|
||||
if reduce:
|
||||
loss = torch.sum(loss)
|
||||
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
agg_output = {
|
||||
"loss": loss_sum / nsentences,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
return agg_output
|
|
@ -0,0 +1,17 @@
|
|||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
# ASG loss requires wav2letter
|
||||
files_to_skip = set()
|
||||
try:
|
||||
import wav2letter
|
||||
except ImportError:
|
||||
files_to_skip.add("ASG_loss.py")
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
|
||||
criterion_name = file[: file.find(".py")]
|
||||
importlib.import_module(
|
||||
"examples.speech_recognition.criterions." + criterion_name
|
||||
)
|
|
@ -0,0 +1,130 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
@register_criterion("cross_entropy_acc")
|
||||
class CrossEntropyWithAccCriterion(FairseqCriterion):
|
||||
def __init__(self, task, sentence_avg):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
|
||||
def compute_loss(self, model, net_output, target, reduction, log_probs):
|
||||
# N, T -> N * T
|
||||
target = target.view(-1)
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
|
||||
if not hasattr(lprobs, "batch_first"):
|
||||
logging.warning(
|
||||
"ERROR: we need to know whether "
|
||||
"batch first for the net output; "
|
||||
"you need to set batch_first attribute for the return value of "
|
||||
"model.get_normalized_probs. Now, we assume this is true, but "
|
||||
"in the future, we will raise exception instead. "
|
||||
)
|
||||
batch_first = getattr(lprobs, "batch_first", True)
|
||||
if not batch_first:
|
||||
lprobs = lprobs.transpose(0, 1)
|
||||
|
||||
# N, T, D -> N * T, D
|
||||
lprobs = lprobs.view(-1, lprobs.size(-1))
|
||||
loss = F.nll_loss(
|
||||
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
|
||||
)
|
||||
return lprobs, loss
|
||||
|
||||
def get_logging_output(self, sample, target, lprobs, loss):
|
||||
target = target.view(-1)
|
||||
mask = target != self.padding_idx
|
||||
correct = torch.sum(
|
||||
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
|
||||
)
|
||||
total = torch.sum(mask)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data), # * sample['ntokens'],
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
"correct": utils.item(correct.data),
|
||||
"total": utils.item(total.data),
|
||||
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
|
||||
}
|
||||
|
||||
return sample_size, logging_output
|
||||
|
||||
def forward(self, model, sample, reduction="sum", log_probs=True):
|
||||
"""Computes the cross entropy with accuracy metric for the given sample.
|
||||
|
||||
This is similar to CrossEntropyCriterion in fairseq, but also
|
||||
computes accuracy metrics as part of logging
|
||||
|
||||
Args:
|
||||
logprobs (Torch.tensor) of shape N, T, D i.e.
|
||||
batchsize, timesteps, dimensions
|
||||
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
|
||||
|
||||
Returns:
|
||||
tuple: With three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
|
||||
TODO:
|
||||
* Currently this Criterion will only work with LSTMEncoderModels or
|
||||
FairseqModels which have decoder, or Models which return TorchTensor
|
||||
as net_output.
|
||||
We need to make a change to support all FairseqEncoder models.
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
target = model.get_targets(sample, net_output)
|
||||
lprobs, loss = self.compute_loss(
|
||||
model, net_output, target, reduction, log_probs
|
||||
)
|
||||
sample_size, logging_output = self.get_logging_output(
|
||||
sample, target, lprobs, loss
|
||||
)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(logging_outputs):
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
|
||||
total_sum = sum(log.get("total", 0) for log in logging_outputs)
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
|
||||
agg_output = {
|
||||
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
||||
# if args.sentence_avg, then sample_size is nsentences, then loss
|
||||
# is per-sentence loss; else sample_size is ntokens, the loss
|
||||
# becomes per-output token loss
|
||||
"ntokens": ntokens,
|
||||
"nsentences": nsentences,
|
||||
"nframes": nframes,
|
||||
"sample_size": sample_size,
|
||||
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
|
||||
"correct": correct_sum,
|
||||
"total": total_sum,
|
||||
# total is the number of validate tokens
|
||||
}
|
||||
if sample_size != ntokens:
|
||||
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
|
||||
# loss: per output token loss
|
||||
# nll_loss: per sentence loss
|
||||
return agg_output
|
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .asr_dataset import AsrDataset
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AsrDataset",
|
||||
]
|
|
@ -0,0 +1,122 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data import FairseqDataset
|
||||
|
||||
from . import data_utils
|
||||
from .collaters import Seq2SeqCollater
|
||||
|
||||
|
||||
class AsrDataset(FairseqDataset):
|
||||
"""
|
||||
A dataset representing speech and corresponding transcription.
|
||||
|
||||
Args:
|
||||
aud_paths: (List[str]): A list of str with paths to audio files.
|
||||
aud_durations_ms (List[int]): A list of int containing the durations of
|
||||
audio files.
|
||||
tgt (List[torch.LongTensor]): A list of LongTensors containing the indices
|
||||
of target transcriptions.
|
||||
tgt_dict (~fairseq.data.Dictionary): target vocabulary.
|
||||
ids (List[str]): A list of utterance IDs.
|
||||
speakers (List[str]): A list of speakers corresponding to utterances.
|
||||
num_mel_bins (int): Number of triangular mel-frequency bins (default: 80)
|
||||
frame_length (float): Frame length in milliseconds (default: 25.0)
|
||||
frame_shift (float): Frame shift in milliseconds (default: 10.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aud_paths,
|
||||
aud_durations_ms,
|
||||
tgt,
|
||||
tgt_dict,
|
||||
ids,
|
||||
speakers,
|
||||
num_mel_bins=80,
|
||||
frame_length=25.0,
|
||||
frame_shift=10.0,
|
||||
):
|
||||
assert frame_length > 0
|
||||
assert frame_shift > 0
|
||||
assert all(x > frame_length for x in aud_durations_ms)
|
||||
self.frame_sizes = [
|
||||
int(1 + (d - frame_length) / frame_shift) for d in aud_durations_ms
|
||||
]
|
||||
|
||||
assert len(aud_paths) > 0
|
||||
assert len(aud_paths) == len(aud_durations_ms)
|
||||
assert len(aud_paths) == len(tgt)
|
||||
assert len(aud_paths) == len(ids)
|
||||
assert len(aud_paths) == len(speakers)
|
||||
self.aud_paths = aud_paths
|
||||
self.tgt_dict = tgt_dict
|
||||
self.tgt = tgt
|
||||
self.ids = ids
|
||||
self.speakers = speakers
|
||||
self.num_mel_bins = num_mel_bins
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
|
||||
self.s2s_collater = Seq2SeqCollater(
|
||||
0,
|
||||
1,
|
||||
pad_index=self.tgt_dict.pad(),
|
||||
eos_index=self.tgt_dict.eos(),
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
import torchaudio
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
tgt_item = self.tgt[index] if self.tgt is not None else None
|
||||
|
||||
path = self.aud_paths[index]
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError("Audio file not found: {}".format(path))
|
||||
sound, sample_rate = torchaudio.load_wav(path)
|
||||
output = kaldi.fbank(
|
||||
sound,
|
||||
num_mel_bins=self.num_mel_bins,
|
||||
frame_length=self.frame_length,
|
||||
frame_shift=self.frame_shift,
|
||||
)
|
||||
output_cmvn = data_utils.apply_mv_norm(output)
|
||||
|
||||
return {"id": index, "data": [output_cmvn.detach(), tgt_item]}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.aud_paths)
|
||||
|
||||
def collater(self, samples):
|
||||
"""Merge a list of samples to form a mini-batch.
|
||||
|
||||
Args:
|
||||
samples (List[int]): sample indices to collate
|
||||
|
||||
Returns:
|
||||
dict: a mini-batch suitable for forwarding with a Model
|
||||
"""
|
||||
return self.s2s_collater.collate(samples)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.frame_sizes[index]
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
return (
|
||||
self.frame_sizes[index],
|
||||
len(self.tgt[index]) if self.tgt is not None else 0,
|
||||
)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
return np.arange(len(self))
|
|
@ -0,0 +1,131 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
This module contains collection of classes which implement
|
||||
collate functionalities for various tasks.
|
||||
|
||||
Collaters should know what data to expect for each sample
|
||||
and they should pack / collate them into batches
|
||||
"""
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import data_utils as fairseq_data_utils
|
||||
|
||||
|
||||
class Seq2SeqCollater(object):
|
||||
"""
|
||||
Implements collate function mainly for seq2seq tasks
|
||||
This expects each sample to contain feature (src_tokens) and
|
||||
targets.
|
||||
This collator is also used for aligned training task.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_index=0,
|
||||
label_index=1,
|
||||
pad_index=1,
|
||||
eos_index=2,
|
||||
move_eos_to_beginning=True,
|
||||
):
|
||||
self.feature_index = feature_index
|
||||
self.label_index = label_index
|
||||
self.pad_index = pad_index
|
||||
self.eos_index = eos_index
|
||||
self.move_eos_to_beginning = move_eos_to_beginning
|
||||
|
||||
def _collate_frames(self, frames):
|
||||
"""Convert a list of 2d frames into a padded 3d tensor
|
||||
Args:
|
||||
frames (list): list of 2d frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3d tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
len_max = max(frame.size(0) for frame in frames)
|
||||
f_dim = frames[0].size(1)
|
||||
res = frames[0].new(len(frames), len_max, f_dim).fill_(0.0)
|
||||
|
||||
for i, v in enumerate(frames):
|
||||
res[i, : v.size(0)] = v
|
||||
|
||||
return res
|
||||
|
||||
def collate(self, samples):
|
||||
"""
|
||||
utility function to collate samples into batch for speech recognition.
|
||||
"""
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
# parse samples into torch tensors
|
||||
parsed_samples = []
|
||||
for s in samples:
|
||||
# skip invalid samples
|
||||
if s["data"][self.feature_index] is None:
|
||||
continue
|
||||
source = s["data"][self.feature_index]
|
||||
if isinstance(source, (np.ndarray, np.generic)):
|
||||
source = torch.from_numpy(source)
|
||||
target = s["data"][self.label_index]
|
||||
if isinstance(target, (np.ndarray, np.generic)):
|
||||
target = torch.from_numpy(target).long()
|
||||
elif isinstance(target, list):
|
||||
target = torch.LongTensor(target)
|
||||
|
||||
parsed_sample = {"id": s["id"], "source": source, "target": target}
|
||||
parsed_samples.append(parsed_sample)
|
||||
samples = parsed_samples
|
||||
|
||||
id = torch.LongTensor([s["id"] for s in samples])
|
||||
frames = self._collate_frames([s["source"] for s in samples])
|
||||
# sort samples by descending number of frames
|
||||
frames_lengths = torch.LongTensor([s["source"].size(0) for s in samples])
|
||||
frames_lengths, sort_order = frames_lengths.sort(descending=True)
|
||||
id = id.index_select(0, sort_order)
|
||||
frames = frames.index_select(0, sort_order)
|
||||
|
||||
target = None
|
||||
target_lengths = None
|
||||
prev_output_tokens = None
|
||||
if samples[0].get("target", None) is not None:
|
||||
ntokens = sum(len(s["target"]) for s in samples)
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[s["target"] for s in samples],
|
||||
self.pad_index,
|
||||
self.eos_index,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
target = target.index_select(0, sort_order)
|
||||
target_lengths = torch.LongTensor(
|
||||
[s["target"].size(0) for s in samples]
|
||||
).index_select(0, sort_order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[s["target"] for s in samples],
|
||||
self.pad_index,
|
||||
self.eos_index,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=self.move_eos_to_beginning,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
|
||||
else:
|
||||
ntokens = sum(len(s["source"]) for s in samples)
|
||||
|
||||
batch = {
|
||||
"id": id,
|
||||
"ntokens": ntokens,
|
||||
"net_input": {"src_tokens": frames, "src_lengths": frames_lengths},
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
if prev_output_tokens is not None:
|
||||
batch["net_input"]["prev_output_tokens"] = prev_output_tokens
|
||||
return batch
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def calc_mean_invstddev(feature):
|
||||
if len(feature.size()) != 2:
|
||||
raise ValueError("We expect the input feature to be 2-D tensor")
|
||||
mean = feature.mean(0)
|
||||
var = feature.var(0)
|
||||
# avoid division by ~zero
|
||||
eps = 1e-8
|
||||
if (var < eps).any():
|
||||
return mean, 1.0 / (torch.sqrt(var) + eps)
|
||||
return mean, 1.0 / torch.sqrt(var)
|
||||
|
||||
|
||||
def apply_mv_norm(features):
|
||||
# If there is less than 2 spectrograms, the variance cannot be computed (is NaN)
|
||||
# and normalization is not possible, so return the item as it is
|
||||
if features.size(0) < 2:
|
||||
return features
|
||||
mean, invstddev = calc_mean_invstddev(features)
|
||||
res = (features - mean) * invstddev
|
||||
return res
|
||||
|
||||
|
||||
def lengths_to_encoder_padding_mask(lengths, batch_first=False):
|
||||
"""
|
||||
convert lengths (a 1-D Long/Int tensor) to 2-D binary tensor
|
||||
|
||||
Args:
|
||||
lengths: a (B, )-shaped tensor
|
||||
|
||||
Return:
|
||||
max_length: maximum length of B sequences
|
||||
encoder_padding_mask: a (max_length, B) binary mask, where
|
||||
[t, b] = 0 for t < lengths[b] and 1 otherwise
|
||||
|
||||
TODO:
|
||||
kernelize this function if benchmarking shows this function is slow
|
||||
"""
|
||||
max_lengths = torch.max(lengths).item()
|
||||
bsz = lengths.size(0)
|
||||
encoder_padding_mask = torch.arange(
|
||||
max_lengths
|
||||
).to( # a (T, ) tensor with [0, ..., T-1]
|
||||
lengths.device
|
||||
).view( # move to the right device
|
||||
1, max_lengths
|
||||
).expand( # reshape to (1, T)-shaped tensor
|
||||
bsz, -1
|
||||
) >= lengths.view( # expand to (B, T)-shaped tensor
|
||||
bsz, 1
|
||||
).expand(
|
||||
-1, max_lengths
|
||||
)
|
||||
if not batch_first:
|
||||
return encoder_padding_mask.t(), max_lengths
|
||||
else:
|
||||
return encoder_padding_mask, max_lengths
|
||||
|
||||
|
||||
def encoder_padding_mask_to_lengths(
|
||||
encoder_padding_mask, max_lengths, batch_size, device
|
||||
):
|
||||
"""
|
||||
convert encoder_padding_mask (2-D binary tensor) to a 1-D tensor
|
||||
|
||||
Conventionally, encoder output contains a encoder_padding_mask, which is
|
||||
a 2-D mask in a shape (T, B), whose (t, b) element indicate whether
|
||||
encoder_out[t, b] is a valid output (=0) or not (=1). Occasionally, we
|
||||
need to convert this mask tensor to a 1-D tensor in shape (B, ), where
|
||||
[b] denotes the valid length of b-th sequence
|
||||
|
||||
Args:
|
||||
encoder_padding_mask: a (T, B)-shaped binary tensor or None; if None,
|
||||
indicating all are valid
|
||||
Return:
|
||||
seq_lengths: a (B,)-shaped tensor, where its (b, )-th element is the
|
||||
number of valid elements of b-th sequence
|
||||
|
||||
max_lengths: maximum length of all sequence, if encoder_padding_mask is
|
||||
not None, max_lengths must equal to encoder_padding_mask.size(0)
|
||||
|
||||
batch_size: batch size; if encoder_padding_mask is
|
||||
not None, max_lengths must equal to encoder_padding_mask.size(1)
|
||||
|
||||
device: which device to put the result on
|
||||
"""
|
||||
if encoder_padding_mask is None:
|
||||
return torch.Tensor([max_lengths] * batch_size).to(torch.int32).to(device)
|
||||
|
||||
assert encoder_padding_mask.size(0) == max_lengths, "max_lengths does not match"
|
||||
assert encoder_padding_mask.size(1) == batch_size, "batch_size does not match"
|
||||
|
||||
return max_lengths - torch.sum(encoder_padding_mask, dim=0)
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Replabel transforms for use with wav2letter's ASG criterion.
|
||||
"""
|
||||
|
||||
|
||||
def replabel_symbol(i):
|
||||
"""
|
||||
Replabel symbols used in wav2letter, currently just "1", "2", ...
|
||||
This prevents training with numeral tokens, so this might change in the future
|
||||
"""
|
||||
return str(i)
|
||||
|
||||
|
||||
def pack_replabels(tokens, dictionary, max_reps):
|
||||
"""
|
||||
Pack a token sequence so that repeated symbols are replaced by replabels
|
||||
"""
|
||||
if len(tokens) == 0 or max_reps <= 0:
|
||||
return tokens
|
||||
|
||||
replabel_value_to_idx = [0] * (max_reps + 1)
|
||||
for i in range(1, max_reps + 1):
|
||||
replabel_value_to_idx[i] = dictionary.index(replabel_symbol(i))
|
||||
|
||||
result = []
|
||||
prev_token = -1
|
||||
num_reps = 0
|
||||
for token in tokens:
|
||||
if token == prev_token and num_reps < max_reps:
|
||||
num_reps += 1
|
||||
else:
|
||||
if num_reps > 0:
|
||||
result.append(replabel_value_to_idx[num_reps])
|
||||
num_reps = 0
|
||||
result.append(token)
|
||||
prev_token = token
|
||||
if num_reps > 0:
|
||||
result.append(replabel_value_to_idx[num_reps])
|
||||
return result
|
||||
|
||||
|
||||
def unpack_replabels(tokens, dictionary, max_reps):
|
||||
"""
|
||||
Unpack a token sequence so that replabels are replaced by repeated symbols
|
||||
"""
|
||||
if len(tokens) == 0 or max_reps <= 0:
|
||||
return tokens
|
||||
|
||||
replabel_idx_to_value = {}
|
||||
for i in range(1, max_reps + 1):
|
||||
replabel_idx_to_value[dictionary.index(replabel_symbol(i))] = i
|
||||
|
||||
result = []
|
||||
prev_token = -1
|
||||
for token in tokens:
|
||||
try:
|
||||
for _ in range(replabel_idx_to_value[token]):
|
||||
result.append(prev_token)
|
||||
prev_token = -1
|
||||
except KeyError:
|
||||
result.append(token)
|
||||
prev_token = token
|
||||
return result
|
|
@ -0,0 +1,125 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from itertools import chain
|
||||
|
||||
import sentencepiece as spm
|
||||
from fairseq.data import Dictionary
|
||||
|
||||
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
|
||||
|
||||
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
||||
import torchaudio
|
||||
|
||||
input = {}
|
||||
output = {}
|
||||
si, ei = torchaudio.info(aud_path)
|
||||
input["length_ms"] = int(
|
||||
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
|
||||
)
|
||||
input["path"] = aud_path
|
||||
|
||||
token = " ".join(sp.EncodeAsPieces(lable))
|
||||
ids = tgt_dict.encode_line(token, append_eos=False)
|
||||
output["text"] = lable
|
||||
output["token"] = token
|
||||
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
|
||||
return {utt_id: {"input": input, "output": output}}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--audio-dirs",
|
||||
nargs="+",
|
||||
default=["-"],
|
||||
required=True,
|
||||
help="input directories with audio files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--labels",
|
||||
required=True,
|
||||
help="aggregated input labels with format <ID LABEL> per line",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--spm-model",
|
||||
required=True,
|
||||
help="sentencepiece model to use for encoding",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dictionary",
|
||||
required=True,
|
||||
help="file to load fairseq dictionary from",
|
||||
type=argparse.FileType("r", encoding="UTF-8"),
|
||||
)
|
||||
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=True,
|
||||
type=argparse.FileType("w"),
|
||||
help="path to save json output",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.Load(args.spm_model.name)
|
||||
|
||||
tgt_dict = Dictionary.load(args.dictionary)
|
||||
|
||||
labels = {}
|
||||
for line in args.labels:
|
||||
(utt_id, label) = line.split(" ", 1)
|
||||
labels[utt_id] = label
|
||||
if len(labels) == 0:
|
||||
raise Exception("No labels found in ", args.labels_path)
|
||||
|
||||
Sample = namedtuple("Sample", "aud_path utt_id")
|
||||
samples = []
|
||||
for path, _, files in chain.from_iterable(
|
||||
os.walk(path) for path in args.audio_dirs
|
||||
):
|
||||
for f in files:
|
||||
if f.endswith(args.audio_format):
|
||||
if len(os.path.splitext(f)) != 2:
|
||||
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
|
||||
utt_id = os.path.splitext(f)[0]
|
||||
if utt_id not in labels:
|
||||
continue
|
||||
samples.append(Sample(os.path.join(path, f), utt_id))
|
||||
|
||||
utts = {}
|
||||
num_cpu = multiprocessing.cpu_count()
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
||||
future_to_sample = {
|
||||
executor.submit(
|
||||
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
|
||||
): s
|
||||
for s in samples
|
||||
}
|
||||
for future in concurrent.futures.as_completed(future_to_sample):
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
print("generated an exception: ", exc)
|
||||
else:
|
||||
utts.update(data)
|
||||
json.dump({"utts": utts}, args.output, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env bash
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Prepare librispeech dataset
|
||||
|
||||
base_url=www.openslr.org/resources/12
|
||||
train_dir=train_960
|
||||
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <download_dir> <out_dir>"
|
||||
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
download_dir=${1%/}
|
||||
out_dir=${2%/}
|
||||
|
||||
fairseq_root=~/fairseq-py/
|
||||
mkdir -p ${out_dir}
|
||||
cd ${out_dir} || exit
|
||||
|
||||
nbpe=5000
|
||||
bpemode=unigram
|
||||
|
||||
if [ ! -d "$fairseq_root" ]; then
|
||||
echo "$0: Please set correct fairseq_root"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Data Download"
|
||||
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
||||
url=$base_url/$part.tar.gz
|
||||
if ! wget -P $download_dir $url; then
|
||||
echo "$0: wget failed for $url"
|
||||
exit 1
|
||||
fi
|
||||
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
|
||||
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "Merge all train packs into one"
|
||||
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
|
||||
for part in train-clean-100 train-clean-360 train-other-500; do
|
||||
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
|
||||
done
|
||||
echo "Merge train text"
|
||||
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
|
||||
|
||||
# Use combined dev-clean and dev-other as validation set
|
||||
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
|
||||
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
|
||||
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
|
||||
|
||||
|
||||
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
|
||||
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
|
||||
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
|
||||
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
|
||||
echo "dictionary: ${dict}"
|
||||
echo "Dictionary preparation"
|
||||
mkdir -p data/lang_char/
|
||||
echo "<unk> 3" > ${dict}
|
||||
echo "</s> 2" >> ${dict}
|
||||
echo "<pad> 1" >> ${dict}
|
||||
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
|
||||
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
|
||||
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
|
||||
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
|
||||
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
|
||||
wc -l ${dict}
|
||||
|
||||
echo "Prepare train and test jsons"
|
||||
for part in train_960 test-other test-clean; do
|
||||
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
|
||||
done
|
||||
# fairseq expects to find train.json and valid.json during training
|
||||
mv train_960.json train.json
|
||||
|
||||
echo "Prepare valid json"
|
||||
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
|
||||
|
||||
cp ${fairseq_dict} ./dict.txt
|
||||
cp ${bpemodel}.model ./spm.model
|
|
@ -0,0 +1,474 @@
|
|||
#!/usr/bin/env python3 -u
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Run inference for pre-processed data with a trained model.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
if '/home/fairseq' in sys.path:
|
||||
sys.path.remove('/home/fairseq')
|
||||
sys.path.append('/datablob/users/v-chengw/pyproj/fairseq')
|
||||
|
||||
import editdistance
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils, pdb
|
||||
from fairseq.data.data_utils import post_process
|
||||
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
||||
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
||||
|
||||
|
||||
logging.basicConfig()
|
||||
logging.root.setLevel(logging.INFO)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_asr_eval_argument(parser):
|
||||
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
|
||||
parser.add_argument(
|
||||
"--wfstlm", default=None, help="wfstlm on dictonary output units"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rnnt_decoding_type",
|
||||
default="greedy",
|
||||
help="wfstlm on dictonary\
|
||||
output units",
|
||||
)
|
||||
try:
|
||||
parser.add_argument(
|
||||
"--lm-weight",
|
||||
"--lm_weight",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="weight for lm while interpolating with neural score",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
parser.add_argument(
|
||||
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--w2l-decoder",
|
||||
choices=["viterbi", "kenlm", "fairseqlm"],
|
||||
help="use a w2l decoder",
|
||||
)
|
||||
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
||||
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
|
||||
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
||||
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
||||
parser.add_argument("--beam-size-token", type=float, default=100)
|
||||
parser.add_argument("--word-score", type=float, default=1.0)
|
||||
parser.add_argument("--unk-weight", type=float, default=-math.inf)
|
||||
parser.add_argument("--sil-weight", type=float, default=0.0)
|
||||
parser.add_argument(
|
||||
"--dump-emissions",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if present, dumps emissions into this file and exits",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dump-features",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if present, dumps features into this file and exits",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load-emissions",
|
||||
type=str,
|
||||
default=None,
|
||||
help="if present, loads emissions from this file",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def check_args(args):
|
||||
# assert args.path is not None, "--path required for generation!"
|
||||
# assert args.results_path is not None, "--results_path required for generation!"
|
||||
assert (
|
||||
not args.sampling or args.nbest == args.beam
|
||||
), "--sampling requires --nbest to be equal to --beam"
|
||||
assert (
|
||||
args.replace_unk is None or args.raw_text
|
||||
), "--replace-unk requires a raw text dataset (--raw-text)"
|
||||
|
||||
|
||||
def get_dataset_itr(args, task, models):
|
||||
return task.get_batch_iterator(
|
||||
dataset=task.dataset(args.gen_subset),
|
||||
max_tokens=args.max_tokens,
|
||||
max_sentences=args.batch_size,
|
||||
max_positions=(sys.maxsize, sys.maxsize),
|
||||
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
||||
required_batch_size_multiple=args.required_batch_size_multiple,
|
||||
num_shards=args.num_shards,
|
||||
shard_id=args.shard_id,
|
||||
num_workers=args.num_workers,
|
||||
data_buffer_size=args.data_buffer_size,
|
||||
).next_epoch_itr(shuffle=False)
|
||||
|
||||
|
||||
def process_predictions(
|
||||
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
||||
):
|
||||
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
||||
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
||||
|
||||
if "words" in hypo:
|
||||
hyp_words = " ".join(hypo["words"])
|
||||
else:
|
||||
hyp_words = post_process(hyp_pieces, args.post_process)
|
||||
|
||||
if res_files is not None:
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
||||
file=res_files["hypo.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(hyp_words, speaker, id),
|
||||
file=res_files["hypo.words"],
|
||||
)
|
||||
|
||||
tgt_pieces = tgt_dict.string(target_tokens)
|
||||
tgt_words = post_process(tgt_pieces, args.post_process)
|
||||
|
||||
if res_files is not None:
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
||||
file=res_files["ref.units"],
|
||||
)
|
||||
print(
|
||||
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
|
||||
)
|
||||
# only score top hypothesis
|
||||
if not args.quiet:
|
||||
logger.debug("HYPO:" + hyp_words)
|
||||
logger.debug("TARGET:" + tgt_words)
|
||||
logger.debug("___________________")
|
||||
|
||||
hyp_words = hyp_words.split()
|
||||
tgt_words = tgt_words.split()
|
||||
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
|
||||
|
||||
|
||||
def prepare_result_files(args):
|
||||
def get_res_file(file_prefix):
|
||||
if args.num_shards > 1:
|
||||
file_prefix = f"{args.shard_id}_{file_prefix}"
|
||||
path = os.path.join(
|
||||
args.results_path,
|
||||
"{}-{}-{}.txt".format(
|
||||
file_prefix, os.path.basename(args.path), args.gen_subset
|
||||
),
|
||||
)
|
||||
return open(path, "w", buffering=1)
|
||||
|
||||
if not args.results_path:
|
||||
return None
|
||||
|
||||
return {
|
||||
"hypo.words": get_res_file("hypo.word"),
|
||||
"hypo.units": get_res_file("hypo.units"),
|
||||
"ref.words": get_res_file("ref.word"),
|
||||
"ref.units": get_res_file("ref.units"),
|
||||
}
|
||||
|
||||
|
||||
def load_models_and_criterions(
|
||||
filenames, data_path, arg_overrides=None, task=None, model_state=None
|
||||
):
|
||||
models = []
|
||||
criterions = []
|
||||
|
||||
if arg_overrides is None:
|
||||
arg_overrides = {}
|
||||
|
||||
arg_overrides["wer_args"] = None
|
||||
arg_overrides["data"] = data_path
|
||||
|
||||
if filenames is None:
|
||||
assert model_state is not None
|
||||
filenames = [0]
|
||||
else:
|
||||
filenames = filenames.split(":")
|
||||
|
||||
for filename in filenames:
|
||||
if model_state is None:
|
||||
if not os.path.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
|
||||
else:
|
||||
state = model_state
|
||||
|
||||
if "cfg" in state:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
|
||||
if task is None:
|
||||
if hasattr(cfg.task, 'data'):
|
||||
cfg.task.data = data_path
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
model = task.build_model(cfg.model)
|
||||
model.load_state_dict(state["model"], strict=True)
|
||||
models.append(model)
|
||||
|
||||
criterion = task.build_criterion(cfg.criterion)
|
||||
if "criterion" in state:
|
||||
criterion.load_state_dict(state["criterion"], strict=True)
|
||||
criterions.append(criterion)
|
||||
return models, criterions, task
|
||||
|
||||
|
||||
def optimize_models(args, use_cuda, models):
|
||||
"""Optimize ensemble for generation"""
|
||||
for model in models:
|
||||
model.make_generation_fast_(
|
||||
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
||||
need_attn=args.print_alignment,
|
||||
)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
if use_cuda:
|
||||
model.cuda()
|
||||
|
||||
|
||||
class ExistingEmissionsDecoder(object):
|
||||
def __init__(self, decoder, emissions):
|
||||
self.decoder = decoder
|
||||
self.emissions = emissions
|
||||
|
||||
def generate(self, models, sample, **unused):
|
||||
ids = sample["id"].cpu().numpy()
|
||||
try:
|
||||
emissions = np.stack(self.emissions[ids])
|
||||
except:
|
||||
print([x.shape for x in self.emissions[ids]])
|
||||
raise Exception("invalid sizes")
|
||||
emissions = torch.from_numpy(emissions)
|
||||
return self.decoder.decode(emissions)
|
||||
|
||||
|
||||
def main(args, task=None, model_state=None):
|
||||
check_args(args)
|
||||
|
||||
if args.max_tokens is None and args.batch_size is None:
|
||||
args.max_tokens = 4000000
|
||||
logger.info(args)
|
||||
|
||||
use_cuda = torch.cuda.is_available() and not args.cpu
|
||||
|
||||
|
||||
logger.info("| decoding with criterion {}".format(args.criterion))
|
||||
|
||||
# Load ensemble
|
||||
if args.load_emissions:
|
||||
models, criterions = [], []
|
||||
task = tasks.setup_task(args)
|
||||
else:
|
||||
logger.info("| loading model(s) from {}".format(args.path))
|
||||
models, criterions, task = load_models_and_criterions(
|
||||
args.path,
|
||||
data_path=args.data,
|
||||
arg_overrides=eval(args.model_overrides), # noqa
|
||||
task=task,
|
||||
model_state=model_state,
|
||||
)
|
||||
optimize_models(args, use_cuda, models)
|
||||
|
||||
# Load dataset splits
|
||||
task.load_dataset(args.gen_subset)
|
||||
|
||||
# Set dictionary
|
||||
tgt_dict = task.target_dictionary
|
||||
|
||||
logger.info(
|
||||
"| {} {} {} examples".format(
|
||||
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
|
||||
)
|
||||
)
|
||||
|
||||
# hack to pass transitions to W2lDecoder
|
||||
if args.criterion == "asg_loss":
|
||||
trans = criterions[0].asg.trans.data
|
||||
args.asg_transitions = torch.flatten(trans).tolist()
|
||||
|
||||
# Load dataset (possibly sharded)
|
||||
itr = get_dataset_itr(args, task, models)
|
||||
|
||||
# Initialize generator
|
||||
gen_timer = StopwatchMeter()
|
||||
|
||||
def build_generator(args):
|
||||
w2l_decoder = getattr(args, "w2l_decoder", None)
|
||||
if w2l_decoder == "viterbi":
|
||||
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
||||
|
||||
return W2lViterbiDecoder(args, task.target_dictionary)
|
||||
elif w2l_decoder == "kenlm":
|
||||
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
||||
|
||||
return W2lKenLMDecoder(args, task.target_dictionary)
|
||||
elif w2l_decoder == "fairseqlm":
|
||||
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
|
||||
|
||||
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
||||
else:
|
||||
print(
|
||||
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
|
||||
)
|
||||
|
||||
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
||||
generator = build_generator(args)
|
||||
|
||||
if args.load_emissions:
|
||||
generator = ExistingEmissionsDecoder(
|
||||
generator, np.load(args.load_emissions, allow_pickle=True)
|
||||
)
|
||||
logger.info("loaded emissions from " + args.load_emissions)
|
||||
|
||||
num_sentences = 0
|
||||
|
||||
if args.results_path is not None and not os.path.exists(args.results_path):
|
||||
os.makedirs(args.results_path)
|
||||
|
||||
max_source_pos = (
|
||||
utils.resolve_max_positions(
|
||||
task.max_positions(), *[model.max_positions() for model in models]
|
||||
),
|
||||
)
|
||||
|
||||
if max_source_pos is not None:
|
||||
max_source_pos = max_source_pos[0]
|
||||
if max_source_pos is not None:
|
||||
max_source_pos = max_source_pos[0] - 1
|
||||
|
||||
if args.dump_emissions:
|
||||
emissions = {}
|
||||
if args.dump_features:
|
||||
features = {}
|
||||
models[0].bert.proj = None
|
||||
else:
|
||||
res_files = prepare_result_files(args)
|
||||
errs_t = 0
|
||||
lengths_t = 0
|
||||
with progress_bar.build_progress_bar(args, itr) as t:
|
||||
wps_meter = TimeMeter()
|
||||
for sample in t:
|
||||
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
||||
if "net_input" not in sample:
|
||||
continue
|
||||
|
||||
prefix_tokens = None
|
||||
if args.prefix_size > 0:
|
||||
prefix_tokens = sample["target"][:, : args.prefix_size]
|
||||
|
||||
gen_timer.start()
|
||||
if args.dump_emissions:
|
||||
with torch.no_grad():
|
||||
encoder_out = models[0](**sample["net_input"])
|
||||
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
|
||||
emm = emm.transpose(0, 1).cpu().numpy()
|
||||
for i, id in enumerate(sample["id"]):
|
||||
emissions[id.item()] = emm[i]
|
||||
continue
|
||||
elif args.dump_features:
|
||||
with torch.no_grad():
|
||||
encoder_out = models[0](**sample["net_input"])
|
||||
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
||||
for i, id in enumerate(sample["id"]):
|
||||
padding = (
|
||||
encoder_out["encoder_padding_mask"][i].cpu().numpy()
|
||||
if encoder_out["encoder_padding_mask"] is not None
|
||||
else None
|
||||
)
|
||||
features[id.item()] = (feat[i], padding)
|
||||
continue
|
||||
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
||||
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
||||
gen_timer.stop(num_generated_tokens)
|
||||
|
||||
for i, sample_id in enumerate(sample["id"].tolist()):
|
||||
speaker = None
|
||||
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
||||
id = sample_id
|
||||
toks = (
|
||||
sample["target"][i, :]
|
||||
if "target_label" not in sample
|
||||
else sample["target_label"][i, :]
|
||||
)
|
||||
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
||||
# Process top predictions
|
||||
errs, length = process_predictions(
|
||||
args,
|
||||
hypos[i],
|
||||
None,
|
||||
tgt_dict,
|
||||
target_tokens,
|
||||
res_files,
|
||||
speaker,
|
||||
id,
|
||||
)
|
||||
errs_t += errs
|
||||
lengths_t += length
|
||||
|
||||
wps_meter.update(num_generated_tokens)
|
||||
t.log({"wps": round(wps_meter.avg)})
|
||||
num_sentences += (
|
||||
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
||||
)
|
||||
|
||||
wer = None
|
||||
if args.dump_emissions:
|
||||
emm_arr = []
|
||||
for i in range(len(emissions)):
|
||||
emm_arr.append(emissions[i])
|
||||
np.save(args.dump_emissions, emm_arr)
|
||||
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
|
||||
elif args.dump_features:
|
||||
feat_arr = []
|
||||
for i in range(len(features)):
|
||||
feat_arr.append(features[i])
|
||||
np.save(args.dump_features, feat_arr)
|
||||
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
|
||||
else:
|
||||
if lengths_t > 0:
|
||||
wer = errs_t * 100.0 / lengths_t
|
||||
logger.info(f"WER: {wer}")
|
||||
|
||||
logger.info(
|
||||
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
|
||||
"sentences/s, {:.2f} tokens/s)".format(
|
||||
num_sentences,
|
||||
gen_timer.n,
|
||||
gen_timer.sum,
|
||||
num_sentences / gen_timer.sum,
|
||||
1.0 / gen_timer.avg,
|
||||
)
|
||||
)
|
||||
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
||||
return task, wer
|
||||
|
||||
|
||||
def make_parser():
|
||||
parser = options.get_generation_parser()
|
||||
parser = add_asr_eval_argument(parser)
|
||||
return parser
|
||||
|
||||
|
||||
def cli_main():
|
||||
parser = make_parser()
|
||||
args = options.parse_args_and_arch(parser)
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli_main()
|
|
@ -0,0 +1,8 @@
|
|||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
model_name = file[: file.find(".py")]
|
||||
importlib.import_module("examples.speech_recognition.models." + model_name)
|
Разница между файлами не показана из-за своего большого размера
Загрузить разницу
|
@ -0,0 +1,177 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.models import (
|
||||
FairseqEncoder,
|
||||
FairseqEncoderModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
from fairseq.modules.fairseq_dropout import FairseqDropout
|
||||
|
||||
|
||||
default_conv_enc_config = """[
|
||||
(400, 13, 170, 0.2),
|
||||
(440, 14, 0, 0.214),
|
||||
(484, 15, 0, 0.22898),
|
||||
(532, 16, 0, 0.2450086),
|
||||
(584, 17, 0, 0.262159202),
|
||||
(642, 18, 0, 0.28051034614),
|
||||
(706, 19, 0, 0.30014607037),
|
||||
(776, 20, 0, 0.321156295296),
|
||||
(852, 21, 0, 0.343637235966),
|
||||
(936, 22, 0, 0.367691842484),
|
||||
(1028, 23, 0, 0.393430271458),
|
||||
(1130, 24, 0, 0.42097039046),
|
||||
(1242, 25, 0, 0.450438317792),
|
||||
(1366, 26, 0, 0.481969000038),
|
||||
(1502, 27, 0, 0.51570683004),
|
||||
(1652, 28, 0, 0.551806308143),
|
||||
(1816, 29, 0, 0.590432749713),
|
||||
]"""
|
||||
|
||||
|
||||
@register_model("asr_w2l_conv_glu_encoder")
|
||||
class W2lConvGluEncoderModel(FairseqEncoderModel):
|
||||
def __init__(self, encoder):
|
||||
super().__init__(encoder)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add model-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--input-feat-per-channel",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="encoder input dimension per input channel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--in-channels",
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="number of encoder input channels",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--conv-enc-config",
|
||||
type=str,
|
||||
metavar="EXPR",
|
||||
help="""
|
||||
an array of tuples each containing the configuration of one conv layer
|
||||
[(out_channels, kernel_size, padding, dropout), ...]
|
||||
""",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
"""Build a new model instance."""
|
||||
conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
|
||||
encoder = W2lConvGluEncoder(
|
||||
vocab_size=len(task.target_dictionary),
|
||||
input_feat_per_channel=args.input_feat_per_channel,
|
||||
in_channels=args.in_channels,
|
||||
conv_enc_config=eval(conv_enc_config),
|
||||
)
|
||||
return cls(encoder)
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
lprobs = super().get_normalized_probs(net_output, log_probs, sample)
|
||||
lprobs.batch_first = False
|
||||
return lprobs
|
||||
|
||||
|
||||
class W2lConvGluEncoder(FairseqEncoder):
|
||||
def __init__(
|
||||
self, vocab_size, input_feat_per_channel, in_channels, conv_enc_config
|
||||
):
|
||||
super().__init__(None)
|
||||
|
||||
self.input_dim = input_feat_per_channel
|
||||
if in_channels != 1:
|
||||
raise ValueError("only 1 input channel is currently supported")
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.linear_layers = nn.ModuleList()
|
||||
self.dropouts = []
|
||||
cur_channels = input_feat_per_channel
|
||||
|
||||
for out_channels, kernel_size, padding, dropout in conv_enc_config:
|
||||
layer = nn.Conv1d(cur_channels, out_channels, kernel_size, padding=padding)
|
||||
layer.weight.data.mul_(math.sqrt(3)) # match wav2letter init
|
||||
self.conv_layers.append(nn.utils.weight_norm(layer))
|
||||
self.dropouts.append(
|
||||
FairseqDropout(dropout, module_name=self.__class__.__name__)
|
||||
)
|
||||
if out_channels % 2 != 0:
|
||||
raise ValueError("odd # of out_channels is incompatible with GLU")
|
||||
cur_channels = out_channels // 2 # halved by GLU
|
||||
|
||||
for out_channels in [2 * cur_channels, vocab_size]:
|
||||
layer = nn.Linear(cur_channels, out_channels)
|
||||
layer.weight.data.mul_(math.sqrt(3))
|
||||
self.linear_layers.append(nn.utils.weight_norm(layer))
|
||||
cur_channels = out_channels // 2
|
||||
|
||||
def forward(self, src_tokens, src_lengths, **kwargs):
|
||||
|
||||
"""
|
||||
src_tokens: padded tensor (B, T, C * feat)
|
||||
src_lengths: tensor of original lengths of input utterances (B,)
|
||||
"""
|
||||
B, T, _ = src_tokens.size()
|
||||
x = src_tokens.transpose(1, 2).contiguous() # (B, feat, T) assuming C == 1
|
||||
|
||||
for layer_idx in range(len(self.conv_layers)):
|
||||
x = self.conv_layers[layer_idx](x)
|
||||
x = F.glu(x, dim=1)
|
||||
x = self.dropouts[layer_idx](x)
|
||||
|
||||
x = x.transpose(1, 2).contiguous() # (B, T, 908)
|
||||
x = self.linear_layers[0](x)
|
||||
x = F.glu(x, dim=2)
|
||||
x = self.dropouts[-1](x)
|
||||
x = self.linear_layers[1](x)
|
||||
|
||||
assert x.size(0) == B
|
||||
assert x.size(1) == T
|
||||
|
||||
encoder_out = x.transpose(0, 1) # (T, B, vocab_size)
|
||||
|
||||
# need to debug this -- find a simpler/elegant way in pytorch APIs
|
||||
encoder_padding_mask = (
|
||||
torch.arange(T).view(1, T).expand(B, -1).to(x.device)
|
||||
>= src_lengths.view(B, 1).expand(-1, T)
|
||||
).t() # (B x T) -> (T x B)
|
||||
|
||||
return {
|
||||
"encoder_out": encoder_out, # (T, B, vocab_size)
|
||||
"encoder_padding_mask": encoder_padding_mask, # (T, B)
|
||||
}
|
||||
|
||||
def reorder_encoder_out(self, encoder_out, new_order):
|
||||
encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select(
|
||||
1, new_order
|
||||
)
|
||||
encoder_out["encoder_padding_mask"] = encoder_out[
|
||||
"encoder_padding_mask"
|
||||
].index_select(1, new_order)
|
||||
return encoder_out
|
||||
|
||||
def max_positions(self):
|
||||
"""Maximum input length supported by the encoder."""
|
||||
return (1e6, 1e6) # an arbitrary large number
|
||||
|
||||
|
||||
@register_model_architecture("asr_w2l_conv_glu_encoder", "w2l_conv_glu_enc")
|
||||
def w2l_conv_glu_enc(args):
|
||||
args.input_feat_per_channel = getattr(args, "input_feat_per_channel", 80)
|
||||
args.in_channels = getattr(args, "in_channels", 1)
|
||||
args.conv_enc_config = getattr(args, "conv_enc_config", default_conv_enc_config)
|
|
@ -0,0 +1,8 @@
|
|||
import importlib
|
||||
import os
|
||||
|
||||
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
task_name = file[: file.find(".py")]
|
||||
importlib.import_module("examples.speech_recognition.tasks." + task_name)
|
|
@ -0,0 +1,157 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from examples.speech_recognition.data import AsrDataset
|
||||
from examples.speech_recognition.data.replabels import replabel_symbol
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
def get_asr_dataset_from_json(data_json_path, tgt_dict):
|
||||
"""
|
||||
Parse data json and create dataset.
|
||||
See scripts/asr_prep_json.py which pack json from raw files
|
||||
|
||||
Json example:
|
||||
{
|
||||
"utts": {
|
||||
"4771-29403-0025": {
|
||||
"input": {
|
||||
"length_ms": 170,
|
||||
"path": "/tmp/file1.flac"
|
||||
},
|
||||
"output": {
|
||||
"text": "HELLO \n",
|
||||
"token": "HE LLO",
|
||||
"tokenid": "4815, 861"
|
||||
}
|
||||
},
|
||||
"1564-142299-0096": {
|
||||
...
|
||||
}
|
||||
}
|
||||
"""
|
||||
if not os.path.isfile(data_json_path):
|
||||
raise FileNotFoundError("Dataset not found: {}".format(data_json_path))
|
||||
with open(data_json_path, "rb") as f:
|
||||
data_samples = json.load(f)["utts"]
|
||||
assert len(data_samples) != 0
|
||||
sorted_samples = sorted(
|
||||
data_samples.items(),
|
||||
key=lambda sample: int(sample[1]["input"]["length_ms"]),
|
||||
reverse=True,
|
||||
)
|
||||
aud_paths = [s[1]["input"]["path"] for s in sorted_samples]
|
||||
ids = [s[0] for s in sorted_samples]
|
||||
speakers = []
|
||||
for s in sorted_samples:
|
||||
m = re.search("(.+?)-(.+?)-(.+?)", s[0])
|
||||
speakers.append(m.group(1) + "_" + m.group(2))
|
||||
frame_sizes = [s[1]["input"]["length_ms"] for s in sorted_samples]
|
||||
tgt = [
|
||||
[int(i) for i in s[1]["output"]["tokenid"].split(", ")]
|
||||
for s in sorted_samples
|
||||
]
|
||||
# append eos
|
||||
tgt = [[*t, tgt_dict.eos()] for t in tgt]
|
||||
return AsrDataset(aud_paths, frame_sizes, tgt, tgt_dict, ids, speakers)
|
||||
|
||||
|
||||
@register_task("speech_recognition")
|
||||
class SpeechRecognitionTask(LegacyFairseqTask):
|
||||
"""
|
||||
Task for training speech recognition model.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("data", help="path to data directory")
|
||||
parser.add_argument(
|
||||
"--silence-token", default="\u2581", help="token for silence (used by w2l)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-source-positions",
|
||||
default=sys.maxsize,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of frames in the source sequence",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=1024,
|
||||
type=int,
|
||||
metavar="N",
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args)
|
||||
self.tgt_dict = tgt_dict
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task (e.g., load dictionaries)."""
|
||||
dict_path = os.path.join(args.data, "dict.txt")
|
||||
if not os.path.isfile(dict_path):
|
||||
raise FileNotFoundError("Dict not found: {}".format(dict_path))
|
||||
tgt_dict = Dictionary.load(dict_path)
|
||||
|
||||
if args.criterion == "ctc_loss":
|
||||
tgt_dict.add_symbol("<ctc_blank>")
|
||||
elif args.criterion == "asg_loss":
|
||||
for i in range(1, args.max_replabel + 1):
|
||||
tgt_dict.add_symbol(replabel_symbol(i))
|
||||
|
||||
print("| dictionary: {} types".format(len(tgt_dict)))
|
||||
return cls(args, tgt_dict)
|
||||
|
||||
def load_dataset(self, split, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
data_json_path = os.path.join(self.args.data, "{}.json".format(split))
|
||||
self.datasets[split] = get_asr_dataset_from_json(data_json_path, self.tgt_dict)
|
||||
|
||||
def build_generator(self, models, args, **unused):
|
||||
w2l_decoder = getattr(args, "w2l_decoder", None)
|
||||
if w2l_decoder == "viterbi":
|
||||
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
||||
|
||||
return W2lViterbiDecoder(args, self.target_dictionary)
|
||||
elif w2l_decoder == "kenlm":
|
||||
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
||||
|
||||
return W2lKenLMDecoder(args, self.target_dictionary)
|
||||
elif w2l_decoder == "fairseqlm":
|
||||
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
|
||||
|
||||
return W2lFairseqLMDecoder(args, self.target_dictionary)
|
||||
else:
|
||||
return super().build_generator(models, args)
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
"""Return the :class:`~fairseq.data.Dictionary` for the language
|
||||
model."""
|
||||
return self.tgt_dict
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
"""Return the source :class:`~fairseq.data.Dictionary` (if applicable
|
||||
for this task)."""
|
||||
return None
|
||||
|
||||
def max_positions(self):
|
||||
"""Return the max speech and sentence length allowed by the task."""
|
||||
return (self.args.max_source_positions, self.args.max_target_positions)
|
|
@ -0,0 +1,381 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import re
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
"""
|
||||
Utility modules for computation of Word Error Rate,
|
||||
Alignments, as well as more granular metrics like
|
||||
deletion, insersion and substitutions.
|
||||
"""
|
||||
|
||||
|
||||
class Code(Enum):
|
||||
match = 1
|
||||
substitution = 2
|
||||
insertion = 3
|
||||
deletion = 4
|
||||
|
||||
|
||||
class Token(object):
|
||||
def __init__(self, lbl="", st=np.nan, en=np.nan):
|
||||
if np.isnan(st):
|
||||
self.label, self.start, self.end = "", 0.0, 0.0
|
||||
else:
|
||||
self.label, self.start, self.end = lbl, st, en
|
||||
|
||||
|
||||
class AlignmentResult(object):
|
||||
def __init__(self, refs, hyps, codes, score):
|
||||
self.refs = refs # std::deque<int>
|
||||
self.hyps = hyps # std::deque<int>
|
||||
self.codes = codes # std::deque<Code>
|
||||
self.score = score # float
|
||||
|
||||
|
||||
def coordinate_to_offset(row, col, ncols):
|
||||
return int(row * ncols + col)
|
||||
|
||||
|
||||
def offset_to_row(offset, ncols):
|
||||
return int(offset / ncols)
|
||||
|
||||
|
||||
def offset_to_col(offset, ncols):
|
||||
return int(offset % ncols)
|
||||
|
||||
|
||||
def trimWhitespace(str):
|
||||
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
|
||||
|
||||
|
||||
def str2toks(str):
|
||||
pieces = trimWhitespace(str).split(" ")
|
||||
toks = []
|
||||
for p in pieces:
|
||||
toks.append(Token(p, 0.0, 0.0))
|
||||
return toks
|
||||
|
||||
|
||||
class EditDistance(object):
|
||||
def __init__(self, time_mediated):
|
||||
self.time_mediated_ = time_mediated
|
||||
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
|
||||
self.backtraces_ = (
|
||||
np.nan
|
||||
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
|
||||
self.confusion_pairs_ = {}
|
||||
|
||||
def cost(self, ref, hyp, code):
|
||||
if self.time_mediated_:
|
||||
if code == Code.match:
|
||||
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
|
||||
elif code == Code.insertion:
|
||||
return hyp.end - hyp.start
|
||||
elif code == Code.deletion:
|
||||
return ref.end - ref.start
|
||||
else: # substitution
|
||||
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
|
||||
else:
|
||||
if code == Code.match:
|
||||
return 0
|
||||
elif code == Code.insertion or code == Code.deletion:
|
||||
return 3
|
||||
else: # substitution
|
||||
return 4
|
||||
|
||||
def get_result(self, refs, hyps):
|
||||
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
|
||||
|
||||
num_rows, num_cols = self.scores_.shape
|
||||
res.score = self.scores_[num_rows - 1, num_cols - 1]
|
||||
|
||||
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
|
||||
|
||||
while curr_offset != 0:
|
||||
curr_row = offset_to_row(curr_offset, num_cols)
|
||||
curr_col = offset_to_col(curr_offset, num_cols)
|
||||
|
||||
prev_offset = self.backtraces_[curr_row, curr_col]
|
||||
|
||||
prev_row = offset_to_row(prev_offset, num_cols)
|
||||
prev_col = offset_to_col(prev_offset, num_cols)
|
||||
|
||||
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
|
||||
res.hyps.appendleft(curr_col - 1)
|
||||
if curr_row - 1 == prev_row and curr_col == prev_col:
|
||||
res.codes.appendleft(Code.deletion)
|
||||
elif curr_row == prev_row and curr_col - 1 == prev_col:
|
||||
res.codes.appendleft(Code.insertion)
|
||||
else:
|
||||
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
|
||||
ref_str = refs[res.refs[0]].label
|
||||
hyp_str = hyps[res.hyps[0]].label
|
||||
|
||||
if ref_str == hyp_str:
|
||||
res.codes.appendleft(Code.match)
|
||||
else:
|
||||
res.codes.appendleft(Code.substitution)
|
||||
|
||||
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
|
||||
if confusion_pair not in self.confusion_pairs_:
|
||||
self.confusion_pairs_[confusion_pair] = 1
|
||||
else:
|
||||
self.confusion_pairs_[confusion_pair] += 1
|
||||
|
||||
curr_offset = prev_offset
|
||||
|
||||
return res
|
||||
|
||||
def align(self, refs, hyps):
|
||||
if len(refs) == 0 and len(hyps) == 0:
|
||||
return np.nan
|
||||
|
||||
# NOTE: we're not resetting the values in these matrices because every value
|
||||
# will be overridden in the loop below. If this assumption doesn't hold,
|
||||
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
|
||||
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
||||
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
||||
|
||||
num_rows, num_cols = self.scores_.shape
|
||||
|
||||
for i in range(num_rows):
|
||||
for j in range(num_cols):
|
||||
if i == 0 and j == 0:
|
||||
self.scores_[i, j] = 0.0
|
||||
self.backtraces_[i, j] = 0
|
||||
continue
|
||||
|
||||
if i == 0:
|
||||
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
|
||||
None, hyps[j - 1], Code.insertion
|
||||
)
|
||||
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
|
||||
continue
|
||||
|
||||
if j == 0:
|
||||
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
|
||||
refs[i - 1], None, Code.deletion
|
||||
)
|
||||
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
|
||||
continue
|
||||
|
||||
# Below here both i and j are greater than 0
|
||||
ref = refs[i - 1]
|
||||
hyp = hyps[j - 1]
|
||||
best_score = self.scores_[i - 1, j - 1] + (
|
||||
self.cost(ref, hyp, Code.match)
|
||||
if (ref.label == hyp.label)
|
||||
else self.cost(ref, hyp, Code.substitution)
|
||||
)
|
||||
|
||||
prev_row = i - 1
|
||||
prev_col = j - 1
|
||||
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
|
||||
if ins < best_score:
|
||||
best_score = ins
|
||||
prev_row = i
|
||||
prev_col = j - 1
|
||||
|
||||
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
|
||||
if delt < best_score:
|
||||
best_score = delt
|
||||
prev_row = i - 1
|
||||
prev_col = j
|
||||
|
||||
self.scores_[i, j] = best_score
|
||||
self.backtraces_[i, j] = coordinate_to_offset(
|
||||
prev_row, prev_col, num_cols
|
||||
)
|
||||
|
||||
return self.get_result(refs, hyps)
|
||||
|
||||
|
||||
class WERTransformer(object):
|
||||
def __init__(self, hyp_str, ref_str, verbose=True):
|
||||
self.ed_ = EditDistance(False)
|
||||
self.id2oracle_errs_ = {}
|
||||
self.utts_ = 0
|
||||
self.words_ = 0
|
||||
self.insertions_ = 0
|
||||
self.deletions_ = 0
|
||||
self.substitutions_ = 0
|
||||
|
||||
self.process(["dummy_str", hyp_str, ref_str])
|
||||
|
||||
if verbose:
|
||||
print("'%s' vs '%s'" % (hyp_str, ref_str))
|
||||
self.report_result()
|
||||
|
||||
def process(self, input): # std::vector<std::string>&& input
|
||||
if len(input) < 3:
|
||||
print(
|
||||
"Input must be of the form <id> ... <hypo> <ref> , got ",
|
||||
len(input),
|
||||
" inputs:",
|
||||
)
|
||||
return None
|
||||
|
||||
# Align
|
||||
# std::vector<Token> hyps;
|
||||
# std::vector<Token> refs;
|
||||
|
||||
hyps = str2toks(input[-2])
|
||||
refs = str2toks(input[-1])
|
||||
|
||||
alignment = self.ed_.align(refs, hyps)
|
||||
if alignment is None:
|
||||
print("Alignment is null")
|
||||
return np.nan
|
||||
|
||||
# Tally errors
|
||||
ins = 0
|
||||
dels = 0
|
||||
subs = 0
|
||||
for code in alignment.codes:
|
||||
if code == Code.substitution:
|
||||
subs += 1
|
||||
elif code == Code.insertion:
|
||||
ins += 1
|
||||
elif code == Code.deletion:
|
||||
dels += 1
|
||||
|
||||
# Output
|
||||
row = input
|
||||
row.append(str(len(refs)))
|
||||
row.append(str(ins))
|
||||
row.append(str(dels))
|
||||
row.append(str(subs))
|
||||
# print(row)
|
||||
|
||||
# Accumulate
|
||||
kIdIndex = 0
|
||||
kNBestSep = "/"
|
||||
|
||||
pieces = input[kIdIndex].split(kNBestSep)
|
||||
|
||||
if len(pieces) == 0:
|
||||
print(
|
||||
"Error splitting ",
|
||||
input[kIdIndex],
|
||||
" on '",
|
||||
kNBestSep,
|
||||
"', got empty list",
|
||||
)
|
||||
return np.nan
|
||||
|
||||
id = pieces[0]
|
||||
if id not in self.id2oracle_errs_:
|
||||
self.utts_ += 1
|
||||
self.words_ += len(refs)
|
||||
self.insertions_ += ins
|
||||
self.deletions_ += dels
|
||||
self.substitutions_ += subs
|
||||
self.id2oracle_errs_[id] = [ins, dels, subs]
|
||||
else:
|
||||
curr_err = ins + dels + subs
|
||||
prev_err = np.sum(self.id2oracle_errs_[id])
|
||||
if curr_err < prev_err:
|
||||
self.id2oracle_errs_[id] = [ins, dels, subs]
|
||||
|
||||
return 0
|
||||
|
||||
def report_result(self):
|
||||
# print("---------- Summary ---------------")
|
||||
if self.words_ == 0:
|
||||
print("No words counted")
|
||||
return
|
||||
|
||||
# 1-best
|
||||
best_wer = (
|
||||
100.0
|
||||
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
||||
/ self.words_
|
||||
)
|
||||
|
||||
print(
|
||||
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
|
||||
"%0.2f%% dels, %0.2f%% subs)"
|
||||
% (
|
||||
best_wer,
|
||||
self.utts_,
|
||||
self.words_,
|
||||
100.0 * self.insertions_ / self.words_,
|
||||
100.0 * self.deletions_ / self.words_,
|
||||
100.0 * self.substitutions_ / self.words_,
|
||||
)
|
||||
)
|
||||
|
||||
def wer(self):
|
||||
if self.words_ == 0:
|
||||
wer = np.nan
|
||||
else:
|
||||
wer = (
|
||||
100.0
|
||||
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
||||
/ self.words_
|
||||
)
|
||||
return wer
|
||||
|
||||
def stats(self):
|
||||
if self.words_ == 0:
|
||||
stats = {}
|
||||
else:
|
||||
wer = (
|
||||
100.0
|
||||
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
||||
/ self.words_
|
||||
)
|
||||
stats = dict(
|
||||
{
|
||||
"wer": wer,
|
||||
"utts": self.utts_,
|
||||
"numwords": self.words_,
|
||||
"ins": self.insertions_,
|
||||
"dels": self.deletions_,
|
||||
"subs": self.substitutions_,
|
||||
"confusion_pairs": self.ed_.confusion_pairs_,
|
||||
}
|
||||
)
|
||||
return stats
|
||||
|
||||
|
||||
def calc_wer(hyp_str, ref_str):
|
||||
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
||||
return t.wer()
|
||||
|
||||
|
||||
def calc_wer_stats(hyp_str, ref_str):
|
||||
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
||||
return t.stats()
|
||||
|
||||
|
||||
def get_wer_alignment_codes(hyp_str, ref_str):
|
||||
"""
|
||||
INPUT: hypothesis string, reference string
|
||||
OUTPUT: List of alignment codes (intermediate results from WER computation)
|
||||
"""
|
||||
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
||||
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
|
||||
|
||||
|
||||
def merge_counts(x, y):
|
||||
# Merge two hashes which have 'counts' as their values
|
||||
# This can be used for example to merge confusion pair counts
|
||||
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
|
||||
for k, v in y.items():
|
||||
if k not in x:
|
||||
x[k] = 0
|
||||
x[k] += v
|
||||
return x
|
|
@ -0,0 +1,453 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Wav2letter decoders.
|
||||
"""
|
||||
|
||||
import gc
|
||||
import pdb
|
||||
import itertools as it
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from collections import deque, namedtuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from examples.speech_recognition.data.replabels import unpack_replabels
|
||||
from fairseq import tasks, pdb
|
||||
from fairseq.utils import apply_to_sample
|
||||
|
||||
|
||||
try:
|
||||
from wav2letter.common import create_word_dict, load_words
|
||||
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
||||
from wav2letter.decoder import (
|
||||
CriterionType,
|
||||
DecoderOptions,
|
||||
KenLM,
|
||||
LM,
|
||||
LMState,
|
||||
SmearingMode,
|
||||
Trie,
|
||||
LexiconDecoder,
|
||||
)
|
||||
except:
|
||||
warnings.warn(
|
||||
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
|
||||
)
|
||||
LM = object
|
||||
LMState = object
|
||||
|
||||
|
||||
class W2lDecoder(object):
|
||||
def __init__(self, args, tgt_dict):
|
||||
self.tgt_dict = tgt_dict
|
||||
self.vocab_size = len(tgt_dict)
|
||||
self.nbest = args.nbest
|
||||
|
||||
# criterion-specific init
|
||||
if args.criterion == "ctc":
|
||||
self.criterion_type = "ctc"
|
||||
self.blank = (
|
||||
tgt_dict.index("<ctc_blank>")
|
||||
if "<ctc_blank>" in tgt_dict.indices
|
||||
else tgt_dict.bos()
|
||||
)
|
||||
self.asg_transitions = None
|
||||
elif args.criterion == "asg_loss":
|
||||
self.criterion_type = CriterionType.ASG
|
||||
self.blank = -1
|
||||
self.asg_transitions = args.asg_transitions
|
||||
self.max_replabel = args.max_replabel
|
||||
assert len(self.asg_transitions) == self.vocab_size ** 2
|
||||
else:
|
||||
raise RuntimeError(f"unknown criterion: {args.criterion}")
|
||||
|
||||
def generate(self, models, sample, **unused):
|
||||
"""Generate a batch of inferences."""
|
||||
# model.forward normally channels prev_output_tokens into the decoder
|
||||
# separately, but SequenceGenerator directly calls model.encoder
|
||||
encoder_input = {
|
||||
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
||||
}
|
||||
emissions = self.get_emissions(models, encoder_input)
|
||||
return self.decode(emissions)
|
||||
|
||||
def get_emissions(self, models, encoder_input):
|
||||
"""Run encoder and normalize emissions"""
|
||||
# encoder_out = models[0].encoder(**encoder_input)
|
||||
encoder_out = models[0](**encoder_input)
|
||||
if self.criterion_type == "ctc":
|
||||
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
|
||||
elif self.criterion_type == CriterionType.ASG:
|
||||
emissions = encoder_out["encoder_out"]
|
||||
return emissions.transpose(0, 1).float().cpu().contiguous()
|
||||
|
||||
def get_tokens(self, idxs):
|
||||
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
||||
idxs = (g[0] for g in it.groupby(idxs))
|
||||
if self.criterion_type == "ctc":
|
||||
idxs = filter(lambda x: x != self.blank, idxs)
|
||||
elif self.criterion_type == CriterionType.ASG:
|
||||
idxs = filter(lambda x: x >= 0, idxs)
|
||||
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
|
||||
return torch.LongTensor(list(idxs))
|
||||
|
||||
|
||||
class W2lViterbiDecoder(W2lDecoder):
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args, tgt_dict)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
results = []
|
||||
"""
|
||||
hypos = []
|
||||
if self.asg_transitions is None:
|
||||
transitions = torch.FloatTensor(N, N).zero_()
|
||||
else:
|
||||
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
||||
viterbi_path = torch.IntTensor(B, T)
|
||||
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
||||
CpuViterbiPath.compute(
|
||||
B,
|
||||
T,
|
||||
N,
|
||||
get_data_ptr_as_bytes(emissions),
|
||||
get_data_ptr_as_bytes(transitions),
|
||||
get_data_ptr_as_bytes(viterbi_path),
|
||||
get_data_ptr_as_bytes(workspace),
|
||||
)
|
||||
return [
|
||||
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
||||
for b in range(B)
|
||||
]
|
||||
"""
|
||||
for b in range(B):
|
||||
am_scores = emissions[b]
|
||||
tokens = am_scores.argmax(dim=-1)
|
||||
ids = []
|
||||
for i in range(T):
|
||||
if tokens[i].item() != self.tgt_dict.bos():
|
||||
if i == 0:
|
||||
ids.append(tokens[i].item())
|
||||
elif tokens[i].item() != tokens[i-1].item():
|
||||
ids.append(tokens[i].item())
|
||||
results.append([{"tokens": torch.LongTensor(ids), "score": 0.0}])
|
||||
return results
|
||||
|
||||
|
||||
|
||||
|
||||
class W2lKenLMDecoder(W2lDecoder):
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args, tgt_dict)
|
||||
|
||||
self.silence = (
|
||||
tgt_dict.index("<ctc_blank>")
|
||||
if "<ctc_blank>" in tgt_dict.indices
|
||||
else tgt_dict.bos()
|
||||
)
|
||||
self.lexicon = load_words(args.lexicon)
|
||||
self.word_dict = create_word_dict(self.lexicon)
|
||||
self.unk_word = self.word_dict.get_index("<unk>")
|
||||
|
||||
self.lm = KenLM(args.kenlm_model, self.word_dict)
|
||||
self.trie = Trie(self.vocab_size, self.silence)
|
||||
|
||||
start_state = self.lm.start(False)
|
||||
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
||||
word_idx = self.word_dict.get_index(word)
|
||||
_, score = self.lm.score(start_state, word_idx)
|
||||
for spelling in spellings:
|
||||
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
||||
assert (
|
||||
tgt_dict.unk() not in spelling_idxs
|
||||
), f"{spelling} {spelling_idxs}"
|
||||
self.trie.insert(spelling_idxs, word_idx, score)
|
||||
self.trie.smear(SmearingMode.MAX)
|
||||
|
||||
self.decoder_opts = DecoderOptions(
|
||||
args.beam,
|
||||
int(getattr(args, "beam_size_token", len(tgt_dict))),
|
||||
args.beam_threshold,
|
||||
args.lm_weight,
|
||||
args.word_score,
|
||||
args.unk_weight,
|
||||
args.sil_weight,
|
||||
0,
|
||||
False,
|
||||
self.criterion_type,
|
||||
)
|
||||
|
||||
if self.asg_transitions is None:
|
||||
N = 768
|
||||
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
|
||||
self.asg_transitions = []
|
||||
|
||||
self.decoder = LexiconDecoder(
|
||||
self.decoder_opts,
|
||||
self.trie,
|
||||
self.lm,
|
||||
self.silence,
|
||||
self.blank,
|
||||
self.unk_word,
|
||||
self.asg_transitions,
|
||||
False,
|
||||
)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
hypos = []
|
||||
for b in range(B):
|
||||
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
||||
results = self.decoder.decode(emissions_ptr, T, N)
|
||||
|
||||
nbest_results = results[: self.nbest]
|
||||
hypos.append(
|
||||
[
|
||||
{
|
||||
"tokens": self.get_tokens(result.tokens),
|
||||
"score": result.score,
|
||||
"words": [
|
||||
self.word_dict.get_entry(x) for x in result.words if x >= 0
|
||||
],
|
||||
}
|
||||
for result in nbest_results
|
||||
]
|
||||
)
|
||||
return hypos
|
||||
|
||||
|
||||
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
|
||||
|
||||
|
||||
class FairseqLM(LM):
|
||||
def __init__(self, dictionary, model):
|
||||
LM.__init__(self)
|
||||
self.dictionary = dictionary
|
||||
self.model = model
|
||||
self.unk = self.dictionary.unk()
|
||||
|
||||
self.save_incremental = False # this currently does not work properly
|
||||
self.max_cache = 20_000
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
model.make_generation_fast_()
|
||||
|
||||
self.states = {}
|
||||
self.stateq = deque()
|
||||
|
||||
def start(self, start_with_nothing):
|
||||
state = LMState()
|
||||
prefix = torch.LongTensor([[self.dictionary.eos()]])
|
||||
incremental_state = {} if self.save_incremental else None
|
||||
with torch.no_grad():
|
||||
res = self.model(prefix.cuda(), incremental_state=incremental_state)
|
||||
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
|
||||
|
||||
if incremental_state is not None:
|
||||
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
|
||||
self.states[state] = FairseqLMState(
|
||||
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
|
||||
)
|
||||
self.stateq.append(state)
|
||||
|
||||
return state
|
||||
|
||||
def score(self, state: LMState, token_index: int, no_cache: bool = False):
|
||||
"""
|
||||
Evaluate language model based on the current lm state and new word
|
||||
Parameters:
|
||||
-----------
|
||||
state: current lm state
|
||||
token_index: index of the word
|
||||
(can be lexicon index then you should store inside LM the
|
||||
mapping between indices of lexicon and lm, or lm index of a word)
|
||||
|
||||
Returns:
|
||||
--------
|
||||
(LMState, float): pair of (new state, score for the current word)
|
||||
"""
|
||||
curr_state = self.states[state]
|
||||
|
||||
def trim_cache(targ_size):
|
||||
while len(self.stateq) > targ_size:
|
||||
rem_k = self.stateq.popleft()
|
||||
rem_st = self.states[rem_k]
|
||||
rem_st = FairseqLMState(rem_st.prefix, None, None)
|
||||
self.states[rem_k] = rem_st
|
||||
|
||||
if curr_state.probs is None:
|
||||
new_incremental_state = (
|
||||
curr_state.incremental_state.copy()
|
||||
if curr_state.incremental_state is not None
|
||||
else None
|
||||
)
|
||||
with torch.no_grad():
|
||||
if new_incremental_state is not None:
|
||||
new_incremental_state = apply_to_sample(
|
||||
lambda x: x.cuda(), new_incremental_state
|
||||
)
|
||||
elif self.save_incremental:
|
||||
new_incremental_state = {}
|
||||
|
||||
res = self.model(
|
||||
torch.from_numpy(curr_state.prefix).cuda(),
|
||||
incremental_state=new_incremental_state,
|
||||
)
|
||||
probs = self.model.get_normalized_probs(
|
||||
res, log_probs=True, sample=None
|
||||
)
|
||||
|
||||
if new_incremental_state is not None:
|
||||
new_incremental_state = apply_to_sample(
|
||||
lambda x: x.cpu(), new_incremental_state
|
||||
)
|
||||
|
||||
curr_state = FairseqLMState(
|
||||
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
|
||||
)
|
||||
|
||||
if not no_cache:
|
||||
self.states[state] = curr_state
|
||||
self.stateq.append(state)
|
||||
|
||||
score = curr_state.probs[token_index].item()
|
||||
|
||||
trim_cache(self.max_cache)
|
||||
|
||||
outstate = state.child(token_index)
|
||||
if outstate not in self.states and not no_cache:
|
||||
prefix = np.concatenate(
|
||||
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
|
||||
)
|
||||
incr_state = curr_state.incremental_state
|
||||
|
||||
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
|
||||
|
||||
if token_index == self.unk:
|
||||
score = float("-inf")
|
||||
|
||||
return outstate, score
|
||||
|
||||
def finish(self, state: LMState):
|
||||
"""
|
||||
Evaluate eos for language model based on the current lm state
|
||||
|
||||
Returns:
|
||||
--------
|
||||
(LMState, float): pair of (new state, score for the current word)
|
||||
"""
|
||||
return self.score(state, self.dictionary.eos())
|
||||
|
||||
def empty_cache(self):
|
||||
self.states = {}
|
||||
self.stateq = deque()
|
||||
gc.collect()
|
||||
|
||||
|
||||
class W2lFairseqLMDecoder(W2lDecoder):
|
||||
def __init__(self, args, tgt_dict):
|
||||
super().__init__(args, tgt_dict)
|
||||
|
||||
self.silence = tgt_dict.bos()
|
||||
|
||||
self.unit_lm = getattr(args, "unit_lm", False)
|
||||
|
||||
self.lexicon = load_words(args.lexicon) if args.lexicon else None
|
||||
self.idx_to_wrd = {}
|
||||
|
||||
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
|
||||
lm_args = checkpoint["args"]
|
||||
lm_args.data = osp.dirname(args.kenlm_model)
|
||||
print(lm_args)
|
||||
task = tasks.setup_task(lm_args)
|
||||
model = task.build_model(lm_args)
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
|
||||
self.trie = Trie(self.vocab_size, self.silence)
|
||||
|
||||
self.word_dict = task.dictionary
|
||||
self.unk_word = self.word_dict.unk()
|
||||
self.lm = FairseqLM(self.word_dict, model)
|
||||
|
||||
self.decoder_opts = DecoderOptions(
|
||||
args.beam,
|
||||
int(getattr(args, "beam_size_token", len(tgt_dict))),
|
||||
args.beam_threshold,
|
||||
args.lm_weight,
|
||||
args.word_score,
|
||||
args.unk_weight,
|
||||
args.sil_weight,
|
||||
0,
|
||||
False,
|
||||
self.criterion_type,
|
||||
)
|
||||
|
||||
if self.lexicon:
|
||||
start_state = self.lm.start(False)
|
||||
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
||||
if self.unit_lm:
|
||||
word_idx = i
|
||||
self.idx_to_wrd[i] = word
|
||||
score = 0
|
||||
else:
|
||||
word_idx = self.word_dict.index(word)
|
||||
_, score = self.lm.score(start_state, word_idx, no_cache=True)
|
||||
|
||||
for spelling in spellings:
|
||||
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
||||
assert (
|
||||
tgt_dict.unk() not in spelling_idxs
|
||||
), f"{spelling} {spelling_idxs}"
|
||||
self.trie.insert(spelling_idxs, word_idx, score)
|
||||
self.trie.smear(SmearingMode.MAX)
|
||||
|
||||
self.decoder = LexiconDecoder(
|
||||
self.decoder_opts,
|
||||
self.trie,
|
||||
self.lm,
|
||||
self.silence,
|
||||
self.blank,
|
||||
self.unk_word,
|
||||
[],
|
||||
self.unit_lm,
|
||||
)
|
||||
else:
|
||||
from wav2letter.decoder import LexiconFreeDecoder
|
||||
self.decoder = LexiconFreeDecoder(
|
||||
self.decoder_opts, self.lm, self.silence, self.blank, []
|
||||
)
|
||||
|
||||
def decode(self, emissions):
|
||||
B, T, N = emissions.size()
|
||||
hypos = []
|
||||
|
||||
def idx_to_word(idx):
|
||||
if self.unit_lm:
|
||||
return self.idx_to_wrd[idx]
|
||||
else:
|
||||
return self.word_dict[idx]
|
||||
|
||||
def make_hypo(result):
|
||||
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
|
||||
if self.lexicon:
|
||||
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
|
||||
return hypo
|
||||
|
||||
for b in range(B):
|
||||
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
||||
results = self.decoder.decode(emissions_ptr, T, N)
|
||||
|
||||
nbest_results = results[: self.nbest]
|
||||
hypos.append([make_hypo(result) for result in nbest_results])
|
||||
self.lm.empty_cache()
|
||||
|
||||
return hypos
|
|
@ -0,0 +1,60 @@
|
|||
import argparse
|
||||
import os
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
|
||||
from pydub import AudioSegment
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--wav-path', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dest-path', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
'--input', type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
'--output', type=str
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
os.makedirs(args.dest_path, exist_ok=True)
|
||||
|
||||
f = open(args.input)
|
||||
data = f.readlines()
|
||||
f.close()
|
||||
|
||||
wf = open(args.output, 'w')
|
||||
wf.write(args.dest_path+"\n")
|
||||
count = len(data)
|
||||
for line in data:
|
||||
wav_name = line.strip()
|
||||
wav_file = os.path.join(args.wav_path, wav_name)
|
||||
base_wav_name = os.path.splitext(wav_name)[0]
|
||||
output_file = os.path.join(args.dest_path, base_wav_name+".wav")
|
||||
if os.path.exists(wav_file) and not os.path.exists(output_file):
|
||||
sound = AudioSegment.from_mp3(wav_file)
|
||||
sound.export(os.path.join(args.dest_path, 'tmp.wav'), format='wav')
|
||||
y, sr = librosa.load(os.path.join(args.dest_path, 'tmp.wav'), sr=16000)
|
||||
sf.write(output_file, y, sr)
|
||||
infos = sf.info(output_file)
|
||||
frames = infos.frames
|
||||
sr = infos.samplerate
|
||||
wf.write("{}\t{}\t{}\n".format(base_wav_name+".wav", frames, sr))
|
||||
count += 1
|
||||
if count % 100 == 0:
|
||||
print('process {} done'.format(count))
|
||||
|
||||
wf.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,56 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Helper script to pre-compute embeddings for a wav2letter++ dataset
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("tsv")
|
||||
parser.add_argument("--output-dir", required=True)
|
||||
parser.add_argument("--output-name", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
transcriptions = {}
|
||||
|
||||
with open(args.tsv, "r") as tsv, open(
|
||||
os.path.join(args.output_dir, args.output_name + ".ltr"), "w"
|
||||
) as ltr_out, open(
|
||||
os.path.join(args.output_dir, args.output_name + ".wrd"), "w"
|
||||
) as wrd_out:
|
||||
root = next(tsv).strip()
|
||||
for line in tsv:
|
||||
line = line.strip()
|
||||
dir = os.path.dirname(line)
|
||||
if dir not in transcriptions:
|
||||
parts = dir.split(os.path.sep)
|
||||
trans_path = f"{parts[-2]}-{parts[-1]}.trans.txt"
|
||||
path = os.path.join(root, dir, trans_path)
|
||||
assert os.path.exists(path)
|
||||
texts = {}
|
||||
with open(path, "r") as trans_f:
|
||||
for tline in trans_f:
|
||||
items = tline.strip().split()
|
||||
texts[items[0]] = " ".join(items[1:])
|
||||
transcriptions[dir] = texts
|
||||
part = os.path.basename(line).split(".")[0]
|
||||
assert part in transcriptions[dir]
|
||||
print(transcriptions[dir][part], file=wrd_out)
|
||||
print(
|
||||
" ".join(list(transcriptions[dir][part].replace(" ", "|"))) + " |",
|
||||
file=ltr_out,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -0,0 +1,11 @@
|
|||
model_path=YOUR_MODEL_PATH
|
||||
train_subset=pretrain_HOUR_16k
|
||||
valid_subset=valSeqs_1.0_uniform_new_version_16k
|
||||
WORLD_SIZE=8
|
||||
|
||||
|
||||
update_freq=2
|
||||
|
||||
mkdir -p ${model_path}
|
||||
|
||||
python train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data/LANG --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec --arch unispeech --train-subset ${train_subset} --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --normalize --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 100000 --lr 0.0002 --warmup-updates 10000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1000000 --max-update 100000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --pretrained-path PRETRAINED_MODEL_FROM_STAGE_1 --no-epoch-checkpoints --transpose --save-interval 4 --validate-interval 4
|
|
@ -0,0 +1,11 @@
|
|||
model_path=~/test
|
||||
pretrained_model=/datablob/users/v-chengw/commonvoice_model/ptmtls2_en1350.large.lr1e-3.64gpu.fgm0.1/checkpoint_best.pt
|
||||
train_subset=trainSeqs_1.0_uniform_new_version_16k
|
||||
valid_subset=valSeqs_1.0_uniform_new_version_16k
|
||||
|
||||
mkdir -p ${model_path}
|
||||
WORLD_SIZE=4
|
||||
updata_freq=1
|
||||
|
||||
|
||||
python train.py --distributed-world-size $WORLD_SIZE --distributed-port 0 /datablob/users/v-chengw/data/commonvoice_20200622/resource/nl --save-dir ${model_path} --post-process word --train-subset ${train_subset} --valid-subset ${valid_subset} --no-epoch-checkpoints --best-checkpoint-metric uer --num-workers 4 --max-update 20000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path ${pretrained_model} --labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.75 --layerdrop 0.1 --mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.25 --zero-infinity --feature-grad-mult 0.0 --freeze-finetune-updates 2000 --validate-after-updates 2000 --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 2000 --hold-steps 8000 --decay-steps 10000 --final-lr-scale 0.05 --activation-dropout 0.1 --dropout 0.1 --attention-dropout 0.1 --final-dropout 0.1 --dropout-input 0.1 --criterion ctc --max-tokens 1000000 --seed 1337 --log-format json --log-interval 100 --ddp-backend no_c10d --fp16 --update-freq ${updata_freq} --dict-path /datablob/users/v-chengw/data/commonvoice_20200622/common_voices_splits/nl/phonesMatches_reduced.json --save-interval 10 --validate-interval 10 --normalize
|
|
@ -0,0 +1,13 @@
|
|||
model_path=MODEL_PATH
|
||||
valid_subset=en/valid_16k
|
||||
WORLD_SIZE=NUM_OF_GPUS
|
||||
|
||||
|
||||
update_freq=$((64/$WORLD_SIZE)) #ngpu * update_freq = 64
|
||||
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
|
||||
mkdir -p ${model_path}
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec_mtl --arch unispeech --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-bias --logit-temp 0.1 --train-subset en/pretrain_1350_16k,es/pretrain_168_16k_sep,fr/pretrain_353_16k_sep,it/pretrain_90_16k_sep --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 200000 --lr 0.001 --warmup-updates 25000 --mask-length 10 --mask-prob 0.5 --mask-selection static --mask-other 0 --encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 1.0 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 320000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1200000 --max-update 250000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --post-process word --labels ltr --dict-path examples/unispeech/data/en/vocab_sep.json --negatives-from-everywhere --mtlalpha 0.5 --replace-prob 0.5 --transpose --no-epoch-checkpoints --log-format json
|
|
@ -0,0 +1,12 @@
|
|||
model_path=MODEL_PATH
|
||||
train_subset=pretrain_1350_16k
|
||||
valid_subset=valid_16k
|
||||
WORLD_SIZE=NUM_OF_GPUS
|
||||
|
||||
|
||||
update_freq=$((64/$WORLD_SIZE)) #ngpu * update_freq = 64
|
||||
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
|
||||
|
||||
mkdir -p ${model_path}
|
||||
|
||||
python -m torch.distributed.launch $DISTRIBUTED_ARGS train.py --distributed-world-size ${WORLD_SIZE} --distributed-port 0 examples/unispeech/data/en --save-dir ${model_path} --fp16 --num-workers 10 --task audio_pretraining --criterion wav2vec_mtl --arch unispeech --extractor-mode "layer_norm" --encoder-layers 24 --encoder-embed-dim 1024 --encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --final-dim 768 --layer-norm-first --conv-bias --logit-temp 0.1 --train-subset ${train_subset} --valid-subset ${valid_subset} --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.1,0.999995)' --infonce --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 200000 --lr 0.001 --warmup-updates 25000 --mask-length 10 --mask-prob 0.5 --mask-selection static --mask-other 0 --encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 1.0 --loss-weights '[0.1, 0]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 320000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1200000 --max-update 250000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d --update-freq ${update_freq} --post-process word --labels ltr --dict-path examples/unispeech/data/en/vocab.json --negatives-from-everywhere --mtlalpha 0.5 --replace-prob 0.5 --transpose --no-epoch-checkpoints --log-format json
|
|
@ -0,0 +1,46 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
'input',
|
||||
type=str,
|
||||
help="input .tsv file"
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dest',
|
||||
type=str,
|
||||
help="output directory"
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
wav_names = []
|
||||
text = []
|
||||
with open(args.input) as f:
|
||||
f.readline()
|
||||
for line in f:
|
||||
items = line.strip().split("\t")
|
||||
wav_names.append(items[1])
|
||||
text.append(items[2])
|
||||
base_name = os.path.basename(args.input)
|
||||
file_name = os.path.splitext(base_name)[0]
|
||||
|
||||
with open(os.path.join(args.dest, file_name+'.list'), 'w') as f:
|
||||
for name in wav_names:
|
||||
f.write(name+"\n")
|
||||
|
||||
with open(os.path.join(args.dest, file_name+'.text'), 'w') as f:
|
||||
for i in range(len(wav_names)):
|
||||
f.write("{}\t{}\n".format(wav_names[i], text[i]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""
|
||||
Data pre-processing: build vocabularies and binarize training data.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
|
||||
import soundfile
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"root", metavar="DIR", help="root directory containing flac files to index"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--valid-percent",
|
||||
default=0.01,
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="percentage of data to use as validation set (between 0 and 1)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dest", default=".", type=str, metavar="DIR", help="output directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ext", default="flac", type=str, metavar="EXT", help="extension to look for"
|
||||
)
|
||||
parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed")
|
||||
parser.add_argument(
|
||||
"--path-must-contain",
|
||||
default=None,
|
||||
type=str,
|
||||
metavar="FRAG",
|
||||
help="if set, path must contain this substring for a file to be included in the manifest",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def main(args):
|
||||
assert args.valid_percent >= 0 and args.valid_percent <= 1.0
|
||||
|
||||
dir_path = os.path.realpath(args.root)
|
||||
search_path = os.path.join(dir_path, "**/*." + args.ext)
|
||||
rand = random.Random(args.seed)
|
||||
|
||||
with open(os.path.join(args.dest, "train.tsv"), "w") as train_f, open(
|
||||
os.path.join(args.dest, "valid.tsv"), "w"
|
||||
) as valid_f:
|
||||
print(dir_path, file=train_f)
|
||||
print(dir_path, file=valid_f)
|
||||
|
||||
for fname in glob.iglob(search_path, recursive=True):
|
||||
file_path = os.path.realpath(fname)
|
||||
|
||||
if args.path_must_contain and args.path_must_contain not in file_path:
|
||||
continue
|
||||
|
||||
frames = soundfile.info(fname).frames
|
||||
dest = train_f if rand.random() > args.valid_percent else valid_f
|
||||
print(
|
||||
"{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
from .version import __version__ # noqa
|
||||
except ImportError:
|
||||
version_txt = os.path.join(os.path.dirname(__file__), "version.txt")
|
||||
with open(version_txt) as f:
|
||||
__version__ = f.read().strip()
|
||||
|
||||
__all__ = ["pdb"]
|
||||
|
||||
# backwards compatibility to support `from fairseq.meters import AverageMeter`
|
||||
from fairseq.logging import meters, metrics, progress_bar # noqa
|
||||
|
||||
sys.modules["fairseq.meters"] = meters
|
||||
sys.modules["fairseq.metrics"] = metrics
|
||||
sys.modules["fairseq.progress_bar"] = progress_bar
|
||||
|
||||
# initialize hydra
|
||||
from fairseq.dataclass.initialize import hydra_init
|
||||
hydra_init()
|
||||
|
||||
import fairseq.criterions # noqa
|
||||
import fairseq.models # noqa
|
||||
import fairseq.modules # noqa
|
||||
import fairseq.optim # noqa
|
||||
import fairseq.optim.lr_scheduler # noqa
|
||||
import fairseq.pdb # noqa
|
||||
import fairseq.scoring # noqa
|
||||
import fairseq.tasks # noqa
|
||||
import fairseq.token_generation_constraints # noqa
|
||||
|
||||
import fairseq.benchmark # noqa
|
||||
import fairseq.model_parallel # noqa
|
|
@ -0,0 +1,7 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# import models/tasks to register them
|
||||
from . import dummy_lm, dummy_masked_lm, dummy_model, dummy_mt # noqa
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("dummy_lm")
|
||||
class DummyLMTask(LegacyFairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("--dict-size", default=49996, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments "
|
||||
"per sample for BERT dataset",
|
||||
)
|
||||
parser.add_argument("--add-bos-token", action="store_true", help="unused")
|
||||
parser.add_argument(
|
||||
"--max-target-positions",
|
||||
default=None,
|
||||
help="max number of tokens in the target sequence",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
|
||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
seq = torch.arange(args.tokens_per_sample + 1) + dictionary.pad() + 1
|
||||
|
||||
self.dummy_src = seq[:-1]
|
||||
self.dummy_tgt = seq[1:]
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if self.args.batch_size is not None:
|
||||
bsz = self.args.batch_size
|
||||
else:
|
||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tokens_per_sample,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=self.args.tokens_per_sample,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
self.num_items = num_items
|
||||
self.item_size = item_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return self.num_items
|
||||
|
||||
def collater(self, samples):
|
||||
return self.batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array([self.item_size] * self.num_items)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.item_size
|
||||
|
||||
def size(self, index):
|
||||
return self.item_size
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.arange(self.num_items)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
|
@ -0,0 +1,127 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("dummy_masked_lm")
|
||||
class DummyMaskedLMTask(LegacyFairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("--dict-size", default=49995, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument(
|
||||
"--tokens-per-sample",
|
||||
default=512,
|
||||
type=int,
|
||||
help="max number of total tokens over all segments "
|
||||
"per sample for BERT dataset",
|
||||
)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
|
||||
# add mask token
|
||||
self.mask_idx = dictionary.add_symbol("<mask>")
|
||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
mask_idx = 0
|
||||
pad_idx = 1
|
||||
seq = torch.arange(args.tokens_per_sample) + pad_idx + 1
|
||||
mask = torch.arange(2, args.tokens_per_sample, 7) # ~15%
|
||||
src = seq.clone()
|
||||
src[mask] = mask_idx
|
||||
tgt = torch.full_like(seq, pad_idx)
|
||||
tgt[mask] = seq[mask]
|
||||
|
||||
self.dummy_src = src
|
||||
self.dummy_tgt = tgt
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
if self.args.batch_size is not None:
|
||||
bsz = self.args.batch_size
|
||||
else:
|
||||
bsz = max(1, self.args.max_tokens // self.args.tokens_per_sample)
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.tokens_per_sample, dtype=torch.long
|
||||
),
|
||||
},
|
||||
"target": torch.stack([self.dummy_tgt for _ in range(bsz)]),
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tokens_per_sample,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=self.args.tokens_per_sample,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
self.num_items = num_items
|
||||
self.item_size = item_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return self.num_items
|
||||
|
||||
def collater(self, samples):
|
||||
return self.batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array([self.item_size] * self.num_items)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.item_size
|
||||
|
||||
def size(self, index):
|
||||
return self.item_size
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.arange(self.num_items)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from fairseq.data import Dictionary
|
||||
from fairseq.models import (
|
||||
FairseqDecoder,
|
||||
FairseqLanguageModel,
|
||||
register_model,
|
||||
register_model_architecture,
|
||||
)
|
||||
|
||||
|
||||
@register_model("dummy_model")
|
||||
class DummyModel(FairseqLanguageModel):
|
||||
def __init__(self, args, encoder):
|
||||
super().__init__(encoder)
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
parser.add_argument("--num-layers", type=int, default=24)
|
||||
parser.add_argument("--embed-dim", type=int, default=1024)
|
||||
|
||||
@classmethod
|
||||
def build_model(cls, args, task):
|
||||
encoder = DummyEncoder(
|
||||
num_embed=len(task.target_dictionary),
|
||||
embed_dim=args.embed_dim,
|
||||
num_layers=args.num_layers,
|
||||
)
|
||||
return cls(args, encoder)
|
||||
|
||||
def forward(self, src_tokens, masked_tokens=None, **kwargs):
|
||||
return self.decoder(src_tokens, masked_tokens=masked_tokens)
|
||||
|
||||
|
||||
class DummyEncoder(FairseqDecoder):
|
||||
def __init__(self, num_embed=50000, embed_dim=1024, num_layers=24):
|
||||
super().__init__(Dictionary())
|
||||
self.embed = nn.Embedding(
|
||||
num_embeddings=num_embed, embedding_dim=embed_dim, padding_idx=0
|
||||
)
|
||||
self.layers_a = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 3 * embed_dim), # q, k, v input projection
|
||||
nn.Linear(3 * embed_dim, embed_dim), # skip self-attention
|
||||
nn.Linear(embed_dim, embed_dim), # output projection
|
||||
nn.Dropout(),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.layers_b = nn.ModuleList(
|
||||
[
|
||||
nn.Sequential(
|
||||
nn.LayerNorm(embed_dim),
|
||||
nn.Linear(embed_dim, 4 * embed_dim), # FFN
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * embed_dim, embed_dim), # FFN
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.out_proj = nn.Linear(embed_dim, num_embed)
|
||||
|
||||
def forward(self, tokens, masked_tokens=None):
|
||||
x = self.embed(tokens)
|
||||
for layer_a, layer_b in zip(self.layers_a, self.layers_b):
|
||||
x = x + layer_a(x)
|
||||
x = x + layer_b(x)
|
||||
x = self.out_proj(x)
|
||||
if masked_tokens is not None:
|
||||
x = x[masked_tokens]
|
||||
return (x,)
|
||||
|
||||
def max_positions(self):
|
||||
return 1024
|
||||
|
||||
def get_normalized_probs(self, net_output, log_probs, sample=None):
|
||||
logits = net_output[0].float()
|
||||
if log_probs:
|
||||
return F.log_softmax(logits, dim=-1)
|
||||
else:
|
||||
return F.softmax(logits, dim=-1)
|
||||
|
||||
|
||||
@register_model_architecture("dummy_model", "dummy_model")
|
||||
def base_architecture(args):
|
||||
pass
|
|
@ -0,0 +1,119 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import Dictionary, FairseqDataset
|
||||
from fairseq.tasks import LegacyFairseqTask, register_task
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@register_task("dummy_mt")
|
||||
class DummyMTTask(LegacyFairseqTask):
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add task-specific arguments to the parser."""
|
||||
parser.add_argument("--dict-size", default=49996, type=int)
|
||||
parser.add_argument("--dataset-size", default=100000, type=int)
|
||||
parser.add_argument("--src-len", default=30, type=int)
|
||||
parser.add_argument("--tgt-len", default=30, type=int)
|
||||
|
||||
def __init__(self, args, dictionary):
|
||||
super().__init__(args)
|
||||
self.dictionary = dictionary
|
||||
self.seed = args.seed
|
||||
|
||||
dictionary.pad_to_multiple_(8) # often faster if divisible by 8
|
||||
|
||||
self.dummy_src = torch.arange(args.src_len + 1) + dictionary.pad() + 1
|
||||
self.dummy_tgt = torch.arange(args.tgt_len + 1) + dictionary.pad() + 1
|
||||
|
||||
@classmethod
|
||||
def setup_task(cls, args, **kwargs):
|
||||
"""Setup the task. """
|
||||
dictionary = Dictionary()
|
||||
for i in range(args.dict_size):
|
||||
dictionary.add_symbol("word{}".format(i))
|
||||
logger.info("dictionary: {} types".format(len(dictionary)))
|
||||
|
||||
args.max_source_positions = args.src_len + dictionary.pad() + 2
|
||||
args.max_target_positions = args.tgt_len + dictionary.pad() + 2
|
||||
|
||||
return cls(args, dictionary)
|
||||
|
||||
def load_dataset(self, split, epoch=1, combine=False, **kwargs):
|
||||
"""Load a given dataset split.
|
||||
Args:
|
||||
split (str): name of the split (e.g., train, valid, test)
|
||||
"""
|
||||
item_size = max(self.args.src_len, self.args.tgt_len)
|
||||
if self.args.batch_size is not None:
|
||||
bsz = self.args.batch_size
|
||||
else:
|
||||
bsz = max(1, self.args.max_tokens // item_size)
|
||||
tgt = torch.stack([self.dummy_tgt for _ in range(bsz)])
|
||||
self.datasets[split] = DummyDataset(
|
||||
{
|
||||
"id": 1,
|
||||
"net_input": {
|
||||
"src_tokens": torch.stack([self.dummy_src for _ in range(bsz)]),
|
||||
"src_lengths": torch.full(
|
||||
(bsz,), self.args.src_len, dtype=torch.long
|
||||
),
|
||||
"prev_output_tokens": tgt.clone(),
|
||||
},
|
||||
"target": tgt,
|
||||
"nsentences": bsz,
|
||||
"ntokens": bsz * self.args.tgt_len,
|
||||
},
|
||||
num_items=self.args.dataset_size,
|
||||
item_size=item_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def source_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
@property
|
||||
def target_dictionary(self):
|
||||
return self.dictionary
|
||||
|
||||
|
||||
class DummyDataset(FairseqDataset):
|
||||
def __init__(self, batch, num_items, item_size):
|
||||
super().__init__()
|
||||
self.batch = batch
|
||||
self.num_items = num_items
|
||||
self.item_size = item_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
return index
|
||||
|
||||
def __len__(self):
|
||||
return self.num_items
|
||||
|
||||
def collater(self, samples):
|
||||
return self.batch
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array([self.item_size] * self.num_items)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.item_size
|
||||
|
||||
def size(self, index):
|
||||
return self.item_size
|
||||
|
||||
def ordered_indices(self):
|
||||
return np.arange(self.num_items)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return False
|
|
@ -0,0 +1,105 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
import torch
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
|
||||
def safe_readline(f):
|
||||
pos = f.tell()
|
||||
while True:
|
||||
try:
|
||||
return f.readline()
|
||||
except UnicodeDecodeError:
|
||||
pos -= 1
|
||||
f.seek(pos) # search where this character begins
|
||||
|
||||
|
||||
class Binarizer:
|
||||
@staticmethod
|
||||
def binarize(
|
||||
filename,
|
||||
dict,
|
||||
consumer,
|
||||
tokenize=tokenize_line,
|
||||
append_eos=True,
|
||||
reverse_order=False,
|
||||
offset=0,
|
||||
end=-1,
|
||||
already_numberized=False,
|
||||
):
|
||||
nseq, ntok = 0, 0
|
||||
replaced = Counter()
|
||||
|
||||
def replaced_consumer(word, idx):
|
||||
if idx == dict.unk_index and word != dict.unk_word:
|
||||
replaced.update([word])
|
||||
|
||||
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
||||
f.seek(offset)
|
||||
# next(f) breaks f.tell(), hence readline() must be used
|
||||
line = safe_readline(f)
|
||||
while line:
|
||||
if end > 0 and f.tell() > end:
|
||||
break
|
||||
if already_numberized:
|
||||
id_strings = line.strip().split()
|
||||
id_list = [int(id_string) for id_string in id_strings]
|
||||
if reverse_order:
|
||||
id_list.reverse()
|
||||
if append_eos:
|
||||
id_list.append(dict.eos())
|
||||
ids = torch.IntTensor(id_list)
|
||||
else:
|
||||
ids = dict.encode_line(
|
||||
line=line,
|
||||
line_tokenizer=tokenize,
|
||||
add_if_not_exist=False,
|
||||
consumer=replaced_consumer,
|
||||
append_eos=append_eos,
|
||||
reverse_order=reverse_order,
|
||||
)
|
||||
nseq += 1
|
||||
ntok += len(ids)
|
||||
consumer(ids)
|
||||
line = f.readline()
|
||||
return {
|
||||
"nseq": nseq,
|
||||
"nunk": sum(replaced.values()),
|
||||
"ntok": ntok,
|
||||
"replaced": replaced,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def binarize_alignments(filename, alignment_parser, consumer, offset=0, end=-1):
|
||||
nseq = 0
|
||||
|
||||
with open(PathManager.get_local_path(filename), "r") as f:
|
||||
f.seek(offset)
|
||||
line = safe_readline(f)
|
||||
while line:
|
||||
if end > 0 and f.tell() > end:
|
||||
break
|
||||
ids = alignment_parser(line)
|
||||
nseq += 1
|
||||
consumer(ids)
|
||||
line = f.readline()
|
||||
return {"nseq": nseq}
|
||||
|
||||
@staticmethod
|
||||
def find_offsets(filename, num_chunks):
|
||||
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
chunk_size = size // num_chunks
|
||||
offsets = [0 for _ in range(num_chunks + 1)]
|
||||
for i in range(1, num_chunks):
|
||||
f.seek(chunk_size * i)
|
||||
safe_readline(f)
|
||||
offsets[i] = f.tell()
|
||||
return offsets
|
|
@ -0,0 +1,610 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig
|
||||
from fairseq.dataclass.utils import (
|
||||
convert_namespace_to_omegaconf,
|
||||
overwrite_args_by_name,
|
||||
)
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.models import FairseqDecoder, FairseqEncoder
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
||||
from fairseq import meters
|
||||
|
||||
# only one worker should attempt to create the required dir
|
||||
if cfg.distributed_rank == 0:
|
||||
os.makedirs(cfg.save_dir, exist_ok=True)
|
||||
|
||||
prev_best = getattr(save_checkpoint, "best", val_loss)
|
||||
if val_loss is not None:
|
||||
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
||||
save_checkpoint.best = best_function(val_loss, prev_best)
|
||||
|
||||
if cfg.no_save:
|
||||
return
|
||||
|
||||
trainer.consolidate_optimizer()
|
||||
|
||||
if not trainer.is_data_parallel_master:
|
||||
return
|
||||
|
||||
def is_better(a, b):
|
||||
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
||||
|
||||
write_timer = meters.StopwatchMeter()
|
||||
write_timer.start()
|
||||
|
||||
epoch = epoch_itr.epoch
|
||||
end_of_epoch = epoch_itr.end_of_epoch()
|
||||
updates = trainer.get_num_updates()
|
||||
|
||||
suffix = cfg.checkpoint_suffix or ""
|
||||
checkpoint_conds = collections.OrderedDict()
|
||||
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
||||
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
||||
not end_of_epoch
|
||||
and cfg.save_interval_updates > 0
|
||||
and updates % cfg.save_interval_updates == 0
|
||||
)
|
||||
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
||||
not hasattr(save_checkpoint, "best")
|
||||
or is_better(val_loss, save_checkpoint.best)
|
||||
)
|
||||
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
||||
checkpoint_conds[
|
||||
"checkpoint.best_{}_{:.2f}.pt".format(cfg.best_checkpoint_metric, val_loss)
|
||||
] = not hasattr(save_checkpoint, "best") or is_better(
|
||||
val_loss, save_checkpoint.best
|
||||
)
|
||||
checkpoint_conds[
|
||||
"checkpoint_last{}.pt".format(suffix)
|
||||
] = not cfg.no_last_checkpoints
|
||||
|
||||
extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
|
||||
if hasattr(save_checkpoint, "best"):
|
||||
extra_state.update({"best": save_checkpoint.best})
|
||||
|
||||
checkpoints = [
|
||||
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
||||
]
|
||||
if len(checkpoints) > 0:
|
||||
trainer.save_checkpoint(checkpoints[0], extra_state)
|
||||
for cp in checkpoints[1:]:
|
||||
PathManager.copy(checkpoints[0], cp, overwrite=True)
|
||||
|
||||
write_timer.stop()
|
||||
logger.info(
|
||||
"saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
||||
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
||||
)
|
||||
)
|
||||
|
||||
if not end_of_epoch and cfg.keep_interval_updates > 0:
|
||||
# remove old checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+)\.pt"
|
||||
)
|
||||
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if cfg.keep_last_epochs > 0:
|
||||
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
||||
checkpoints = checkpoint_paths(cfg.save_dir, pattern=r"checkpoint(\d+)\.pt")
|
||||
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
if cfg.keep_best_checkpoints > 0:
|
||||
# only keep the best N checkpoints according to validation metric
|
||||
checkpoints = checkpoint_paths(
|
||||
cfg.save_dir,
|
||||
pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format(
|
||||
cfg.best_checkpoint_metric
|
||||
),
|
||||
)
|
||||
if not cfg.maximize_best_checkpoint_metric:
|
||||
checkpoints = checkpoints[::-1]
|
||||
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
||||
if os.path.lexists(old_chk):
|
||||
os.remove(old_chk)
|
||||
|
||||
|
||||
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
||||
"""
|
||||
Load a checkpoint and restore the training iterator.
|
||||
|
||||
*passthrough_args* will be passed through to
|
||||
``trainer.get_train_iterator``.
|
||||
"""
|
||||
|
||||
reset_optimizer = cfg.reset_optimizer
|
||||
reset_lr_scheduler = cfg.reset_lr_scheduler
|
||||
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
||||
reset_meters = cfg.reset_meters
|
||||
reset_dataloader = cfg.reset_dataloader
|
||||
|
||||
if cfg.finetune_from_model is not None and (
|
||||
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
||||
):
|
||||
raise ValueError(
|
||||
"--finetune-from-model can not be set together with either --reset-optimizer"
|
||||
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
||||
)
|
||||
|
||||
suffix = cfg.checkpoint_suffix
|
||||
if (
|
||||
cfg.restore_file == "checkpoint_last.pt"
|
||||
): # default value of restore_file is 'checkpoint_last.pt'
|
||||
checkpoint_path = os.path.join(
|
||||
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
||||
)
|
||||
first_launch = not PathManager.exists(checkpoint_path)
|
||||
if cfg.finetune_from_model is not None and first_launch:
|
||||
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
||||
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
||||
if PathManager.exists(cfg.finetune_from_model):
|
||||
checkpoint_path = cfg.finetune_from_model
|
||||
reset_optimizer = True
|
||||
reset_lr_scheduler = True
|
||||
reset_meters = True
|
||||
reset_dataloader = True
|
||||
logger.info(
|
||||
f"loading pretrained model from {checkpoint_path}: "
|
||||
"optimizer, lr scheduler, meters, dataloader will be reset"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"--funetune-from-model {cfg.finetune_from_model} does not exist"
|
||||
)
|
||||
elif cfg.model_parallel_size > 1:
|
||||
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
checkpoint_path = cfg.restore_file
|
||||
|
||||
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
||||
raise ValueError(
|
||||
"--finetune-from-model and --restore-file (non-default value) "
|
||||
"can not be specified together: " + str(cfg)
|
||||
)
|
||||
|
||||
extra_state = trainer.load_checkpoint(
|
||||
checkpoint_path,
|
||||
reset_optimizer,
|
||||
reset_lr_scheduler,
|
||||
optimizer_overrides,
|
||||
reset_meters=reset_meters,
|
||||
)
|
||||
|
||||
if (
|
||||
extra_state is not None
|
||||
and "best" in extra_state
|
||||
and not reset_optimizer
|
||||
and not reset_meters
|
||||
):
|
||||
save_checkpoint.best = extra_state["best"]
|
||||
|
||||
if extra_state is not None and not reset_dataloader:
|
||||
# restore iterator from checkpoint
|
||||
itr_state = extra_state["train_iterator"]
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
||||
)
|
||||
epoch_itr.load_state_dict(itr_state)
|
||||
else:
|
||||
epoch_itr = trainer.get_train_iterator(
|
||||
epoch=1, load_dataset=True, **passthrough_args
|
||||
)
|
||||
|
||||
trainer.lr_step(epoch_itr.epoch)
|
||||
|
||||
return extra_state, epoch_itr
|
||||
|
||||
|
||||
def load_checkpoint_to_cpu(path, arg_overrides=None):
|
||||
"""Loads a checkpoint to CPU (with upgrading for backward compatibility)."""
|
||||
with open(PathManager.get_local_path(path), "rb") as f:
|
||||
state = torch.load(
|
||||
f, map_location=lambda s, l: default_restore_location(s, "cpu")
|
||||
)
|
||||
|
||||
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
||||
args = state["args"]
|
||||
for arg_name, arg_val in arg_overrides.items():
|
||||
setattr(args, arg_name, arg_val)
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None and arg_overrides is not None:
|
||||
overwrite_args_by_name(state["cfg"], arg_overrides)
|
||||
|
||||
state = _upgrade_state_dict(state)
|
||||
return state
|
||||
|
||||
|
||||
def load_model_ensemble(
|
||||
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||
):
|
||||
"""Loads an ensemble of models.
|
||||
|
||||
Args:
|
||||
filenames (List[str]): checkpoint files to load
|
||||
arg_overrides (Dict[str,Any], optional): override model args that
|
||||
were used during model training
|
||||
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
||||
"""
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble, args, _task = load_model_ensemble_and_task(
|
||||
filenames,
|
||||
arg_overrides,
|
||||
task,
|
||||
strict,
|
||||
suffix,
|
||||
num_shards,
|
||||
)
|
||||
return ensemble, args
|
||||
|
||||
|
||||
def load_model_ensemble_and_task(
|
||||
filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1
|
||||
):
|
||||
from fairseq import tasks
|
||||
|
||||
assert not (
|
||||
strict and num_shards > 1
|
||||
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
||||
ensemble = []
|
||||
for filename in filenames:
|
||||
orig_filename = filename
|
||||
for shard_idx in range(num_shards):
|
||||
if num_shards == 1:
|
||||
filename = filename.replace(".pt", suffix + ".pt")
|
||||
else:
|
||||
filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
||||
|
||||
if not PathManager.exists(filename):
|
||||
raise IOError("Model file not found: {}".format(filename))
|
||||
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
||||
if "args" in state and state["args"] is not None:
|
||||
cfg = convert_namespace_to_omegaconf(state["args"])
|
||||
elif "cfg" in state and state["cfg"] is not None:
|
||||
cfg = state["cfg"]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
||||
)
|
||||
|
||||
if task is None:
|
||||
task = tasks.setup_task(cfg.task)
|
||||
|
||||
# build model for ensemble
|
||||
model = task.build_model(cfg.model)
|
||||
|
||||
model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model)
|
||||
ensemble.append(model)
|
||||
return ensemble, cfg, task
|
||||
|
||||
|
||||
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"):
|
||||
"""Retrieves all checkpoints found in `path` directory.
|
||||
|
||||
Checkpoints are identified by matching filename to the specified pattern. If
|
||||
the pattern contains groups, the result will be sorted by the first group in
|
||||
descending order.
|
||||
"""
|
||||
pt_regexp = re.compile(pattern)
|
||||
files = os.listdir(path)
|
||||
|
||||
entries = []
|
||||
for i, f in enumerate(files):
|
||||
m = pt_regexp.fullmatch(f)
|
||||
if m is not None:
|
||||
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
||||
entries.append((idx, m.group(0)))
|
||||
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
||||
|
||||
|
||||
def torch_persistent_save(obj, f):
|
||||
if isinstance(f, str):
|
||||
with PathManager.open(f, "wb") as h:
|
||||
torch_persistent_save(obj, h)
|
||||
return
|
||||
for i in range(3):
|
||||
try:
|
||||
return torch.save(obj, f)
|
||||
except Exception:
|
||||
if i == 2:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
def save_state(
|
||||
filename,
|
||||
cfg: FairseqConfig,
|
||||
model_state_dict,
|
||||
criterion,
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
num_updates,
|
||||
optim_history=None,
|
||||
extra_state=None,
|
||||
**kwargs,
|
||||
):
|
||||
from fairseq import utils
|
||||
|
||||
if optim_history is None:
|
||||
optim_history = []
|
||||
if extra_state is None:
|
||||
extra_state = {}
|
||||
state_dict = {
|
||||
"cfg": cfg,
|
||||
"args": kwargs.get("args", None),
|
||||
"model": model_state_dict or {},
|
||||
"optimizer_history": optim_history
|
||||
+ [
|
||||
{
|
||||
"criterion_name": criterion.__class__.__name__,
|
||||
"optimizer_name": optimizer.__class__.__name__,
|
||||
"lr_scheduler_state": lr_scheduler.state_dict(),
|
||||
"num_updates": num_updates,
|
||||
}
|
||||
],
|
||||
"extra_state": extra_state,
|
||||
}
|
||||
if utils.has_parameters(criterion):
|
||||
state_dict["criterion"] = criterion.state_dict()
|
||||
|
||||
if cfg is None:
|
||||
cfg = state_dict["args"]
|
||||
assert cfg is not None, "must provide cfg or args"
|
||||
|
||||
if isinstance(cfg, DictConfig):
|
||||
no_save_optimizer_state = cfg.checkpoint.no_save_optimizer_state
|
||||
else:
|
||||
no_save_optimizer_state = cfg.no_save_optimizer_state
|
||||
if not no_save_optimizer_state:
|
||||
state_dict["last_optimizer_state"] = optimizer.state_dict()
|
||||
|
||||
with PathManager.open(filename, "wb") as f:
|
||||
torch_persistent_save(state_dict, f)
|
||||
|
||||
|
||||
def _upgrade_state_dict(state):
|
||||
"""Helper for upgrading old model checkpoints."""
|
||||
from fairseq import models, registry, tasks
|
||||
|
||||
# add optimizer_history
|
||||
if "optimizer_history" not in state:
|
||||
state["optimizer_history"] = [
|
||||
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
||||
]
|
||||
state["last_optimizer_state"] = state["optimizer"]
|
||||
del state["optimizer"]
|
||||
del state["best_loss"]
|
||||
# move extra_state into sub-dictionary
|
||||
if "epoch" in state and "extra_state" not in state:
|
||||
state["extra_state"] = {
|
||||
"epoch": state["epoch"],
|
||||
"batch_offset": state["batch_offset"],
|
||||
"val_loss": state["val_loss"],
|
||||
}
|
||||
del state["epoch"]
|
||||
del state["batch_offset"]
|
||||
del state["val_loss"]
|
||||
# reduce optimizer history's memory usage (only keep the last state)
|
||||
if "optimizer" in state["optimizer_history"][-1]:
|
||||
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
||||
for optim_hist in state["optimizer_history"]:
|
||||
del optim_hist["optimizer"]
|
||||
# record the optimizer class name
|
||||
if "optimizer_name" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
||||
# move best_loss into lr_scheduler_state
|
||||
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
||||
"best": state["optimizer_history"][-1]["best_loss"]
|
||||
}
|
||||
del state["optimizer_history"][-1]["best_loss"]
|
||||
# keep track of number of updates
|
||||
if "num_updates" not in state["optimizer_history"][-1]:
|
||||
state["optimizer_history"][-1]["num_updates"] = 0
|
||||
# use stateful training data iterator
|
||||
if "train_iterator" not in state["extra_state"]:
|
||||
state["extra_state"]["train_iterator"] = {
|
||||
"epoch": state["extra_state"]["epoch"],
|
||||
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
||||
}
|
||||
|
||||
# old model checkpoints may not have separate source/target positions
|
||||
# backward compatibility, cfg updates
|
||||
if "args" in state and state["args"] is not None:
|
||||
# default to translation task
|
||||
if not hasattr(state["args"], "task"):
|
||||
state["args"].task = "translation"
|
||||
# --raw-text and --lazy-load are deprecated
|
||||
if getattr(state["args"], "raw_text", False):
|
||||
state["args"].dataset_impl = "raw"
|
||||
elif getattr(state["args"], "lazy_load", False):
|
||||
state["args"].dataset_impl = "lazy"
|
||||
# epochs start at 1
|
||||
if state["extra_state"]["train_iterator"] is not None:
|
||||
state["extra_state"]["train_iterator"]["epoch"] = max(
|
||||
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
||||
)
|
||||
|
||||
if hasattr(state["args"], "remove_bpe"):
|
||||
state["args"].post_process = state["args"].remove_bpe
|
||||
|
||||
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
||||
|
||||
if "cfg" in state and state["cfg"] is not None:
|
||||
with open_dict(state["cfg"]):
|
||||
if state["cfg"].task is not None:
|
||||
if hasattr(state["cfg"].task, "max_positions") and not hasattr(
|
||||
state["cfg"].task, "max_source_positions"
|
||||
):
|
||||
state["cfg"].task.max_source_positions = state[
|
||||
"cfg"
|
||||
].task.max_positions
|
||||
state["cfg"].task.max_target_positions = state[
|
||||
"cfg"
|
||||
].task.max_positions
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
||||
"""Prune the given state_dict if desired for LayerDrop
|
||||
(https://arxiv.org/abs/1909.11556).
|
||||
|
||||
Training with LayerDrop allows models to be robust to pruning at inference
|
||||
time. This function prunes state_dict to allow smaller models to be loaded
|
||||
from a larger model and re-maps the existing state_dict for this to occur.
|
||||
|
||||
It's called by functions that load models from checkpoints and does not
|
||||
need to be called directly.
|
||||
"""
|
||||
arch = None
|
||||
if model_cfg is not None:
|
||||
arch = (
|
||||
model_cfg._name
|
||||
if isinstance(model_cfg, DictConfig)
|
||||
else getattr(model_cfg, "arch", None)
|
||||
)
|
||||
|
||||
if not model_cfg or arch is None or arch == "ptt_transformer":
|
||||
# args should not be none, but don't crash if it is.
|
||||
return state_dict
|
||||
|
||||
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
||||
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
||||
|
||||
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
||||
return state_dict
|
||||
|
||||
# apply pruning
|
||||
logger.info(
|
||||
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
||||
)
|
||||
|
||||
def create_pruning_pass(layers_to_keep, layer_name):
|
||||
keep_layers = sorted(
|
||||
int(layer_string) for layer_string in layers_to_keep.split(",")
|
||||
)
|
||||
mapping_dict = {}
|
||||
for i in range(len(keep_layers)):
|
||||
mapping_dict[str(keep_layers[i])] = str(i)
|
||||
|
||||
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
||||
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
||||
|
||||
pruning_passes = []
|
||||
if encoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
||||
if decoder_layers_to_keep:
|
||||
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
||||
|
||||
new_state_dict = {}
|
||||
for layer_name in state_dict.keys():
|
||||
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
||||
# if layer has no number in it, it is a supporting layer, such as an
|
||||
# embedding
|
||||
if not match:
|
||||
new_state_dict[layer_name] = state_dict[layer_name]
|
||||
continue
|
||||
|
||||
# otherwise, layer should be pruned.
|
||||
original_layer_number = match.group(1)
|
||||
# figure out which mapping dict to replace from
|
||||
for pruning_pass in pruning_passes:
|
||||
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
||||
"substitution_regex"
|
||||
].search(layer_name):
|
||||
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
||||
substitution_match = pruning_pass["substitution_regex"].search(
|
||||
layer_name
|
||||
)
|
||||
new_state_key = (
|
||||
layer_name[: substitution_match.start(1)]
|
||||
+ new_layer_number
|
||||
+ layer_name[substitution_match.end(1) :]
|
||||
)
|
||||
new_state_dict[new_state_key] = state_dict[layer_name]
|
||||
|
||||
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
||||
# This is more of "It would make it work fix" rather than a proper fix.
|
||||
|
||||
with open_dict(model_cfg):
|
||||
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
||||
model_cfg.encoder_layers_to_keep = None
|
||||
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
||||
model_cfg.decoder_layers_to_keep = None
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_pretrained_component_from_model(
|
||||
component: Union[FairseqEncoder, FairseqDecoder], checkpoint: str
|
||||
):
|
||||
"""
|
||||
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
||||
provided `component` object. If state_dict fails to load, there may be a
|
||||
mismatch in the architecture of the corresponding `component` found in the
|
||||
`checkpoint` file.
|
||||
"""
|
||||
if not PathManager.exists(checkpoint):
|
||||
raise IOError("Model file not found: {}".format(checkpoint))
|
||||
state = load_checkpoint_to_cpu(checkpoint)
|
||||
if isinstance(component, FairseqEncoder):
|
||||
component_type = "encoder"
|
||||
elif isinstance(component, FairseqDecoder):
|
||||
component_type = "decoder"
|
||||
else:
|
||||
raise ValueError(
|
||||
"component to load must be either a FairseqEncoder or "
|
||||
"FairseqDecoder. Loading other component types are not supported."
|
||||
)
|
||||
component_state_dict = OrderedDict()
|
||||
for key in state["model"].keys():
|
||||
if key.startswith(component_type):
|
||||
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
||||
component_subkey = key[len(component_type) + 1 :]
|
||||
component_state_dict[component_subkey] = state["model"][key]
|
||||
component.load_state_dict(component_state_dict, strict=True)
|
||||
return component
|
||||
|
||||
|
||||
def verify_checkpoint_directory(save_dir: str) -> None:
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
temp_file_path = os.path.join(save_dir, "dummy")
|
||||
try:
|
||||
with open(temp_file_path, "w"):
|
||||
pass
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
"Unable to access checkpoint save directory: {}".format(save_dir)
|
||||
)
|
||||
raise e
|
||||
else:
|
||||
os.remove(temp_file_path)
|
|
@ -0,0 +1,141 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <cstdio>
|
||||
|
||||
typedef struct
|
||||
{
|
||||
size_t reflen;
|
||||
size_t predlen;
|
||||
size_t match1;
|
||||
size_t count1;
|
||||
size_t match2;
|
||||
size_t count2;
|
||||
size_t match3;
|
||||
size_t count3;
|
||||
size_t match4;
|
||||
size_t count4;
|
||||
} bleu_stat;
|
||||
|
||||
// left trim (remove pad)
|
||||
void bleu_ltrim(size_t* len, int** sent, int pad) {
|
||||
size_t start = 0;
|
||||
while(start < *len) {
|
||||
if (*(*sent + start) != pad) { break; }
|
||||
start++;
|
||||
}
|
||||
*sent += start;
|
||||
*len -= start;
|
||||
}
|
||||
|
||||
// right trim remove (eos)
|
||||
void bleu_rtrim(size_t* len, int** sent, int pad, int eos) {
|
||||
size_t end = *len - 1;
|
||||
while (end > 0) {
|
||||
if (*(*sent + end) != eos && *(*sent + end) != pad) { break; }
|
||||
end--;
|
||||
}
|
||||
*len = end + 1;
|
||||
}
|
||||
|
||||
// left and right trim
|
||||
void bleu_trim(size_t* len, int** sent, int pad, int eos) {
|
||||
bleu_ltrim(len, sent, pad);
|
||||
bleu_rtrim(len, sent, pad, eos);
|
||||
}
|
||||
|
||||
size_t bleu_hash(int len, int* data) {
|
||||
size_t h = 14695981039346656037ul;
|
||||
size_t prime = 0x100000001b3;
|
||||
char* b = (char*) data;
|
||||
size_t blen = sizeof(int) * len;
|
||||
|
||||
while (blen-- > 0) {
|
||||
h ^= *b++;
|
||||
h *= prime;
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
void bleu_addngram(
|
||||
size_t *ntotal, size_t *nmatch, size_t n,
|
||||
size_t reflen, int* ref, size_t predlen, int* pred) {
|
||||
|
||||
if (predlen < n) { return; }
|
||||
|
||||
predlen = predlen - n + 1;
|
||||
(*ntotal) += predlen;
|
||||
|
||||
if (reflen < n) { return; }
|
||||
|
||||
reflen = reflen - n + 1;
|
||||
|
||||
std::map<size_t, size_t> count;
|
||||
while (predlen > 0) {
|
||||
size_t w = bleu_hash(n, pred++);
|
||||
count[w]++;
|
||||
predlen--;
|
||||
}
|
||||
|
||||
while (reflen > 0) {
|
||||
size_t w = bleu_hash(n, ref++);
|
||||
if (count[w] > 0) {
|
||||
(*nmatch)++;
|
||||
count[w] -=1;
|
||||
}
|
||||
reflen--;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" {
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_zero_init(bleu_stat* stat) {
|
||||
std::memset(stat, 0, sizeof(bleu_stat));
|
||||
}
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_one_init(bleu_stat* stat) {
|
||||
bleu_zero_init(stat);
|
||||
stat->count1 = 0;
|
||||
stat->count2 = 1;
|
||||
stat->count3 = 1;
|
||||
stat->count4 = 1;
|
||||
stat->match1 = 0;
|
||||
stat->match2 = 1;
|
||||
stat->match3 = 1;
|
||||
stat->match4 = 1;
|
||||
}
|
||||
|
||||
#ifdef _WIN64
|
||||
__declspec(dllexport)
|
||||
#endif
|
||||
void bleu_add(
|
||||
bleu_stat* stat,
|
||||
size_t reflen, int* ref, size_t predlen, int* pred, int pad, int eos) {
|
||||
|
||||
bleu_trim(&reflen, &ref, pad, eos);
|
||||
bleu_trim(&predlen, &pred, pad, eos);
|
||||
stat->reflen += reflen;
|
||||
stat->predlen += predlen;
|
||||
|
||||
bleu_addngram(&stat->count1, &stat->match1, 1, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count2, &stat->match2, 2, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count3, &stat->match3, 3, reflen, ref, predlen, pred);
|
||||
bleu_addngram(&stat->count4, &stat->match4, 4, reflen, ref, predlen, pred);
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <Python.h>
|
||||
|
||||
|
||||
static PyMethodDef method_def[] = {
|
||||
{NULL, NULL, 0, NULL}
|
||||
};
|
||||
|
||||
static struct PyModuleDef module_def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"libbleu", /* name of module */
|
||||
NULL, /* module documentation, may be NULL */
|
||||
-1, /* size of per-interpreter state of the module,
|
||||
or -1 if the module keeps state in global variables. */
|
||||
method_def
|
||||
};
|
||||
|
||||
|
||||
#if PY_MAJOR_VERSION == 2
|
||||
PyMODINIT_FUNC init_libbleu()
|
||||
#else
|
||||
PyMODINIT_FUNC PyInit_libbleu()
|
||||
#endif
|
||||
{
|
||||
PyObject *m = PyModule_Create(&module_def);
|
||||
if (!m) {
|
||||
return NULL;
|
||||
}
|
||||
return m;
|
||||
}
|
|
@ -0,0 +1,231 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <torch/torch.h> // @manual=//caffe2:torch_extension
|
||||
#include <pybind11/detail/common.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <iosfwd>
|
||||
#include <memory>
|
||||
#include <new>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
using namespace ::std;
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_with_dp(
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y) {
|
||||
uint32_t lx = x.size();
|
||||
uint32_t ly = y.size();
|
||||
vector<vector<uint32_t>> d(lx + 1, vector<uint32_t>(ly + 1));
|
||||
for (uint32_t i = 0; i < lx + 1; i++) {
|
||||
d[i][0] = i;
|
||||
}
|
||||
for (uint32_t j = 0; j < ly + 1; j++) {
|
||||
d[0][j] = j;
|
||||
}
|
||||
for (uint32_t i = 1; i < lx + 1; i++) {
|
||||
for (uint32_t j = 1; j < ly + 1; j++) {
|
||||
d[i][j] =
|
||||
min(min(d[i - 1][j], d[i][j - 1]) + 1,
|
||||
d[i - 1][j - 1] + 2 * (x.at(i - 1) == y.at(j - 1) ? 0 : 1));
|
||||
}
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_backtracking(
|
||||
vector<vector<uint32_t>>& d,
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y,
|
||||
uint32_t terminal_symbol) {
|
||||
vector<uint32_t> seq;
|
||||
vector<vector<uint32_t>> edit_seqs(x.size() + 2, vector<uint32_t>());
|
||||
/*
|
||||
edit_seqs:
|
||||
0~x.size() cell is the insertion sequences
|
||||
last cell is the delete sequence
|
||||
*/
|
||||
|
||||
if (x.size() == 0) {
|
||||
edit_seqs.at(0) = y;
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
uint32_t i = d.size() - 1;
|
||||
uint32_t j = d.at(0).size() - 1;
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
||||
seq.push_back(1); // insert
|
||||
seq.push_back(y.at(j - 1));
|
||||
j--;
|
||||
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
||||
seq.push_back(2); // delete
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
} else {
|
||||
seq.push_back(3); // keep
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prev_op, op, s, word;
|
||||
prev_op = 0, s = 0;
|
||||
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
||||
op = seq.at(seq.size() - 2 * k - 2);
|
||||
word = seq.at(seq.size() - 2 * k - 1);
|
||||
if (prev_op != 1) {
|
||||
s++;
|
||||
}
|
||||
if (op == 1) // insert
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(word);
|
||||
} else if (op == 2) // delete
|
||||
{
|
||||
edit_seqs.at(x.size() + 1).push_back(1);
|
||||
} else {
|
||||
edit_seqs.at(x.size() + 1).push_back(0);
|
||||
}
|
||||
|
||||
prev_op = op;
|
||||
}
|
||||
|
||||
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
||||
if (edit_seqs[k].size() == 0) {
|
||||
edit_seqs[k].push_back(terminal_symbol);
|
||||
}
|
||||
}
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
vector<vector<uint32_t>> edit_distance2_backtracking_with_delete(
|
||||
vector<vector<uint32_t>>& d,
|
||||
vector<uint32_t>& x,
|
||||
vector<uint32_t>& y,
|
||||
uint32_t terminal_symbol,
|
||||
uint32_t deletion_symbol) {
|
||||
vector<uint32_t> seq;
|
||||
vector<vector<uint32_t>> edit_seqs(x.size() + 1, vector<uint32_t>());
|
||||
/*
|
||||
edit_seqs:
|
||||
0~x.size() cell is the insertion sequences
|
||||
last cell is the delete sequence
|
||||
*/
|
||||
|
||||
if (x.size() == 0) {
|
||||
edit_seqs.at(0) = y;
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
uint32_t i = d.size() - 1;
|
||||
uint32_t j = d.at(0).size() - 1;
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (d.at(i).at(j - 1) < d.at(i).at(j))) {
|
||||
seq.push_back(1); // insert
|
||||
seq.push_back(y.at(j - 1));
|
||||
j--;
|
||||
} else if ((i > 0) && (d.at(i - 1).at(j) < d.at(i).at(j))) {
|
||||
seq.push_back(2); // delete
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
} else {
|
||||
seq.push_back(3); // keep
|
||||
seq.push_back(x.at(i - 1));
|
||||
i--;
|
||||
j--;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t prev_op, op, s, word;
|
||||
prev_op = 0, s = 0;
|
||||
for (uint32_t k = 0; k < seq.size() / 2; k++) {
|
||||
op = seq.at(seq.size() - 2 * k - 2);
|
||||
word = seq.at(seq.size() - 2 * k - 1);
|
||||
if (prev_op != 1) {
|
||||
s++;
|
||||
}
|
||||
if (op == 1) // insert
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(word);
|
||||
} else if (op == 2) // delete
|
||||
{
|
||||
edit_seqs.at(s - 1).push_back(deletion_symbol);
|
||||
}
|
||||
|
||||
prev_op = op;
|
||||
}
|
||||
|
||||
for (uint32_t k = 0; k < edit_seqs.size(); k++) {
|
||||
if (edit_seqs.at(k).size() == 0) {
|
||||
edit_seqs.at(k).push_back(terminal_symbol);
|
||||
}
|
||||
}
|
||||
return edit_seqs;
|
||||
}
|
||||
|
||||
vector<uint32_t> compute_ed2(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys) {
|
||||
vector<uint32_t> distances(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
distances.at(i) = d.at(xs.at(i).size()).at(ys.at(i).size());
|
||||
}
|
||||
return distances;
|
||||
}
|
||||
|
||||
vector<vector<vector<uint32_t>>> suggested_ed2_path(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys,
|
||||
uint32_t terminal_symbol) {
|
||||
vector<vector<vector<uint32_t>>> seq(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
seq.at(i) =
|
||||
edit_distance2_backtracking(d, xs.at(i), ys.at(i), terminal_symbol);
|
||||
}
|
||||
return seq;
|
||||
}
|
||||
|
||||
vector<vector<vector<uint32_t>>> suggested_ed2_path_with_delete(
|
||||
vector<vector<uint32_t>>& xs,
|
||||
vector<vector<uint32_t>>& ys,
|
||||
uint32_t terminal_symbol,
|
||||
uint32_t deletion_symbol) {
|
||||
vector<vector<vector<uint32_t>>> seq(xs.size());
|
||||
for (uint32_t i = 0; i < xs.size(); i++) {
|
||||
vector<vector<uint32_t>> d = edit_distance2_with_dp(xs.at(i), ys.at(i));
|
||||
seq.at(i) = edit_distance2_backtracking_with_delete(
|
||||
d, xs.at(i), ys.at(i), terminal_symbol, deletion_symbol);
|
||||
}
|
||||
return seq;
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(libnat, m) {
|
||||
m.def("compute_ed2", &compute_ed2, "compute_ed2");
|
||||
m.def("suggested_ed2_path", &suggested_ed2_path, "suggested_ed2_path");
|
||||
m.def(
|
||||
"suggested_ed2_path_with_delete",
|
||||
&suggested_ed2_path_with_delete,
|
||||
"suggested_ed2_path_with_delete");
|
||||
}
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
/*
|
||||
This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance
|
||||
*/
|
||||
|
||||
#include "edit_dist.h"
|
||||
#include <torch/types.h>
|
||||
|
||||
#ifndef TORCH_CHECK
|
||||
#define TORCH_CHECK AT_CHECK
|
||||
#endif
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
torch::Tensor LevenshteinDistance(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length) {
|
||||
|
||||
CHECK_INPUT(source);
|
||||
CHECK_INPUT(target);
|
||||
CHECK_INPUT(source_length);
|
||||
CHECK_INPUT(target_length);
|
||||
return LevenshteinDistanceCuda(source, target, source_length, target_length);
|
||||
}
|
||||
|
||||
torch::Tensor GenerateDeletionLabel(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations) {
|
||||
|
||||
CHECK_INPUT(source);
|
||||
CHECK_INPUT(operations);
|
||||
return GenerateDeletionLabelCuda(source, operations);
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabel(
|
||||
torch::Tensor target,
|
||||
torch::Tensor operations) {
|
||||
|
||||
CHECK_INPUT(target);
|
||||
CHECK_INPUT(operations);
|
||||
return GenerateInsertionLabelCuda(target, operations);
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("levenshtein_distance", &LevenshteinDistance, "Levenshtein distance");
|
||||
m.def("generate_deletion_labels", &GenerateDeletionLabel, "Generate Deletion Label");
|
||||
m.def("generate_insertion_labels", &GenerateInsertionLabel, "Generate Insertion Label");
|
||||
}
|
|
@ -0,0 +1,332 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include "edit_dist.h"
|
||||
#include <THC/THC.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <device_launch_parameters.h>
|
||||
#include <utility> // std::pair
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void generate_deletion_label_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const size_t source_size,
|
||||
const size_t operation_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ labels) {
|
||||
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * operation_size;
|
||||
const int offset_label = index * source_size;
|
||||
|
||||
for (int i = 0; i < source_size; i++) {
|
||||
labels[offset_label + i] = 0;
|
||||
}
|
||||
|
||||
int k = 0;
|
||||
for (int i = 0; i < operation_size; i++){
|
||||
if (operations[offset + i] == 0){
|
||||
break;
|
||||
} else if (operations[offset + i] == 1){
|
||||
continue;
|
||||
} else {
|
||||
labels[offset_label + k] = 3 - operations[offset + i];
|
||||
k++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void generate_insertion_label_kernel(
|
||||
const scalar_t* __restrict__ target,
|
||||
const size_t target_size,
|
||||
const size_t operation_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ labels,
|
||||
int* __restrict__ masks) {
|
||||
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * operation_size;
|
||||
const int offset_label = index * target_size;
|
||||
|
||||
int k = 0;
|
||||
int u = 0;
|
||||
int m = 0;
|
||||
|
||||
for (int i = 0; i < target_size; i++) {
|
||||
labels[offset_label + i] = 0;
|
||||
masks[offset_label + i] = 0;
|
||||
}
|
||||
|
||||
for (int i = 0; i < operation_size-1; i++){
|
||||
if (operations[offset + i] == 0){
|
||||
break;
|
||||
} else if (operations[offset + i] == 2){
|
||||
continue;
|
||||
} else if (operations[offset + i] == 1){
|
||||
masks[offset_label + m] = 1;
|
||||
u++; m++;
|
||||
} else {
|
||||
labels[offset_label + k] = u;
|
||||
masks[offset_label + m] = 0;
|
||||
k++; m++;
|
||||
u = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void levenshtein_distance_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const scalar_t* __restrict__ target,
|
||||
const int* __restrict__ source_length,
|
||||
const int* __restrict__ target_length,
|
||||
const size_t source_size,
|
||||
const size_t target_size,
|
||||
int* __restrict__ operations,
|
||||
int* __restrict__ errors_curr) {
|
||||
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * (source_size + target_size);
|
||||
const int d = index * (source_size + 1) * (target_size + 1);
|
||||
const int t = target_size + 1;
|
||||
|
||||
auto err_idx = [d, t](int i, int j) { return d + i * t + j; };
|
||||
auto opt_idx = [offset](int k) { return offset + k; };
|
||||
|
||||
const int hyp_len = source_length[index];
|
||||
const int ref_len = target_length[index];
|
||||
const scalar_t* hyp_begin = source + index * source_size;
|
||||
const scalar_t* ref_begin = target + index * target_size;
|
||||
|
||||
// dynamic programming
|
||||
for (int i = 0; i <= hyp_len; i++){
|
||||
errors_curr[err_idx(i, 0)] = i;
|
||||
}
|
||||
for (int j = 0; j <= ref_len; j++){
|
||||
errors_curr[err_idx(0, j)] = j;
|
||||
}
|
||||
for (int i = 1; i <= hyp_len; i++){
|
||||
for (int j = 1; j <= ref_len; j++){
|
||||
errors_curr[err_idx(i, j)] = min(
|
||||
min(
|
||||
errors_curr[err_idx(i-1, j)],
|
||||
errors_curr[err_idx(i, j-1)]
|
||||
) + 1,
|
||||
errors_curr[err_idx(i-1, j-1)] + 2 * (
|
||||
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// back-tracing
|
||||
int i = hyp_len;
|
||||
int j = ref_len;
|
||||
int o = hyp_len + ref_len;
|
||||
|
||||
for (int k = 0; k < source_size + target_size; k++) {
|
||||
operations[opt_idx(k)] = 0;
|
||||
}
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
|
||||
o--; operations[opt_idx(o)] = 1; j--; // insertion
|
||||
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
|
||||
o--; operations[opt_idx(o)] = 2; i--; // deletion
|
||||
} else {
|
||||
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
|
||||
}
|
||||
}
|
||||
|
||||
// moving to the left
|
||||
for (int k = 0; k < hyp_len + ref_len; k++) {
|
||||
if (k + o < hyp_len + ref_len){
|
||||
operations[opt_idx(k)] = operations[opt_idx(k+o)];
|
||||
} else{
|
||||
operations[opt_idx(k)] = 0; // padding
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void faster_levenshtein_distance_kernel(
|
||||
const scalar_t* __restrict__ source,
|
||||
const scalar_t* __restrict__ target,
|
||||
const int* __restrict__ source_length,
|
||||
const int* __restrict__ target_length,
|
||||
const size_t source_size,
|
||||
const size_t target_size,
|
||||
int* __restrict__ operations) {
|
||||
|
||||
extern __shared__ short errors[];
|
||||
auto errors_curr = errors;
|
||||
|
||||
const int index = blockIdx.x;
|
||||
const int offset = index * (source_size + target_size);
|
||||
const int t = target_size + 1;
|
||||
|
||||
auto err_idx = [t](int i, int j) { return i * t + j; };
|
||||
auto opt_idx = [offset](int k) { return offset + k; };
|
||||
|
||||
const int hyp_len = source_length[index];
|
||||
const int ref_len = target_length[index];
|
||||
const scalar_t* hyp_begin = source + index * source_size;
|
||||
const scalar_t* ref_begin = target + index * target_size;
|
||||
|
||||
// dynamic programming
|
||||
for (int i = 0; i <= hyp_len; i++){
|
||||
errors_curr[err_idx(i, 0)] = i;
|
||||
}
|
||||
for (int j = 0; j <= ref_len; j++){
|
||||
errors_curr[err_idx(0, j)] = j;
|
||||
}
|
||||
for (int i = 1; i <= hyp_len; i++){
|
||||
for (int j = 1; j <= ref_len; j++){
|
||||
errors_curr[err_idx(i, j)] = min(
|
||||
min(
|
||||
errors_curr[err_idx(i-1, j)],
|
||||
errors_curr[err_idx(i, j-1)]
|
||||
) + 1,
|
||||
errors_curr[err_idx(i-1, j-1)] + 2 * (
|
||||
*(hyp_begin+i-1) == *(ref_begin+j-1) ? 0 : 1
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// back-tracing
|
||||
int i = hyp_len;
|
||||
int j = ref_len;
|
||||
int o = hyp_len + ref_len;
|
||||
|
||||
for (int k = 0; k < source_size + target_size; k++) {
|
||||
operations[opt_idx(k)] = 0;
|
||||
}
|
||||
|
||||
while ((i >= 0) && (j >= 0)) {
|
||||
if ((i == 0) && (j == 0)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if ((j > 0) && (errors_curr[err_idx(i, j-1)] < errors_curr[err_idx(i, j)])) {
|
||||
o--; operations[opt_idx(o)] = 1; j--; // insertion
|
||||
} else if ((i > 0) && (errors_curr[err_idx(i-1, j)] < errors_curr[err_idx(i, j)])) {
|
||||
o--; operations[opt_idx(o)] = 2; i--; // deletion
|
||||
} else {
|
||||
o--; operations[opt_idx(o)] = 3; i--; j--; // do nothing
|
||||
}
|
||||
}
|
||||
|
||||
// moving to the left
|
||||
for (int k = 0; k < hyp_len + ref_len; k++) {
|
||||
if (k + o < hyp_len + ref_len){
|
||||
operations[opt_idx(k)] = operations[opt_idx(k+o)];
|
||||
} else{
|
||||
operations[opt_idx(k)] = 0; // padding
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor GenerateDeletionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations) {
|
||||
|
||||
const auto batch_size = source.size(0);
|
||||
at::TensorOptions options(source.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto labels = torch::empty({batch_size, source.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "generate_deletion_labels", ([&] {
|
||||
generate_deletion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
source.size(1),
|
||||
operations.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
labels.data_ptr<int>());
|
||||
}));
|
||||
|
||||
return labels;
|
||||
}
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
||||
torch::Tensor target,
|
||||
torch::Tensor operations) {
|
||||
|
||||
const auto batch_size = target.size(0);
|
||||
at::TensorOptions options(target.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto labels = torch::empty({batch_size, target.size(1)}, options);
|
||||
auto masks = torch::empty({batch_size, target.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(target.device().index());
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(target.scalar_type(), "generate_insertion_labels", ([&] {
|
||||
generate_insertion_label_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
||||
target.data_ptr<scalar_t>(),
|
||||
target.size(1),
|
||||
operations.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
labels.data_ptr<int>(),
|
||||
masks.data_ptr<int>());
|
||||
}));
|
||||
|
||||
return std::make_pair(labels, masks);
|
||||
}
|
||||
|
||||
|
||||
torch::Tensor LevenshteinDistanceCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length) {
|
||||
|
||||
const auto batch_size = source.size(0);
|
||||
const auto shared_size = (source.size(1) + 1) * (target.size(1) + 1) * sizeof(short);
|
||||
|
||||
at::TensorOptions options(source.device());
|
||||
options = options.dtype(at::ScalarType::Int);
|
||||
auto operations = torch::empty({batch_size, source.size(1) + target.size(1)}, options);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(source.device().index());
|
||||
|
||||
if (shared_size > 40000) {
|
||||
auto distances = torch::empty({batch_size, (source.size(1) + 1) * (target.size(1) + 1)}, options);
|
||||
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "levenshtein_distance", ([&] {
|
||||
levenshtein_distance_kernel<scalar_t><<<batch_size, 1, 0, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
target.data_ptr<scalar_t>(),
|
||||
source_length.data_ptr<int>(),
|
||||
target_length.data_ptr<int>(),
|
||||
source.size(1),
|
||||
target.size(1),
|
||||
operations.data_ptr<int>(),
|
||||
distances.data_ptr<int>());
|
||||
}));
|
||||
} else {
|
||||
AT_DISPATCH_ALL_TYPES(source.scalar_type(), "faster_levenshtein_distance", ([&] {
|
||||
faster_levenshtein_distance_kernel<scalar_t><<<batch_size, 1, shared_size, stream>>>(
|
||||
source.data_ptr<scalar_t>(),
|
||||
target.data_ptr<scalar_t>(),
|
||||
source_length.data_ptr<int>(),
|
||||
target_length.data_ptr<int>(),
|
||||
source.size(1),
|
||||
target.size(1),
|
||||
operations.data_ptr<int>());
|
||||
}));
|
||||
}
|
||||
|
||||
return operations;
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2017-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
torch::Tensor LevenshteinDistanceCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor target,
|
||||
torch::Tensor source_length,
|
||||
torch::Tensor target_length);
|
||||
|
||||
torch::Tensor GenerateDeletionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations);
|
||||
|
||||
std::pair<torch::Tensor, torch::Tensor> GenerateInsertionLabelCuda(
|
||||
torch::Tensor source,
|
||||
torch::Tensor operations);
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
from fairseq.criterions.fairseq_criterion import ( # noqa
|
||||
FairseqCriterion,
|
||||
LegacyFairseqCriterion,
|
||||
)
|
||||
from omegaconf import DictConfig
|
||||
|
||||
|
||||
(
|
||||
build_criterion_,
|
||||
register_criterion,
|
||||
CRITERION_REGISTRY,
|
||||
CRITERION_DATACLASS_REGISTRY,
|
||||
) = registry.setup_registry(
|
||||
"--criterion", base_class=FairseqCriterion, default="cross_entropy"
|
||||
)
|
||||
|
||||
|
||||
def build_criterion(cfg: DictConfig, task):
|
||||
return build_criterion_(cfg, task)
|
||||
|
||||
|
||||
# automatically import any Python files in the criterions/ directory
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
file_name = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.criterions." + file_name)
|
|
@ -0,0 +1,263 @@
|
|||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the LICENSE file in
|
||||
# the root directory of this source tree. An additional grant of patent rights
|
||||
# can be found in the PATENTS file in the same directory.
|
||||
|
||||
import math
|
||||
from argparse import Namespace
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||
from fairseq.data.data_utils import post_process
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
|
||||
@register_criterion("ctc")
|
||||
class CtcCriterion(LegacyFairseqCriterion):
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
self.blank_idx = task.target_dictionary.bos()
|
||||
self.pad_idx = task.target_dictionary.pad()
|
||||
self.eos_idx = task.target_dictionary.eos()
|
||||
self.post_process = args.post_process
|
||||
|
||||
if args.wer_args is not None:
|
||||
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
||||
|
||||
wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args)
|
||||
|
||||
dec_args = Namespace()
|
||||
dec_args.nbest = 1
|
||||
dec_args.criterion = "ctc"
|
||||
dec_args.kenlm_model = wer_compute_kenlm
|
||||
dec_args.lexicon = wer_lexicon
|
||||
dec_args.beam = 50
|
||||
dec_args.beam_size_token = min(50, len(task.target_dictionary))
|
||||
dec_args.beam_threshold = min(50, len(task.target_dictionary))
|
||||
dec_args.lm_weight = lm_w
|
||||
dec_args.word_score = ws_w
|
||||
dec_args.unk_weight = -math.inf
|
||||
dec_args.sil_weight = 0
|
||||
|
||||
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
|
||||
else:
|
||||
self.w2l_decoder = None
|
||||
|
||||
self.zero_infinity = args.zero_infinity
|
||||
self.sentence_avg = args.sentence_avg
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
parser.add_argument(
|
||||
"--zero-infinity", action="store_true", help="zero inf loss"
|
||||
)
|
||||
try:
|
||||
parser.add_argument(
|
||||
"--post-process",
|
||||
"--remove-bpe",
|
||||
default="letter",
|
||||
help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)",
|
||||
)
|
||||
except:
|
||||
pass # this option might have been added from eval args
|
||||
parser.add_argument(
|
||||
"--wer-args",
|
||||
type=str,
|
||||
default=None,
|
||||
help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \
|
||||
path to lexicon, lm score, word score",
|
||||
)
|
||||
|
||||
def get_net_output(self, model, sample):
|
||||
net_output = model(**sample["net_input"])
|
||||
return net_output
|
||||
|
||||
def get_loss(self, model, sample, net_output, reduce=True):
|
||||
lprobs = model.get_normalized_probs(
|
||||
net_output, log_probs=True
|
||||
).contiguous() # (T, B, C) from the encoder
|
||||
|
||||
if "src_lengths" in sample["net_input"]:
|
||||
input_lengths = sample["net_input"]["src_lengths"]
|
||||
else:
|
||||
non_padding_mask = ~net_output["padding_mask"]
|
||||
input_lengths = non_padding_mask.long().sum(-1)
|
||||
|
||||
pad_mask = (sample["target"] != self.pad_idx) & (
|
||||
sample["target"] != self.eos_idx
|
||||
)
|
||||
targets_flat = sample["target"].masked_select(pad_mask)
|
||||
target_lengths = sample["target_lengths"]
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
lprobs,
|
||||
targets_flat,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.blank_idx,
|
||||
reduction="sum",
|
||||
zero_infinity=self.zero_infinity,
|
||||
)
|
||||
|
||||
ntokens = (
|
||||
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
|
||||
)
|
||||
|
||||
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data), # * sample['ntokens'],
|
||||
"ntokens": ntokens,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
if not model.training:
|
||||
import editdistance
|
||||
|
||||
with torch.no_grad():
|
||||
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
|
||||
|
||||
c_err = 0
|
||||
c_len = 0
|
||||
w_errs = 0
|
||||
w_len = 0
|
||||
wv_errs = 0
|
||||
for lp, t, inp_l in zip(
|
||||
lprobs_t,
|
||||
sample["target_label"]
|
||||
if "target_label" in sample
|
||||
else sample["target"],
|
||||
input_lengths,
|
||||
):
|
||||
lp = lp[:inp_l].unsqueeze(0)
|
||||
|
||||
decoded = None
|
||||
if self.w2l_decoder is not None:
|
||||
decoded = self.w2l_decoder.decode(lp)
|
||||
if len(decoded) < 1:
|
||||
decoded = None
|
||||
else:
|
||||
decoded = decoded[0]
|
||||
if len(decoded) < 1:
|
||||
decoded = None
|
||||
else:
|
||||
decoded = decoded[0]
|
||||
|
||||
p = (t != self.task.target_dictionary.pad()) & (
|
||||
t != self.task.target_dictionary.eos()
|
||||
)
|
||||
targ = t[p]
|
||||
targ_units = self.task.target_dictionary.string(targ)
|
||||
targ_units_arr = targ.tolist()
|
||||
|
||||
toks = lp.argmax(dim=-1).unique_consecutive()
|
||||
pred_units_arr = toks[toks != self.blank_idx].tolist()
|
||||
|
||||
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
|
||||
c_len += len(targ_units_arr)
|
||||
|
||||
targ_words = post_process(targ_units, self.post_process).split()
|
||||
|
||||
pred_units = self.task.target_dictionary.string(pred_units_arr)
|
||||
pred_words_raw = post_process(pred_units, self.post_process).split()
|
||||
|
||||
if decoded is not None and "words" in decoded:
|
||||
pred_words = decoded["words"]
|
||||
w_errs += editdistance.eval(pred_words, targ_words)
|
||||
wv_errs += editdistance.eval(pred_words_raw, targ_words)
|
||||
else:
|
||||
dist = editdistance.eval(pred_words_raw, targ_words)
|
||||
w_errs += dist
|
||||
wv_errs += dist
|
||||
|
||||
w_len += len(targ_words)
|
||||
|
||||
logging_output["wv_errors"] = wv_errs
|
||||
logging_output["w_errors"] = w_errs
|
||||
logging_output["w_total"] = w_len
|
||||
logging_output["c_errors"] = c_err
|
||||
logging_output["c_total"] = c_len
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
import pdb
|
||||
net_output = self.get_net_output(model, sample)
|
||||
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce)
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
nsentences = utils.item(
|
||||
sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
)
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar("ntokens", ntokens)
|
||||
metrics.log_scalar("nsentences", nsentences)
|
||||
if sample_size != ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
|
||||
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_errors", c_errors)
|
||||
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_total", c_total)
|
||||
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_errors", w_errors)
|
||||
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_wv_errors", wv_errors)
|
||||
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_total", w_total)
|
||||
|
||||
if c_total > 0:
|
||||
metrics.log_derived(
|
||||
"uer",
|
||||
lambda meters: safe_round(
|
||||
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
|
||||
)
|
||||
if meters["_c_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if w_total > 0:
|
||||
metrics.log_derived(
|
||||
"wer",
|
||||
lambda meters: safe_round(
|
||||
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
||||
)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
metrics.log_derived(
|
||||
"raw_wer",
|
||||
lambda meters: safe_round(
|
||||
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
|
||||
)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.dataclass.utils import gen_parser_from_dataclass
|
||||
from omegaconf import DictConfig
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
|
||||
class FairseqCriterion(_Loss):
|
||||
def __init__(self, task):
|
||||
super().__init__()
|
||||
self.task = task
|
||||
if hasattr(task, "target_dictionary"):
|
||||
tgt_dict = task.target_dictionary
|
||||
self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100
|
||||
|
||||
@classmethod
|
||||
def add_args(cls, parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
dc = getattr(cls, "__dataclass", None)
|
||||
if dc is not None:
|
||||
gen_parser_from_dataclass(parser, dc())
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, cfg: DictConfig, task):
|
||||
"""Construct a criterion from command-line args."""
|
||||
# arguments in the __init__.
|
||||
init_args = {}
|
||||
for p in inspect.signature(cls).parameters.values():
|
||||
if (
|
||||
p.kind == p.POSITIONAL_ONLY
|
||||
or p.kind == p.VAR_POSITIONAL
|
||||
or p.kind == p.VAR_KEYWORD
|
||||
):
|
||||
# we haven't implemented inference for these argument types,
|
||||
# but PRs welcome :)
|
||||
raise NotImplementedError("{} not supported".format(p.kind))
|
||||
|
||||
assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY}
|
||||
|
||||
if p.name == "task":
|
||||
init_args["task"] = task
|
||||
elif hasattr(cfg, p.name):
|
||||
init_args[p.name] = getattr(cfg, p.name)
|
||||
elif p.default != p.empty:
|
||||
pass # we'll use the default value
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unable to infer Criterion arguments, please implement "
|
||||
"{}.build_criterion".format(cls.__name__)
|
||||
)
|
||||
return cls(**init_args)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def aggregate_logging_outputs(
|
||||
logging_outputs: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
"The aggregate_logging_outputs API is deprecated. "
|
||||
"Please use the reduce_metrics API instead."
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
utils.deprecation_warning(
|
||||
"Criterions should implement the reduce_metrics API. "
|
||||
"Falling back to deprecated aggregate_logging_outputs API."
|
||||
)
|
||||
agg_logging_outputs = cls.aggregate_logging_outputs(logging_outputs)
|
||||
for k, v in agg_logging_outputs.items():
|
||||
if k in {"nsentences", "ntokens", "sample_size"}:
|
||||
continue
|
||||
metrics.log_scalar(k, v)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return False
|
||||
|
||||
|
||||
class LegacyFairseqCriterion(FairseqCriterion):
|
||||
def __init__(self, args, task):
|
||||
super().__init__(task=task)
|
||||
self.args = args
|
||||
|
||||
utils.deprecation_warning(
|
||||
"Criterions should take explicit arguments instead of an "
|
||||
"argparse.Namespace object, please update your criterion by "
|
||||
"extending FairseqCriterion instead of LegacyFairseqCriterion."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_criterion(cls, args, task):
|
||||
"""Construct a criterion from command-line args."""
|
||||
return cls(args, task)
|
|
@ -0,0 +1,160 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
|
||||
|
||||
def label_smoothed_nll_loss(lprobs, target, epsilon, ignore_index=None, reduce=True):
|
||||
if target.dim() == lprobs.dim() - 1:
|
||||
target = target.unsqueeze(-1)
|
||||
nll_loss = -lprobs.gather(dim=-1, index=target)
|
||||
smooth_loss = -lprobs.sum(dim=-1, keepdim=True)
|
||||
if ignore_index is not None:
|
||||
pad_mask = target.eq(ignore_index)
|
||||
nll_loss.masked_fill_(pad_mask, 0.0)
|
||||
smooth_loss.masked_fill_(pad_mask, 0.0)
|
||||
else:
|
||||
nll_loss = nll_loss.squeeze(-1)
|
||||
smooth_loss = smooth_loss.squeeze(-1)
|
||||
if reduce:
|
||||
nll_loss = nll_loss.sum()
|
||||
smooth_loss = smooth_loss.sum()
|
||||
eps_i = epsilon / lprobs.size(-1)
|
||||
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
|
||||
return loss, nll_loss
|
||||
|
||||
|
||||
@register_criterion("label_smoothed_cross_entropy")
|
||||
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
|
||||
def __init__(
|
||||
self,
|
||||
task,
|
||||
sentence_avg,
|
||||
label_smoothing,
|
||||
ignore_prefix_size=0,
|
||||
report_accuracy=False,
|
||||
):
|
||||
super().__init__(task)
|
||||
self.sentence_avg = sentence_avg
|
||||
self.eps = label_smoothing
|
||||
self.ignore_prefix_size = ignore_prefix_size
|
||||
self.report_accuracy = report_accuracy
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
|
||||
help='epsilon for label smoothing, 0 means no label smoothing')
|
||||
parser.add_argument('--report-accuracy', action='store_true',
|
||||
help='report accuracy metric')
|
||||
parser.add_argument('--ignore-prefix-size', default=0, type=int,
|
||||
help='Ignore first N tokens')
|
||||
# fmt: on
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": loss.data,
|
||||
"nll_loss": nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
if self.report_accuracy:
|
||||
n_correct, total = self.compute_accuracy(model, net_output, sample)
|
||||
logging_output["n_correct"] = utils.item(n_correct.data)
|
||||
logging_output["total"] = utils.item(total.data)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def get_lprobs_and_target(self, model, net_output, sample):
|
||||
lprobs = model.get_normalized_probs(net_output, log_probs=True)
|
||||
target = model.get_targets(sample, net_output)
|
||||
if self.ignore_prefix_size > 0:
|
||||
if getattr(lprobs, "batch_first", False):
|
||||
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
|
||||
target = target[:, self.ignore_prefix_size :].contiguous()
|
||||
else:
|
||||
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
|
||||
target = target[self.ignore_prefix_size :, :].contiguous()
|
||||
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
|
||||
|
||||
def compute_loss(self, model, net_output, sample, reduce=True):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
loss, nll_loss = label_smoothed_nll_loss(
|
||||
lprobs,
|
||||
target,
|
||||
self.eps,
|
||||
ignore_index=self.padding_idx,
|
||||
reduce=reduce,
|
||||
)
|
||||
return loss, nll_loss
|
||||
|
||||
def compute_accuracy(self, model, net_output, sample):
|
||||
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
|
||||
mask = target.ne(self.padding_idx)
|
||||
n_correct = torch.sum(
|
||||
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
|
||||
)
|
||||
total = torch.sum(mask)
|
||||
return n_correct, total
|
||||
|
||||
@classmethod
|
||||
def reduce_metrics(cls, logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
||||
nll_loss_sum = sum(log.get("nll_loss", 0) for log in logging_outputs)
|
||||
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
||||
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
|
||||
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
|
||||
if total > 0:
|
||||
metrics.log_scalar("total", total)
|
||||
n_correct = utils.item(
|
||||
sum(log.get("n_correct", 0) for log in logging_outputs)
|
||||
)
|
||||
metrics.log_scalar("n_correct", n_correct)
|
||||
metrics.log_derived(
|
||||
"accuracy",
|
||||
lambda meters: round(
|
||||
meters["n_correct"].sum * 100.0 / meters["total"].sum, 3
|
||||
)
|
||||
if meters["total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
|
@ -0,0 +1,125 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import register_criterion
|
||||
|
||||
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
|
||||
|
||||
|
||||
@register_criterion("label_smoothed_cross_entropy_with_alignment")
|
||||
class LabelSmoothedCrossEntropyCriterionWithAlignment(
|
||||
LabelSmoothedCrossEntropyCriterion
|
||||
):
|
||||
def __init__(self, task, sentence_avg, label_smoothing, alignment_lambda):
|
||||
super().__init__(task, sentence_avg, label_smoothing)
|
||||
self.alignment_lambda = alignment_lambda
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
LabelSmoothedCrossEntropyCriterion.add_args(parser)
|
||||
parser.add_argument(
|
||||
"--alignment-lambda",
|
||||
default=0.05,
|
||||
type=float,
|
||||
metavar="D",
|
||||
help="weight for the alignment loss",
|
||||
)
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
|
||||
sample_size = (
|
||||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
||||
)
|
||||
logging_output = {
|
||||
"loss": utils.item(loss.data) if reduce else loss.data,
|
||||
"nll_loss": utils.item(nll_loss.data) if reduce else nll_loss.data,
|
||||
"ntokens": sample["ntokens"],
|
||||
"nsentences": sample["target"].size(0),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
alignment_loss = None
|
||||
|
||||
# Compute alignment loss only for training set and non dummy batches.
|
||||
if "alignments" in sample and sample["alignments"] is not None:
|
||||
alignment_loss = self.compute_alignment_loss(sample, net_output)
|
||||
|
||||
if alignment_loss is not None:
|
||||
logging_output["alignment_loss"] = utils.item(alignment_loss.data)
|
||||
loss += self.alignment_lambda * alignment_loss
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def compute_alignment_loss(self, sample, net_output):
|
||||
attn_prob = net_output[1]["attn"][0]
|
||||
bsz, tgt_sz, src_sz = attn_prob.shape
|
||||
attn = attn_prob.view(bsz * tgt_sz, src_sz)
|
||||
|
||||
align = sample["alignments"]
|
||||
align_weights = sample["align_weights"].float()
|
||||
|
||||
if len(align) > 0:
|
||||
# Alignment loss computation. align (shape [:, 2]) contains the src-tgt index pairs corresponding to
|
||||
# the alignments. align_weights (shape [:]) contains the 1 / frequency of a tgt index for normalizing.
|
||||
loss = -(
|
||||
(attn[align[:, 1][:, None], align[:, 0][:, None]]).log()
|
||||
* align_weights[:, None]
|
||||
).sum()
|
||||
else:
|
||||
return None
|
||||
|
||||
return loss
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
nll_loss_sum = utils.item(
|
||||
sum(log.get("nll_loss", 0) for log in logging_outputs)
|
||||
)
|
||||
alignment_loss_sum = utils.item(
|
||||
sum(log.get("alignment_loss", 0) for log in logging_outputs)
|
||||
)
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"nll_loss", nll_loss_sum / ntokens / math.log(2), ntokens, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"alignment_loss",
|
||||
alignment_loss_sum / sample_size / math.log(2),
|
||||
sample_size,
|
||||
round=3,
|
||||
)
|
||||
metrics.log_derived(
|
||||
"ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return True
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq import metrics, utils
|
||||
from fairseq.criterions import FairseqCriterion, register_criterion
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
|
||||
@register_criterion("wav2vec")
|
||||
class Wav2vecCriterion(FairseqCriterion):
|
||||
def __init__(self, task, infonce=False, loss_weights=None, log_keys=None):
|
||||
super().__init__(task)
|
||||
self.infonce = infonce
|
||||
self.loss_weights = None if loss_weights is None else eval(loss_weights)
|
||||
self.log_keys = [] if log_keys is None else eval(log_keys)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
"""Add criterion-specific arguments to the parser."""
|
||||
# fmt: off
|
||||
parser.add_argument('--infonce', action='store_true',
|
||||
help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)')
|
||||
parser.add_argument('--loss-weights', type=str, default=None,
|
||||
help='weights for additional loss terms (not first one)')
|
||||
parser.add_argument('--log-keys', type=str, default=None,
|
||||
help='output keys to log')
|
||||
# fmt: on
|
||||
|
||||
def get_net_output(self, model, sample):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
1) the loss
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
net_output = model(**sample["net_input"])
|
||||
return net_output
|
||||
|
||||
def get_loss(self, model, sample, net_output, reduce=True, log_pred=False):
|
||||
logits = model.get_logits(net_output).float()
|
||||
target = model.get_targets(sample, net_output)
|
||||
|
||||
weights = None
|
||||
if hasattr(model, "get_target_weights") and not self.infonce:
|
||||
weights = model.get_target_weights(target, net_output)
|
||||
if torch.is_tensor(weights):
|
||||
weights = weights.float()
|
||||
|
||||
losses = []
|
||||
|
||||
if self.infonce:
|
||||
loss = F.cross_entropy(
|
||||
logits,
|
||||
target,
|
||||
reduction="sum" if reduce else "none",
|
||||
)
|
||||
else:
|
||||
loss = F.binary_cross_entropy_with_logits(
|
||||
logits,
|
||||
target.float(),
|
||||
weights,
|
||||
reduction="sum" if reduce else "none",
|
||||
)
|
||||
|
||||
sample_size = target.numel() if self.infonce else target.long().sum().item()
|
||||
losses.append(loss.detach().clone())
|
||||
|
||||
if self.loss_weights is not None:
|
||||
assert hasattr(model, "get_extra_losses")
|
||||
extra_losses = model.get_extra_losses(net_output)
|
||||
if torch.is_tensor(extra_losses):
|
||||
extra_losses = [extra_losses]
|
||||
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
|
||||
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
|
||||
assert len(extra_losses) == len(
|
||||
self.loss_weights
|
||||
), f"{len(extra_losses)}, {len(self.loss_weights)}"
|
||||
for p, coef in zip(extra_losses, self.loss_weights):
|
||||
if coef != 0 and p is not None:
|
||||
p = coef * p.float() * sample_size
|
||||
loss += p
|
||||
losses.append(p)
|
||||
|
||||
logging_output = {
|
||||
"loss": loss.item() if reduce else loss,
|
||||
"ntokens": sample_size,
|
||||
"nsentences": sample["id"].numel(),
|
||||
"sample_size": sample_size,
|
||||
}
|
||||
|
||||
for lk in self.log_keys:
|
||||
if lk in net_output:
|
||||
logging_output[lk] = float((net_output[lk]))
|
||||
|
||||
if len(losses) > 1:
|
||||
for i, l in enumerate(losses):
|
||||
logging_output[f"loss_{i}"] = l.item()
|
||||
|
||||
if self.infonce:
|
||||
with torch.no_grad():
|
||||
if logits.numel() == 0:
|
||||
corr = 0
|
||||
count = 0
|
||||
else:
|
||||
assert logits.dim() > 1, logits.shape
|
||||
max = logits.argmax(-1) == 0
|
||||
min = logits.argmin(-1) == 0
|
||||
both = max & min
|
||||
corr = max.long().sum().item() - both.long().sum().item()
|
||||
count = max.numel()
|
||||
|
||||
logging_output["correct"] = corr
|
||||
logging_output["count"] = count
|
||||
|
||||
if log_pred:
|
||||
logging_output["logits"] = logits.cpu().numpy()
|
||||
logging_output["target"] = target.cpu().numpy()
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def forward(self, model, sample, reduce=True, log_pred=False):
|
||||
net_output = self.get_net_output(model, sample)
|
||||
loss, sample_size, logging_output = self.get_loss(model, sample, net_output, reduce, log_pred)
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
"""Aggregate logging outputs from data parallel training."""
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
|
||||
nsentences = utils.item(
|
||||
sum(log.get("nsentences", 0) for log in logging_outputs)
|
||||
)
|
||||
sample_size = utils.item(
|
||||
sum(log.get("sample_size", 0) for log in logging_outputs)
|
||||
)
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar("ntokens", ntokens)
|
||||
metrics.log_scalar("nsentences", nsentences)
|
||||
|
||||
correct = sum(log.get("correct", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_correct", correct)
|
||||
|
||||
total = sum(log.get("count", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_total", total)
|
||||
|
||||
if total > 0:
|
||||
metrics.log_derived(
|
||||
"accuracy",
|
||||
lambda meters: safe_round(
|
||||
meters["_correct"].sum / meters["_total"].sum, 5
|
||||
)
|
||||
if meters["_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
builtin_keys = {
|
||||
"loss",
|
||||
"ntokens",
|
||||
"nsentences",
|
||||
"sample_size",
|
||||
"correct",
|
||||
"count",
|
||||
}
|
||||
|
||||
for k in logging_outputs[0]:
|
||||
if k not in builtin_keys:
|
||||
val = sum(log.get(k, 0) for log in logging_outputs) / len(
|
||||
logging_outputs
|
||||
)
|
||||
if k.startswith("loss"):
|
||||
metrics.log_scalar(k, val / sample_size / math.log(2), sample_size)
|
||||
else:
|
||||
metrics.log_scalar(k, val, round=3)
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
"""
|
||||
Whether the logging outputs returned by `forward` can be summed
|
||||
across workers prior to calling `reduce_metrics`. Setting this
|
||||
to True will improves distributed training speed.
|
||||
"""
|
||||
return False
|
|
@ -0,0 +1,137 @@
|
|||
import torch
|
||||
import math
|
||||
|
||||
from fairseq import pdb
|
||||
from fairseq import utils, metrics
|
||||
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
||||
from fairseq.criterions.wav2vec_criterion import Wav2vecCriterion
|
||||
from fairseq.criterions.ctc import CtcCriterion
|
||||
from fairseq.logging.meters import safe_round
|
||||
|
||||
@register_criterion('wav2vec_mtl')
|
||||
class Wav2vecMTLCriterion(LegacyFairseqCriterion):
|
||||
|
||||
def __init__(self, args, task):
|
||||
super().__init__(args, task)
|
||||
self.mtlalpha = args.mtlalpha
|
||||
self.w2v_criterion = Wav2vecCriterion(task, args.infonce, args.loss_weights, args.log_keys)
|
||||
if self.mtlalpha > 0:
|
||||
self.ctc_criterion = CtcCriterion(args, task)
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
Wav2vecCriterion.add_args(parser)
|
||||
CtcCriterion.add_args(parser)
|
||||
parser.add_argument('--mtlalpha', type=float, default=0.5)
|
||||
|
||||
|
||||
def forward(self, model, sample, reduce=True):
|
||||
net_output = model(**sample["net_input"])
|
||||
|
||||
if self.mtlalpha > 0.0:
|
||||
ctc_loss, ctc_sample_size, ctc_logging_output = self.ctc_criterion.get_loss(model, sample, net_output, reduce)
|
||||
else:
|
||||
ctc_loss = 0
|
||||
ctc_sample_size = 0
|
||||
ctc_logging_output = {}
|
||||
|
||||
infonce_loss, infonce_sample_size, infonce_logging_output = self.w2v_criterion.get_loss(model.w2v_encoder.w2v_model, sample, net_output['contrastive_res'], reduce)
|
||||
loss = self.mtlalpha * ctc_loss + (1.0 - self.mtlalpha) * infonce_loss
|
||||
sample_size = infonce_sample_size
|
||||
logging_output = {'loss': loss, 'ntokens': ctc_logging_output['ntokens'], 'nsentences': ctc_logging_output['nsentences'],
|
||||
'ctc': ctc_logging_output, 'infonce': infonce_logging_output}
|
||||
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
@staticmethod
|
||||
def logging_outputs_can_be_summed() -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def reduce_metrics(logging_outputs) -> None:
|
||||
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
|
||||
|
||||
ctc_loss_sum = utils.item(sum(log['ctc'].get('loss', 0) for log in logging_outputs))
|
||||
ctc_sample_size = utils.item(sum(log['ctc'].get('sample_size', 0) for log in logging_outputs))
|
||||
ctc_ntokens = utils.item(sum(log['ctc'].get('ntokens', 0) for log in logging_outputs))
|
||||
ctc_nsentences = utils.item(sum(log['ctc'].get('nsentences', 0) for log in logging_outputs))
|
||||
|
||||
ctras_loss_sum = utils.item(sum(log['infonce'].get('loss', 0) for log in logging_outputs))
|
||||
ctras_sample_size = utils.item(sum(log['infonce'].get('sample_size', 0) for log in logging_outputs))
|
||||
ctras_ntokens = utils.item(sum(log['infonce'].get('ntokens', 0) for log in logging_outputs))
|
||||
ctras_nsentences = utils.item(sum(log['infonce'].get('nsentences', 0) for log in logging_outputs))
|
||||
|
||||
metrics.log_scalar(
|
||||
"loss", loss_sum, 1, round=3)
|
||||
metrics.log_scalar(
|
||||
"ctc_loss", ctc_loss_sum / ctc_sample_size / math.log(2), ctc_sample_size, round=3
|
||||
)
|
||||
metrics.log_scalar(
|
||||
"contrastive_loss", ctras_loss_sum / ctras_sample_size / math.log(2), ctras_sample_size, round=3
|
||||
)
|
||||
if ctc_sample_size != ctc_ntokens:
|
||||
metrics.log_scalar(
|
||||
"nll_loss", ctc_loss_sum / ctc_ntokens / math.log(2), ctc_ntokens, round=3
|
||||
)
|
||||
c_errors = sum(log['ctc'].get("c_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_errors", c_errors)
|
||||
c_total = sum(log['ctc'].get("c_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_c_total", c_total)
|
||||
w_errors = sum(log['ctc'].get("w_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_errors", w_errors)
|
||||
wv_errors = sum(log['ctc'].get("wv_errors", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_wv_errors", wv_errors)
|
||||
w_total = sum(log['ctc'].get("w_total", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_w_total", w_total)
|
||||
|
||||
if c_total > 0:
|
||||
metrics.log_derived(
|
||||
"uer",
|
||||
lambda meters: safe_round(meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3)
|
||||
if meters["_c_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
if w_total > 0:
|
||||
metrics.log_derived(
|
||||
"wer",
|
||||
lambda meters: safe_round(meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
metrics.log_derived(
|
||||
"raw_wer",
|
||||
lambda meters: safe_round(meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3)
|
||||
if meters["_w_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
|
||||
metrics.log_scalar("nsentences", ctras_nsentences)
|
||||
metrics.log_scalar("ctc_sample_size", ctc_sample_size)
|
||||
metrics.log_scalar("contrastive_sample_size", ctras_sample_size)
|
||||
|
||||
|
||||
correct = sum(log['infonce'].get("correct", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_correct", correct)
|
||||
|
||||
total = sum(log['infonce'].get("count", 0) for log in logging_outputs)
|
||||
metrics.log_scalar("_total", total)
|
||||
|
||||
|
||||
if total > 0:
|
||||
metrics.log_derived(
|
||||
"accuracy",
|
||||
lambda meters: safe_round(meters["_correct"].sum / meters["_total"].sum, 5)
|
||||
if meters["_total"].sum > 0
|
||||
else float("nan"),
|
||||
)
|
||||
|
||||
builtin_keys = {'loss', 'ntokens', 'nsentences', 'sample_size', 'correct', 'count'}
|
||||
for k in logging_outputs[0]['infonce']:
|
||||
if k not in builtin_keys:
|
||||
val = sum(log['infonce'].get(k, 0) for log in logging_outputs) / len(logging_outputs)
|
||||
if k.startswith('loss'):
|
||||
metrics.log_scalar(k, val / ctras_sample_size / math.log(2), ctras_sample_size)
|
||||
else:
|
||||
metrics.log_scalar(k, val, round=3)
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
"""isort:skip_file"""
|
||||
|
||||
from .dictionary import Dictionary, TruncatedDictionary
|
||||
|
||||
from .fairseq_dataset import FairseqDataset, FairseqIterableDataset
|
||||
|
||||
from .base_wrapper_dataset import BaseWrapperDataset
|
||||
|
||||
from .add_target_dataset import AddTargetDataset
|
||||
from .audio.raw_audio_dataset import FileAudioDataset
|
||||
from .concat_dataset import ConcatDataset
|
||||
from .id_dataset import IdDataset
|
||||
from .resampling_dataset import ResamplingDataset
|
||||
|
||||
from .iterators import (
|
||||
CountingIterator,
|
||||
EpochBatchIterator,
|
||||
GroupedIterator,
|
||||
ShardedIterator,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTargetDataset",
|
||||
"ConcatDataset",
|
||||
"CountingIterator",
|
||||
"Dictionary",
|
||||
"EpochBatchIterator",
|
||||
"FairseqDataset",
|
||||
"FairseqIterableDataset",
|
||||
"FastaDataset",
|
||||
"GroupedIterator",
|
||||
"IdDataset",
|
||||
"ResamplingDataset",
|
||||
"ShardedIterator",
|
||||
]
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import BaseWrapperDataset, data_utils
|
||||
|
||||
|
||||
class AddTargetDataset(BaseWrapperDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset,
|
||||
labels,
|
||||
pad,
|
||||
eos,
|
||||
batch_targets,
|
||||
process_label=None,
|
||||
add_to_input=False,
|
||||
):
|
||||
super().__init__(dataset)
|
||||
self.labels = labels
|
||||
self.batch_targets = batch_targets
|
||||
self.pad = pad
|
||||
self.eos = eos
|
||||
self.process_label = process_label
|
||||
self.add_to_input = add_to_input
|
||||
|
||||
def get_label(self, index):
|
||||
return (
|
||||
self.labels[index]
|
||||
if self.process_label is None
|
||||
else self.process_label(self.labels[index])
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
item = self.dataset[index]
|
||||
item["label"] = self.get_label(index)
|
||||
return item
|
||||
|
||||
def size(self, index):
|
||||
sz = self.dataset.size(index)
|
||||
own_sz = len(self.get_label(index))
|
||||
return (sz, own_sz)
|
||||
|
||||
def collater(self, samples):
|
||||
collated = self.dataset.collater(samples)
|
||||
if len(collated) == 0:
|
||||
return collated
|
||||
indices = set(collated["id"].tolist())
|
||||
target = [s["label"] for s in samples if s["id"] in indices]
|
||||
|
||||
if self.batch_targets:
|
||||
collated["target_lengths"] = torch.LongTensor([len(t) for t in target])
|
||||
target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False)
|
||||
collated["ntokens"] = collated["target_lengths"].sum().item()
|
||||
else:
|
||||
collated["ntokens"] = sum([len(t) for t in target])
|
||||
|
||||
collated["target"] = target
|
||||
|
||||
if self.add_to_input:
|
||||
eos = target.new_full((target.size(0), 1), self.eos)
|
||||
collated["target"] = torch.cat([target, eos], dim=-1).long()
|
||||
collated["net_input"]["prev_output_tokens"] = torch.cat(
|
||||
[eos, target], dim=-1
|
||||
).long()
|
||||
collated["ntokens"] += target.size(0)
|
||||
return collated
|
|
@ -0,0 +1,85 @@
|
|||
import os.path as op
|
||||
from typing import BinaryIO, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_waveform(
|
||||
path_or_fp: Union[str, BinaryIO], normalization=True
|
||||
) -> Tuple[np.ndarray, int]:
|
||||
"""Get the waveform and sample rate of a 16-bit mono-channel WAV or FLAC.
|
||||
|
||||
Args:
|
||||
path_or_fp (str or BinaryIO): the path or file-like object
|
||||
normalization (bool): Normalize values to [-1, 1] (Default: True)
|
||||
"""
|
||||
if isinstance(path_or_fp, str):
|
||||
ext = op.splitext(op.basename(path_or_fp))[1]
|
||||
if ext not in {".flac", ".wav"}:
|
||||
raise ValueError(f"Unsupported audio format: {ext}")
|
||||
|
||||
try:
|
||||
import soundfile as sf
|
||||
except ImportError:
|
||||
raise ImportError("Please install soundfile to load WAV/FLAC file")
|
||||
|
||||
waveform, sample_rate = sf.read(path_or_fp, dtype="float32")
|
||||
if not normalization:
|
||||
waveform *= 2 ** 15 # denormalized to 16-bit signed integers
|
||||
return waveform, sample_rate
|
||||
|
||||
|
||||
def _get_kaldi_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via PyKaldi."""
|
||||
try:
|
||||
from kaldi.feat.mel import MelBanksOptions
|
||||
from kaldi.feat.fbank import FbankOptions, Fbank
|
||||
from kaldi.feat.window import FrameExtractionOptions
|
||||
from kaldi.matrix import Vector
|
||||
|
||||
mel_opts = MelBanksOptions()
|
||||
mel_opts.num_bins = n_bins
|
||||
frame_opts = FrameExtractionOptions()
|
||||
frame_opts.samp_freq = sample_rate
|
||||
opts = FbankOptions()
|
||||
opts.mel_opts = mel_opts
|
||||
opts.frame_opts = frame_opts
|
||||
fbank = Fbank(opts=opts)
|
||||
features = fbank.compute(Vector(waveform), 1.0).numpy()
|
||||
return features
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def _get_torchaudio_fbank(waveform, sample_rate, n_bins=80) -> Optional[np.ndarray]:
|
||||
"""Get mel-filter bank features via TorchAudio."""
|
||||
try:
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as ta_kaldi
|
||||
|
||||
waveform = torch.from_numpy(waveform).unsqueeze(0)
|
||||
features = ta_kaldi.fbank(
|
||||
waveform, num_mel_bins=n_bins, sample_frequency=sample_rate
|
||||
)
|
||||
return features.numpy()
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
def get_fbank(path_or_fp: Union[str, BinaryIO], n_bins=80) -> np.ndarray:
|
||||
"""Get mel-filter bank features via PyKaldi or TorchAudio. Prefer PyKaldi
|
||||
(faster CPP implementation) to TorchAudio (Python implementation). Note that
|
||||
Kaldi/TorchAudio requires 16-bit signed integers as inputs and hence the
|
||||
waveform should not be normalized."""
|
||||
sound, sample_rate = get_waveform(path_or_fp, normalization=False)
|
||||
|
||||
features = _get_kaldi_fbank(sound, sample_rate, n_bins)
|
||||
if features is None:
|
||||
features = _get_torchaudio_fbank(sound, sample_rate, n_bins)
|
||||
if features is None:
|
||||
raise ImportError(
|
||||
"Please install pyKaldi or torchaudio to enable "
|
||||
"online filterbank feature extraction"
|
||||
)
|
||||
|
||||
return features
|
|
@ -0,0 +1,82 @@
|
|||
import importlib
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class AudioFeatureTransform(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config_dict(cls, config: Optional[Dict] = None):
|
||||
pass
|
||||
|
||||
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY = {}
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES = set()
|
||||
|
||||
|
||||
def register_audio_feature_transform(name):
|
||||
def register_audio_feature_transform_cls(cls):
|
||||
if name in AUDIO_FEATURE_TRANSFORM_REGISTRY:
|
||||
raise ValueError(f"Cannot register duplicate transform ({name})")
|
||||
if not issubclass(cls, AudioFeatureTransform):
|
||||
raise ValueError(
|
||||
f"Transform ({name}: {cls.__name__}) must extend "
|
||||
"AudioFeatureTransform"
|
||||
)
|
||||
if cls.__name__ in AUDIO_FEATURE_TRANSFORM_CLASS_NAMES:
|
||||
raise ValueError(
|
||||
f"Cannot register audio feature transform with duplicate "
|
||||
f"class name ({cls.__name__})"
|
||||
)
|
||||
AUDIO_FEATURE_TRANSFORM_REGISTRY[name] = cls
|
||||
AUDIO_FEATURE_TRANSFORM_CLASS_NAMES.add(cls.__name__)
|
||||
return cls
|
||||
|
||||
return register_audio_feature_transform_cls
|
||||
|
||||
|
||||
def get_audio_feature_transform(name):
|
||||
return AUDIO_FEATURE_TRANSFORM_REGISTRY[name]
|
||||
|
||||
|
||||
transforms_dir = os.path.dirname(__file__)
|
||||
for file in os.listdir(transforms_dir):
|
||||
path = os.path.join(transforms_dir, file)
|
||||
if (
|
||||
not file.startswith("_")
|
||||
and not file.startswith(".")
|
||||
and (file.endswith(".py") or os.path.isdir(path))
|
||||
):
|
||||
name = file[: file.find(".py")] if file.endswith(".py") else file
|
||||
importlib.import_module("fairseq.data.audio.feature_transforms." + name)
|
||||
|
||||
|
||||
class CompositeAudioFeatureTransform(AudioFeatureTransform):
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
_transforms = _config.get("transforms")
|
||||
if _transforms is None:
|
||||
return None
|
||||
transforms = [
|
||||
get_audio_feature_transform(_t).from_config_dict(_config.get(_t))
|
||||
for _t in _transforms
|
||||
]
|
||||
return CompositeAudioFeatureTransform(transforms)
|
||||
|
||||
def __init__(self, transforms):
|
||||
self.transforms = [t for t in transforms if t is not None]
|
||||
|
||||
def __call__(self, x):
|
||||
for t in self.transforms:
|
||||
x = t(x)
|
||||
return x
|
||||
|
||||
def __repr__(self):
|
||||
format_string = (
|
||||
[self.__class__.__name__ + "("]
|
||||
+ [f" {t.__repr__()}" for t in self.transforms]
|
||||
+ [")"]
|
||||
)
|
||||
return "\n".join(format_string)
|
|
@ -0,0 +1,25 @@
|
|||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("global_cmvn")
|
||||
class GlobalCMVN(AudioFeatureTransform):
|
||||
"""Global CMVN (cepstral mean and variance normalization). The global mean
|
||||
and variance need to be pre-computed and stored in NumPy format (.npz)."""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return GlobalCMVN(_config.get("stats_npz_path"))
|
||||
|
||||
def __init__(self, stats_npz_path):
|
||||
stats = np.load(stats_npz_path)
|
||||
self.mean, self.std = stats["mean"], stats["std"]
|
||||
|
||||
def __call__(self, x):
|
||||
x = np.subtract(x, self.mean)
|
||||
x = np.divide(x, self.std)
|
||||
return x
|
|
@ -0,0 +1,131 @@
|
|||
import math
|
||||
import numbers
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("specaugment")
|
||||
class SpecAugmentTransform(AudioFeatureTransform):
|
||||
"""SpecAugment (https://arxiv.org/abs/1904.08779)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return SpecAugmentTransform(
|
||||
_config.get("time_warp_W", 0),
|
||||
_config.get("freq_mask_N", 0),
|
||||
_config.get("freq_mask_F", 0),
|
||||
_config.get("time_mask_N", 0),
|
||||
_config.get("time_mask_T", 0),
|
||||
_config.get("time_mask_p", 0.0),
|
||||
_config.get("mask_value", None),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_warp_w: int = 0,
|
||||
freq_mask_n: int = 0,
|
||||
freq_mask_f: int = 0,
|
||||
time_mask_n: int = 0,
|
||||
time_mask_t: int = 0,
|
||||
time_mask_p: float = 0.0,
|
||||
mask_value: Optional[float] = 0.0,
|
||||
):
|
||||
# Sanity checks
|
||||
assert mask_value is None or isinstance(
|
||||
mask_value, numbers.Number
|
||||
), f"mask_value (type: {type(mask_value)}) must be None or a number"
|
||||
if freq_mask_n > 0:
|
||||
assert freq_mask_f > 0, (
|
||||
f"freq_mask_F ({freq_mask_f}) "
|
||||
f"must be larger than 0 when doing freq masking."
|
||||
)
|
||||
if time_mask_n > 0:
|
||||
assert time_mask_t > 0, (
|
||||
f"time_mask_T ({time_mask_t}) must be larger than 0 when "
|
||||
f"doing time masking."
|
||||
)
|
||||
|
||||
self.time_warp_w = time_warp_w
|
||||
self.freq_mask_n = freq_mask_n
|
||||
self.freq_mask_f = freq_mask_f
|
||||
self.time_mask_n = time_mask_n
|
||||
self.time_mask_t = time_mask_t
|
||||
self.time_mask_p = time_mask_p
|
||||
self.mask_value = mask_value
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ "("
|
||||
+ ", ".join(
|
||||
[
|
||||
f"time_warp_w={self.time_warp_w}",
|
||||
f"freq_mask_n={self.freq_mask_n}",
|
||||
f"freq_mask_f={self.freq_mask_f}",
|
||||
f"time_mask_n={self.time_mask_n}",
|
||||
f"time_mask_t={self.time_mask_t}",
|
||||
f"time_mask_p={self.time_mask_p}",
|
||||
]
|
||||
)
|
||||
+ ")"
|
||||
)
|
||||
|
||||
def __call__(self, spectrogram):
|
||||
assert len(spectrogram.shape) == 2, "spectrogram must be a 2-D tensor."
|
||||
|
||||
distorted = spectrogram.copy() # make a copy of input spectrogram.
|
||||
num_frames = spectrogram.shape[0] # or 'tau' in the paper.
|
||||
num_freqs = spectrogram.shape[1] # or 'miu' in the paper.
|
||||
mask_value = self.mask_value
|
||||
|
||||
if mask_value is None: # if no value was specified, use local mean.
|
||||
mask_value = spectrogram.mean()
|
||||
|
||||
if num_frames == 0:
|
||||
return spectrogram
|
||||
|
||||
if num_freqs < self.freq_mask_f:
|
||||
return spectrogram
|
||||
|
||||
if self.time_warp_w > 0:
|
||||
if 2 * self.time_warp_w < num_frames:
|
||||
import cv2
|
||||
|
||||
w0 = np.random.randint(self.time_warp_w, num_frames - self.time_warp_w)
|
||||
w = np.random.randint(0, self.time_warp_w)
|
||||
upper, lower = distorted[:w0, :], distorted[w0:, :]
|
||||
upper = cv2.resize(
|
||||
upper, dsize=(num_freqs, w0 + w), interpolation=cv2.INTER_LINEAR
|
||||
)
|
||||
lower = cv2.resize(
|
||||
lower,
|
||||
dsize=(num_freqs, num_frames - w0 - w),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
distorted = np.concatenate((upper, lower), axis=0)
|
||||
|
||||
for _i in range(self.freq_mask_n):
|
||||
f = np.random.randint(0, self.freq_mask_f)
|
||||
f0 = np.random.randint(0, num_freqs - f)
|
||||
if f != 0:
|
||||
distorted[:, f0 : f0 + f] = mask_value
|
||||
|
||||
max_time_mask_t = min(
|
||||
self.time_mask_t, math.floor(num_frames * self.time_mask_p)
|
||||
)
|
||||
if max_time_mask_t < 1:
|
||||
return distorted
|
||||
|
||||
for _i in range(self.time_mask_n):
|
||||
t = np.random.randint(0, max_time_mask_t)
|
||||
t0 = np.random.randint(0, num_frames - t)
|
||||
if t != 0:
|
||||
distorted[t0 : t0 + t, :] = mask_value
|
||||
|
||||
return distorted
|
|
@ -0,0 +1,40 @@
|
|||
import numpy as np
|
||||
from fairseq.data.audio.feature_transforms import (
|
||||
AudioFeatureTransform,
|
||||
register_audio_feature_transform,
|
||||
)
|
||||
|
||||
|
||||
@register_audio_feature_transform("utterance_cmvn")
|
||||
class UtteranceCMVN(AudioFeatureTransform):
|
||||
"""Utterance-level CMVN (cepstral mean and variance normalization)"""
|
||||
|
||||
@classmethod
|
||||
def from_config_dict(cls, config=None):
|
||||
_config = {} if config is None else config
|
||||
return UtteranceCMVN(
|
||||
_config.get("norm_means", True),
|
||||
_config.get("norm_vars", True),
|
||||
)
|
||||
|
||||
def __init__(self, norm_means=True, norm_vars=True):
|
||||
self.norm_means, self.norm_vars = norm_means, norm_vars
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f"(norm_means={self.norm_means}, norm_vars={self.norm_vars})"
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
mean = x.mean(axis=0)
|
||||
square_sums = (x ** 2).sum(axis=0)
|
||||
|
||||
if self.norm_means:
|
||||
x = np.subtract(x, mean)
|
||||
if self.norm_vars:
|
||||
var = square_sums / x.shape[0] - mean ** 2
|
||||
std = np.sqrt(np.maximum(var, 1e-10))
|
||||
x = np.divide(x, std)
|
||||
|
||||
return x
|
|
@ -0,0 +1,192 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .. import FairseqDataset
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RawAudioDataset(FairseqDataset):
|
||||
def __init__(
|
||||
self,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=None,
|
||||
shuffle=True,
|
||||
min_length=0,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.sizes = []
|
||||
self.max_sample_size = (
|
||||
max_sample_size if max_sample_size is not None else sys.maxsize
|
||||
)
|
||||
self.min_sample_size = min_sample_size
|
||||
self.min_length = min_length
|
||||
self.pad = pad
|
||||
self.shuffle = shuffle
|
||||
self.normalize = normalize
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sizes)
|
||||
|
||||
def postprocess(self, feats, curr_sample_rate):
|
||||
if feats.dim() == 2:
|
||||
feats = feats.mean(-1)
|
||||
|
||||
if curr_sample_rate != self.sample_rate:
|
||||
raise Exception(f"sample rate: {curr_sample_rate}, need {self.sample_rate}")
|
||||
|
||||
assert feats.dim() == 1, feats.dim()
|
||||
|
||||
if self.normalize:
|
||||
with torch.no_grad():
|
||||
feats = F.layer_norm(feats, feats.shape)
|
||||
return feats
|
||||
|
||||
def crop_to_max_size(self, wav, target_size):
|
||||
size = len(wav)
|
||||
diff = size - target_size
|
||||
if diff <= 0:
|
||||
return wav
|
||||
|
||||
start = np.random.randint(0, diff + 1)
|
||||
end = size - diff + start
|
||||
return wav[start:end]
|
||||
|
||||
def collater(self, samples):
|
||||
samples = [s for s in samples if s["source"] is not None]
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
|
||||
sources = [s["source"] for s in samples]
|
||||
sizes = [len(s) for s in sources]
|
||||
|
||||
if self.pad:
|
||||
target_size = min(max(sizes), self.max_sample_size)
|
||||
else:
|
||||
target_size = min(min(sizes), self.max_sample_size)
|
||||
|
||||
collated_sources = sources[0].new_zeros(len(sources), target_size)
|
||||
padding_mask = (
|
||||
torch.BoolTensor(collated_sources.shape).fill_(False) if self.pad else None
|
||||
)
|
||||
for i, (source, size) in enumerate(zip(sources, sizes)):
|
||||
diff = size - target_size
|
||||
if diff == 0:
|
||||
collated_sources[i] = source
|
||||
elif diff < 0:
|
||||
assert self.pad
|
||||
collated_sources[i] = torch.cat(
|
||||
[source, source.new_full((-diff,), 0.0)]
|
||||
)
|
||||
padding_mask[i, diff:] = True
|
||||
else:
|
||||
collated_sources[i] = self.crop_to_max_size(source, target_size)
|
||||
|
||||
input = {"source": collated_sources}
|
||||
if self.pad:
|
||||
input["padding_mask"] = padding_mask
|
||||
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.size(index)
|
||||
|
||||
def size(self, index):
|
||||
"""Return an example's size as a float or tuple. This value is used when
|
||||
filtering a dataset with ``--max-positions``."""
|
||||
if self.pad:
|
||||
return self.sizes[index]
|
||||
return min(self.sizes[index], self.max_sample_size)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""Return an ordered list of indices. Batches will be constructed based
|
||||
on this order."""
|
||||
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
|
||||
order.append(self.sizes)
|
||||
return np.lexsort(order)[::-1]
|
||||
|
||||
|
||||
class FileAudioDataset(RawAudioDataset):
|
||||
def __init__(
|
||||
self,
|
||||
manifest_path,
|
||||
sample_rate,
|
||||
max_sample_size=None,
|
||||
min_sample_size=None,
|
||||
shuffle=True,
|
||||
min_length=0,
|
||||
pad=False,
|
||||
normalize=False,
|
||||
):
|
||||
super().__init__(
|
||||
sample_rate=sample_rate,
|
||||
max_sample_size=max_sample_size,
|
||||
min_sample_size=min_sample_size,
|
||||
shuffle=shuffle,
|
||||
min_length=min_length,
|
||||
pad=pad,
|
||||
normalize=normalize,
|
||||
)
|
||||
|
||||
self.fnames = []
|
||||
self.skipped = []
|
||||
|
||||
skipped = 0
|
||||
count = 0
|
||||
with open(manifest_path, "r") as f:
|
||||
self.root_dir = f.readline().strip()
|
||||
for line in f:
|
||||
count += 1
|
||||
items = line.strip().split("\t")
|
||||
#assert len(items) == 2, line
|
||||
sz = int(items[1])
|
||||
if len(items) == 3:
|
||||
sr = int(items[2])
|
||||
assert sz % (sr / self.sample_rate) == 0
|
||||
sz = sz / (sr / self.sample_rate)
|
||||
if min_length is not None and sz < min_length:
|
||||
skipped += 1
|
||||
self.skipped.append(count)
|
||||
continue
|
||||
if pad and max_sample_size is not None and sz > max_sample_size:
|
||||
skipped += 1
|
||||
self.skipped.append(count)
|
||||
continue
|
||||
self.fnames.append(items[0])
|
||||
self.sizes.append(int(sz))
|
||||
self.sizes = np.array(self.sizes)
|
||||
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")
|
||||
|
||||
def __getitem__(self, index):
|
||||
import soundfile as sf
|
||||
|
||||
fname = os.path.join(self.root_dir, self.fnames[index])
|
||||
wav, curr_sample_rate = sf.read(fname)
|
||||
#wav, curr_sample_rate = librosa.load(fname, sr=self.sample_rate)
|
||||
feats = torch.from_numpy(wav).float()
|
||||
feats = self.postprocess(feats, curr_sample_rate)
|
||||
return {"id": index, "source": feats}
|
|
@ -0,0 +1,528 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os.path as op
|
||||
import re
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairseq.data import (
|
||||
ConcatDataset,
|
||||
Dictionary,
|
||||
FairseqDataset,
|
||||
ResamplingDataset,
|
||||
data_utils as fairseq_data_utils,
|
||||
)
|
||||
from fairseq.data.audio.audio_utils import get_fbank, get_waveform
|
||||
from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class S2TDataConfig(object):
|
||||
"""Wrapper class for data config YAML"""
|
||||
|
||||
def __init__(self, yaml_path):
|
||||
try:
|
||||
import yaml
|
||||
except ImportError:
|
||||
print("Please install PyYAML to load YAML files for " "S2T data config")
|
||||
self.config = {}
|
||||
if op.isfile(yaml_path):
|
||||
try:
|
||||
with open(yaml_path) as f:
|
||||
self.config = yaml.load(f, Loader=yaml.FullLoader)
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to load config from {yaml_path}: {e}")
|
||||
else:
|
||||
logger.info(f"Cannot find {yaml_path}")
|
||||
|
||||
@property
|
||||
def vocab_filename(self):
|
||||
"""fairseq vocabulary file under data root"""
|
||||
return self.config.get("vocab_filename", "dict.txt")
|
||||
|
||||
@property
|
||||
def shuffle(self) -> bool:
|
||||
"""Shuffle dataset samples before batching"""
|
||||
return self.config.get("shuffle", False)
|
||||
|
||||
@property
|
||||
def pre_tokenizer(self) -> Dict:
|
||||
"""Pre-tokenizer to apply before subword tokenization. Returning
|
||||
a dictionary with `tokenizer` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("pre_tokenizer", {"tokenizer": None})
|
||||
|
||||
@property
|
||||
def bpe_tokenizer(self) -> Dict:
|
||||
"""Subword tokenizer to apply after pre-tokenization. Returning
|
||||
a dictionary with `bpe` providing the tokenizer name and
|
||||
the other items providing the tokenizer-specific arguments.
|
||||
Tokenizers are defined in `fairseq.data.encoders.*`"""
|
||||
return self.config.get("bpe_tokenizer", {"bpe": None})
|
||||
|
||||
@property
|
||||
def prepend_tgt_lang_tag(self) -> bool:
|
||||
"""Prepend target lang ID token as the target BOS (e.g. for to-many
|
||||
multilingual setting). During inference, this requires `--prefix-size 1`
|
||||
to force BOS to be lang ID token."""
|
||||
return self.config.get("prepend_tgt_lang_tag", False)
|
||||
|
||||
@property
|
||||
def input_feat_per_channel(self):
|
||||
"""The dimension of input features (per audio channel)"""
|
||||
return self.config.get("input_feat_per_channel", 80)
|
||||
|
||||
@property
|
||||
def input_channels(self):
|
||||
"""The number of channels in the input audio"""
|
||||
return self.config.get("input_channels", 1)
|
||||
|
||||
@property
|
||||
def sampling_alpha(self):
|
||||
"""Hyper-parameter alpha = 1/T for temperature-based resampling.
|
||||
(alpha = 1 for no resampling)"""
|
||||
return self.config.get("sampling_alpha", 1.0)
|
||||
|
||||
@property
|
||||
def use_audio_input(self):
|
||||
"""Needed by the dataset loader to see if the model requires
|
||||
raw audio as inputs."""
|
||||
return self.config.get("use_audio_input", False)
|
||||
|
||||
@property
|
||||
def audio_root(self):
|
||||
"""Audio paths in the manifest TSV can be relative and this provides
|
||||
the root path. Set this to empty string when using absolute paths."""
|
||||
return self.config.get("audio_root", "")
|
||||
|
||||
def get_feature_transforms(self, split, is_train):
|
||||
"""Split-specific feature transforms. Allowing train set wildcard `_train`,
|
||||
evaluation set wildcard `_eval` and general wildcard `*` for matching."""
|
||||
from copy import deepcopy
|
||||
|
||||
cfg = deepcopy(self.config)
|
||||
_cur = cfg.get("transforms", {})
|
||||
cur = _cur.get(split)
|
||||
cur = _cur.get("_train") if cur is None and is_train else cur
|
||||
cur = _cur.get("_eval") if cur is None and not is_train else cur
|
||||
cur = _cur.get("*") if cur is None else cur
|
||||
cfg["transforms"] = cur
|
||||
return cfg
|
||||
|
||||
|
||||
def is_npy_data(data: bytes) -> bool:
|
||||
return data[0] == 147 and data[1] == 78
|
||||
|
||||
|
||||
def is_flac_or_wav_data(data: bytes) -> bool:
|
||||
is_flac = data[0] == 102 and data[1] == 76
|
||||
is_wav = data[0] == 82 and data[1] == 73
|
||||
return is_flac or is_wav
|
||||
|
||||
|
||||
def read_from_uncompressed_zip(file_path, offset, file_size) -> bytes:
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
data = f.read(file_size)
|
||||
return data
|
||||
|
||||
|
||||
def get_features_from_npy_or_audio(path):
|
||||
ext = op.splitext(op.basename(path))[1]
|
||||
if ext not in {".npy", ".flac", ".wav"}:
|
||||
raise ValueError(f'Unsupported file format for "{path}"')
|
||||
return np.load(path) if ext == ".npy" else get_fbank(path)
|
||||
|
||||
|
||||
def get_features_or_waveform_from_uncompressed_zip(
|
||||
path, byte_offset, byte_size, need_waveform=False
|
||||
):
|
||||
assert path.endswith(".zip")
|
||||
data = read_from_uncompressed_zip(path, byte_offset, byte_size)
|
||||
f = io.BytesIO(data)
|
||||
if is_npy_data(data):
|
||||
features_or_waveform = np.load(f)
|
||||
elif is_flac_or_wav_data(data):
|
||||
features_or_waveform = get_waveform(f)[0] if need_waveform else get_fbank(f)
|
||||
else:
|
||||
raise ValueError(f'Unknown file format for "{path}"')
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def get_features_or_waveform(path: str, need_waveform=False):
|
||||
"""Get speech features from .npy file or waveform from .wav/.flac file.
|
||||
The file may be inside an uncompressed ZIP file and is accessed via byte
|
||||
offset and length.
|
||||
|
||||
Args:
|
||||
path (str): File path in the format of "<.npy/.wav/.flac path>" or
|
||||
"<zip path>:<byte offset>:<byte length>".
|
||||
need_waveform (bool): return waveform instead of features.
|
||||
|
||||
Returns:
|
||||
features_or_waveform (numpy.ndarray): speech features or waveform.
|
||||
"""
|
||||
_path, *extra = path.split(":")
|
||||
if not op.exists(_path):
|
||||
raise FileNotFoundError(f"File not found: {_path}")
|
||||
|
||||
if len(extra) == 0:
|
||||
if need_waveform:
|
||||
return get_waveform(_path)
|
||||
return get_features_from_npy_or_audio(_path)
|
||||
elif len(extra) == 2:
|
||||
extra = [int(i) for i in extra]
|
||||
features_or_waveform = get_features_or_waveform_from_uncompressed_zip(
|
||||
_path, extra[0], extra[1], need_waveform=need_waveform
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid path: {path}")
|
||||
|
||||
return features_or_waveform
|
||||
|
||||
|
||||
def _collate_frames(
|
||||
frames: List[torch.Tensor], is_audio_input: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a list of 2D frames into a padded 3D tensor
|
||||
Args:
|
||||
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
|
||||
length of i-th frame and f_dim is static dimension of features
|
||||
Returns:
|
||||
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
|
||||
"""
|
||||
max_len = max(frame.size(0) for frame in frames)
|
||||
if is_audio_input:
|
||||
out = frames[0].new_zeros((len(frames), max_len))
|
||||
else:
|
||||
out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1)))
|
||||
for i, v in enumerate(frames):
|
||||
out[i, : v.size(0)] = v
|
||||
return out
|
||||
|
||||
|
||||
class SpeechToTextDataset(FairseqDataset):
|
||||
LANG_TAG_TEMPLATE = "<lang:{}>"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
split: str,
|
||||
is_train_split: bool,
|
||||
data_cfg: S2TDataConfig,
|
||||
audio_paths: List[str],
|
||||
n_frames: List[int],
|
||||
src_texts: Optional[List[str]] = None,
|
||||
tgt_texts: Optional[List[str]] = None,
|
||||
speakers: Optional[List[str]] = None,
|
||||
src_langs: Optional[List[str]] = None,
|
||||
tgt_langs: Optional[List[str]] = None,
|
||||
ids: Optional[List[str]] = None,
|
||||
tgt_dict: Optional[Dictionary] = None,
|
||||
pre_tokenizer=None,
|
||||
bpe_tokenizer=None,
|
||||
):
|
||||
self.split, self.is_train_split = split, is_train_split
|
||||
self.data_cfg = data_cfg
|
||||
self.audio_paths, self.n_frames = audio_paths, n_frames
|
||||
self.n_samples = len(audio_paths)
|
||||
assert len(n_frames) == self.n_samples > 0
|
||||
assert src_texts is None or len(src_texts) == self.n_samples
|
||||
assert tgt_texts is None or len(tgt_texts) == self.n_samples
|
||||
assert speakers is None or len(speakers) == self.n_samples
|
||||
assert src_langs is None or len(src_langs) == self.n_samples
|
||||
assert tgt_langs is None or len(tgt_langs) == self.n_samples
|
||||
assert ids is None or len(ids) == self.n_samples
|
||||
assert (tgt_dict is None and tgt_texts is None) or (
|
||||
tgt_dict is not None and tgt_texts is not None
|
||||
)
|
||||
self.src_texts, self.tgt_texts = src_texts, tgt_texts
|
||||
self.src_langs, self.tgt_langs = src_langs, tgt_langs
|
||||
self.tgt_dict = tgt_dict
|
||||
self.check_tgt_lang_tag()
|
||||
self.ids = ids
|
||||
self.shuffle = data_cfg.shuffle if is_train_split else False
|
||||
|
||||
self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict(
|
||||
self.data_cfg.get_feature_transforms(split, is_train_split)
|
||||
)
|
||||
|
||||
self.pre_tokenizer = pre_tokenizer
|
||||
self.bpe_tokenizer = bpe_tokenizer
|
||||
|
||||
logger.info(self.__repr__())
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
self.__class__.__name__
|
||||
+ f'(split="{self.split}", n_samples={self.n_samples}, '
|
||||
f"prepend_tgt_lang_tag={self.data_cfg.prepend_tgt_lang_tag}, "
|
||||
f"shuffle={self.shuffle}, transforms={self.feature_transforms})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_lang_tag(cls, token):
|
||||
pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)")
|
||||
return re.match(pattern, token)
|
||||
|
||||
def check_tgt_lang_tag(self):
|
||||
if self.data_cfg.prepend_tgt_lang_tag:
|
||||
assert self.tgt_langs is not None and self.tgt_dict is not None
|
||||
tgt_lang_tags = [
|
||||
self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs)
|
||||
]
|
||||
assert all(t in self.tgt_dict for t in tgt_lang_tags)
|
||||
|
||||
def tokenize_text(self, text: str):
|
||||
if self.pre_tokenizer is not None:
|
||||
text = self.pre_tokenizer.encode(text)
|
||||
if self.bpe_tokenizer is not None:
|
||||
text = self.bpe_tokenizer.encode(text)
|
||||
return text
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[int, torch.Tensor, Optional[torch.Tensor]]:
|
||||
source = get_features_or_waveform(
|
||||
self.audio_paths[index], need_waveform=self.data_cfg.use_audio_input
|
||||
)
|
||||
if self.feature_transforms is not None:
|
||||
assert not self.data_cfg.use_audio_input
|
||||
source = self.feature_transforms(source)
|
||||
source = torch.from_numpy(source).float()
|
||||
|
||||
target = None
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.tokenize_text(self.tgt_texts[index])
|
||||
target = self.tgt_dict.encode_line(
|
||||
tokenized, add_if_not_exist=False, append_eos=True
|
||||
).long()
|
||||
if self.data_cfg.prepend_tgt_lang_tag:
|
||||
lang_tag = self.LANG_TAG_TEMPLATE.format(self.tgt_langs[index])
|
||||
lang_tag_idx = self.tgt_dict.index(lang_tag)
|
||||
target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0)
|
||||
return index, source, target
|
||||
|
||||
def __len__(self):
|
||||
return self.n_samples
|
||||
|
||||
def collater(self, samples: List[Tuple[int, torch.Tensor, torch.Tensor]]) -> Dict:
|
||||
if len(samples) == 0:
|
||||
return {}
|
||||
indices = torch.tensor([i for i, _, _ in samples], dtype=torch.long)
|
||||
frames = _collate_frames(
|
||||
[s for _, s, _ in samples], self.data_cfg.use_audio_input
|
||||
)
|
||||
# sort samples by descending number of frames
|
||||
n_frames = torch.tensor([s.size(0) for _, s, _ in samples], dtype=torch.long)
|
||||
n_frames, order = n_frames.sort(descending=True)
|
||||
indices = indices.index_select(0, order)
|
||||
frames = frames.index_select(0, order)
|
||||
|
||||
target, target_lengths = None, None
|
||||
prev_output_tokens = None
|
||||
ntokens = None
|
||||
if self.tgt_texts is not None:
|
||||
target = fairseq_data_utils.collate_tokens(
|
||||
[t for _, _, t in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
)
|
||||
target = target.index_select(0, order)
|
||||
target_lengths = torch.tensor(
|
||||
[t.size(0) for _, _, t in samples], dtype=torch.long
|
||||
).index_select(0, order)
|
||||
prev_output_tokens = fairseq_data_utils.collate_tokens(
|
||||
[t for _, _, t in samples],
|
||||
self.tgt_dict.pad(),
|
||||
self.tgt_dict.eos(),
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=True,
|
||||
)
|
||||
prev_output_tokens = prev_output_tokens.index_select(0, order)
|
||||
ntokens = sum(t.size(0) for _, _, t in samples)
|
||||
|
||||
out = {
|
||||
"id": indices,
|
||||
"net_input": {
|
||||
"src_tokens": frames,
|
||||
"src_lengths": n_frames,
|
||||
"prev_output_tokens": prev_output_tokens,
|
||||
},
|
||||
"target": target,
|
||||
"target_lengths": target_lengths,
|
||||
"ntokens": ntokens,
|
||||
"nsentences": len(samples),
|
||||
}
|
||||
return out
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.n_frames[index]
|
||||
|
||||
def size(self, index):
|
||||
t_len = 0
|
||||
if self.tgt_texts is not None:
|
||||
tokenized = self.tokenize_text(self.tgt_texts[index])
|
||||
t_len = len(tokenized.split(" "))
|
||||
return self.n_frames[index], t_len
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return np.array(self.n_frames)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return True
|
||||
|
||||
def ordered_indices(self):
|
||||
if self.shuffle:
|
||||
order = [np.random.permutation(len(self))]
|
||||
else:
|
||||
order = [np.arange(len(self))]
|
||||
# first by descending order of # of frames then by original/random order
|
||||
order.append([-n for n in self.n_frames])
|
||||
return np.lexsort(order)
|
||||
|
||||
def prefetch(self, indices):
|
||||
raise False
|
||||
|
||||
|
||||
class SpeechToTextDatasetCreator(object):
|
||||
# mandatory columns
|
||||
KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames"
|
||||
KEY_TGT_TEXT = "tgt_text"
|
||||
# optional columns
|
||||
KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text"
|
||||
KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang"
|
||||
# default values
|
||||
DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = ""
|
||||
|
||||
@classmethod
|
||||
def _from_list(
|
||||
cls,
|
||||
split_name: str,
|
||||
is_train_split,
|
||||
samples: List[List[Dict]],
|
||||
data_cfg: S2TDataConfig,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
) -> SpeechToTextDataset:
|
||||
audio_paths, n_frames, src_texts, tgt_texts, ids = [], [], [], [], []
|
||||
speakers, src_langs, tgt_langs = [], [], []
|
||||
for s in samples:
|
||||
ids.extend([ss[cls.KEY_ID] for ss in s])
|
||||
audio_paths.extend(
|
||||
[op.join(data_cfg.audio_root, ss[cls.KEY_AUDIO]) for ss in s]
|
||||
)
|
||||
n_frames.extend([int(ss[cls.KEY_N_FRAMES]) for ss in s])
|
||||
tgt_texts.extend([ss[cls.KEY_TGT_TEXT] for ss in s])
|
||||
src_texts.extend(
|
||||
[ss.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for ss in s]
|
||||
)
|
||||
speakers.extend([ss.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for ss in s])
|
||||
src_langs.extend([ss.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for ss in s])
|
||||
tgt_langs.extend([ss.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for ss in s])
|
||||
return SpeechToTextDataset(
|
||||
split_name,
|
||||
is_train_split,
|
||||
data_cfg,
|
||||
audio_paths,
|
||||
n_frames,
|
||||
src_texts,
|
||||
tgt_texts,
|
||||
speakers,
|
||||
src_langs,
|
||||
tgt_langs,
|
||||
ids,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_size_ratios(cls, ids: List[str], sizes: List[int], alpha: float = 1.0):
|
||||
"""Size ratios for temperature-based sampling
|
||||
(https://arxiv.org/abs/1907.05019)"""
|
||||
_sizes = np.array(sizes)
|
||||
prob = _sizes / _sizes.sum()
|
||||
smoothed_prob = prob ** alpha
|
||||
smoothed_prob = smoothed_prob / smoothed_prob.sum()
|
||||
size_ratio = (smoothed_prob * _sizes.sum()) / _sizes
|
||||
|
||||
o_str = str({_i: f"{prob[i]:.3f}" for i, _i in enumerate(ids)})
|
||||
logger.info(f"original sampling probability: {o_str}")
|
||||
p_str = str({_i: f"{smoothed_prob[i]:.3f}" for i, _i in enumerate(ids)})
|
||||
logger.info(f"balanced sampling probability: {p_str}")
|
||||
sr_str = str({_id: f"{size_ratio[i]:.3f}" for i, _id in enumerate(ids)})
|
||||
logger.info(f"balanced sampling size ratio: {sr_str}")
|
||||
return size_ratio.tolist()
|
||||
|
||||
@classmethod
|
||||
def from_tsv(
|
||||
cls,
|
||||
root: str,
|
||||
data_cfg: S2TDataConfig,
|
||||
splits: str,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
is_train_split: bool,
|
||||
epoch: int,
|
||||
seed: int,
|
||||
) -> SpeechToTextDataset:
|
||||
samples = []
|
||||
_splits = splits.split(",")
|
||||
for split in _splits:
|
||||
tsv_path = op.join(root, f"{split}.tsv")
|
||||
if not op.isfile(tsv_path):
|
||||
raise FileNotFoundError(f"Dataset not found: {tsv_path}")
|
||||
with open(tsv_path) as f:
|
||||
reader = csv.DictReader(
|
||||
f,
|
||||
delimiter="\t",
|
||||
quotechar=None,
|
||||
doublequote=False,
|
||||
lineterminator="\n",
|
||||
quoting=csv.QUOTE_NONE,
|
||||
)
|
||||
samples.append([dict(e) for e in reader])
|
||||
assert len(samples) > 0
|
||||
|
||||
datasets = [
|
||||
cls._from_list(
|
||||
name,
|
||||
is_train_split,
|
||||
[s],
|
||||
data_cfg,
|
||||
tgt_dict,
|
||||
pre_tokenizer,
|
||||
bpe_tokenizer,
|
||||
)
|
||||
for name, s in zip(_splits, samples)
|
||||
]
|
||||
|
||||
if is_train_split and len(_splits) > 1 and data_cfg.sampling_alpha != 1.0:
|
||||
# temperature-based sampling
|
||||
size_ratios = cls._get_size_ratios(
|
||||
_splits, [len(s) for s in samples], alpha=data_cfg.sampling_alpha
|
||||
)
|
||||
datasets = [
|
||||
ResamplingDataset(
|
||||
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0)
|
||||
)
|
||||
for d, r in zip(datasets, size_ratios)
|
||||
]
|
||||
return ConcatDataset(datasets)
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class BaseWrapperDataset(FairseqDataset):
|
||||
def __init__(self, dataset):
|
||||
super().__init__()
|
||||
self.dataset = dataset
|
||||
|
||||
def __getitem__(self, index):
|
||||
return self.dataset[index]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def collater(self, samples):
|
||||
if hasattr(self.dataset, "collater"):
|
||||
return self.dataset.collater(samples)
|
||||
else:
|
||||
return default_collate(samples)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return self.dataset.sizes
|
||||
|
||||
def num_tokens(self, index):
|
||||
return self.dataset.num_tokens(index)
|
||||
|
||||
def size(self, index):
|
||||
return self.dataset.size(index)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.dataset.ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return getattr(self.dataset, "supports_prefetch", False)
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
return self.dataset.attr(attr, index)
|
||||
|
||||
def prefetch(self, indices):
|
||||
self.dataset.prefetch(indices)
|
||||
|
||||
def get_batch_shapes(self):
|
||||
return self.dataset.get_batch_shapes()
|
||||
|
||||
def batch_by_size(
|
||||
self,
|
||||
indices,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
):
|
||||
return self.dataset.batch_by_size(
|
||||
indices,
|
||||
max_tokens=max_tokens,
|
||||
max_sentences=max_sentences,
|
||||
required_batch_size_multiple=required_batch_size_multiple,
|
||||
)
|
||||
|
||||
def filter_indices_by_size(self, indices, max_sizes):
|
||||
return self.dataset.filter_indices_by_size(indices, max_sizes)
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return self.dataset.can_reuse_epoch_itr_across_epochs
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
if hasattr(self.dataset, "set_epoch"):
|
||||
self.dataset.set_epoch(epoch)
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import bisect
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class ConcatDataset(FairseqDataset):
|
||||
@staticmethod
|
||||
def cumsum(sequence, sample_ratios):
|
||||
r, s = [], 0
|
||||
for e, ratio in zip(sequence, sample_ratios):
|
||||
curr_len = int(ratio * len(e))
|
||||
r.append(curr_len + s)
|
||||
s += curr_len
|
||||
return r
|
||||
|
||||
def __init__(self, datasets, sample_ratios=1):
|
||||
super(ConcatDataset, self).__init__()
|
||||
assert len(datasets) > 0, "datasets should not be an empty iterable"
|
||||
self.datasets = list(datasets)
|
||||
if isinstance(sample_ratios, int):
|
||||
sample_ratios = [sample_ratios] * len(self.datasets)
|
||||
self.sample_ratios = sample_ratios
|
||||
self.cumulative_sizes = self.cumsum(self.datasets, sample_ratios)
|
||||
self.real_sizes = [len(d) for d in self.datasets]
|
||||
|
||||
def __len__(self):
|
||||
return self.cumulative_sizes[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
return self.datasets[dataset_idx][sample_idx]
|
||||
|
||||
def _get_dataset_and_sample_index(self, idx: int):
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
|
||||
if dataset_idx == 0:
|
||||
sample_idx = idx
|
||||
else:
|
||||
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
|
||||
sample_idx = sample_idx % self.real_sizes[dataset_idx]
|
||||
return dataset_idx, sample_idx
|
||||
|
||||
def collater(self, samples, **extra_args):
|
||||
# For now only supports datasets with same underlying collater implementations
|
||||
if hasattr(self.datasets[0], "collater"):
|
||||
return self.datasets[0].collater(samples, **extra_args)
|
||||
else:
|
||||
return default_collate(samples, **extra_args)
|
||||
|
||||
def size(self, idx: int):
|
||||
"""
|
||||
Return an example's size as a float or tuple.
|
||||
"""
|
||||
dataset_idx, sample_idx = self._get_dataset_and_sample_index(idx)
|
||||
return self.datasets[dataset_idx].size(sample_idx)
|
||||
|
||||
def num_tokens(self, index: int):
|
||||
return np.max(self.size(index))
|
||||
|
||||
def attr(self, attr: str, index: int):
|
||||
dataset_idx = bisect.bisect_right(self.cumulative_sizes, index)
|
||||
return getattr(self.datasets[dataset_idx], attr, None)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
_dataset_sizes = []
|
||||
for ds, sr in zip(self.datasets, self.sample_ratios):
|
||||
if isinstance(ds.sizes, np.ndarray):
|
||||
_dataset_sizes.append(np.tile(ds.sizes, sr))
|
||||
else:
|
||||
# Only support underlying dataset with single size array.
|
||||
assert isinstance(ds.sizes, list)
|
||||
_dataset_sizes.append(np.tile(ds.sizes[0], sr))
|
||||
return np.concatenate(_dataset_sizes)
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return all(d.supports_prefetch for d in self.datasets)
|
||||
|
||||
def ordered_indices(self):
|
||||
"""
|
||||
Returns indices sorted by length. So less padding is needed.
|
||||
"""
|
||||
if isinstance(self.sizes, np.ndarray) and len(self.sizes.shape) > 1:
|
||||
# special handling for concatenating lang_pair_datasets
|
||||
indices = np.arange(len(self))
|
||||
sizes = self.sizes
|
||||
tgt_sizes = (
|
||||
sizes[:, 1] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else None
|
||||
)
|
||||
src_sizes = (
|
||||
sizes[:, 0] if len(sizes.shape) > 0 and sizes.shape[1] > 1 else sizes
|
||||
)
|
||||
# sort by target length, then source length
|
||||
if tgt_sizes is not None:
|
||||
indices = indices[np.argsort(tgt_sizes[indices], kind="mergesort")]
|
||||
return indices[np.argsort(src_sizes[indices], kind="mergesort")]
|
||||
else:
|
||||
return np.argsort(self.sizes)
|
||||
|
||||
def prefetch(self, indices):
|
||||
frm = 0
|
||||
for to, ds in zip(self.cumulative_sizes, self.datasets):
|
||||
real_size = len(ds)
|
||||
if getattr(ds, "supports_prefetch", False):
|
||||
ds.prefetch([(i - frm) % real_size for i in indices if frm <= i < to])
|
||||
frm = to
|
||||
|
||||
@property
|
||||
def can_reuse_epoch_itr_across_epochs(self):
|
||||
return all(d.can_reuse_epoch_itr_across_epochs for d in self.datasets)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
for ds in self.datasets:
|
||||
if hasattr(ds, "set_epoch"):
|
||||
ds.set_epoch(epoch)
|
|
@ -0,0 +1,54 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from . import FairseqDataset
|
||||
|
||||
|
||||
class ConcatSentencesDataset(FairseqDataset):
|
||||
def __init__(self, *datasets):
|
||||
super().__init__()
|
||||
self.datasets = datasets
|
||||
assert all(
|
||||
len(ds) == len(datasets[0]) for ds in datasets
|
||||
), "datasets must have the same length"
|
||||
|
||||
def __getitem__(self, index):
|
||||
return torch.cat([ds[index] for ds in self.datasets])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.datasets[0])
|
||||
|
||||
def collater(self, samples):
|
||||
return self.datasets[0].collater(samples)
|
||||
|
||||
@property
|
||||
def sizes(self):
|
||||
return sum(ds.sizes for ds in self.datasets)
|
||||
|
||||
def num_tokens(self, index):
|
||||
return sum(ds.num_tokens(index) for ds in self.datasets)
|
||||
|
||||
def size(self, index):
|
||||
return sum(ds.size(index) for ds in self.datasets)
|
||||
|
||||
def ordered_indices(self):
|
||||
return self.datasets[0].ordered_indices()
|
||||
|
||||
@property
|
||||
def supports_prefetch(self):
|
||||
return any(getattr(ds, "supports_prefetch", False) for ds in self.datasets)
|
||||
|
||||
def prefetch(self, indices):
|
||||
for ds in self.datasets:
|
||||
if getattr(ds, "supports_prefetch", False):
|
||||
ds.prefetch(indices)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
super().set_epoch(epoch)
|
||||
for ds in self.datasets:
|
||||
if hasattr(ds, "set_epoch"):
|
||||
ds.set_epoch(epoch)
|
|
@ -0,0 +1,499 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
try:
|
||||
from collections.abc import Iterable
|
||||
except ImportError:
|
||||
from collections import Iterable
|
||||
import contextlib
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def infer_language_pair(path):
|
||||
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
|
||||
src, dst = None, None
|
||||
for filename in os.listdir(path):
|
||||
parts = filename.split(".")
|
||||
if len(parts) >= 3 and len(parts[1].split("-")) == 2:
|
||||
return parts[1].split("-")
|
||||
return src, dst
|
||||
|
||||
|
||||
def collate_tokens(
|
||||
values,
|
||||
pad_idx,
|
||||
eos_idx=None,
|
||||
left_pad=False,
|
||||
move_eos_to_beginning=False,
|
||||
pad_to_length=None,
|
||||
pad_to_multiple=1,
|
||||
):
|
||||
"""Convert a list of 1d tensors into a padded 2d tensor."""
|
||||
size = max(v.size(0) for v in values)
|
||||
size = size if pad_to_length is None else max(size, pad_to_length)
|
||||
if pad_to_multiple != 1 and size % pad_to_multiple != 0:
|
||||
size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple)
|
||||
res = values[0].new(len(values), size).fill_(pad_idx)
|
||||
|
||||
def copy_tensor(src, dst):
|
||||
assert dst.numel() == src.numel()
|
||||
if move_eos_to_beginning:
|
||||
if eos_idx is None:
|
||||
# if no eos_idx is specified, then use the last token in src
|
||||
dst[0] = src[-1]
|
||||
else:
|
||||
dst[0] = eos_idx
|
||||
dst[1:] = src[:-1]
|
||||
else:
|
||||
dst.copy_(src)
|
||||
|
||||
for i, v in enumerate(values):
|
||||
copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)])
|
||||
return res
|
||||
|
||||
|
||||
def load_indexed_dataset(
|
||||
path, dictionary=None, dataset_impl=None, combine=False, default="cached"
|
||||
):
|
||||
"""A helper function for loading indexed datasets.
|
||||
|
||||
Args:
|
||||
path (str): path to indexed dataset (e.g., 'data-bin/train')
|
||||
dictionary (~fairseq.data.Dictionary): data dictionary
|
||||
dataset_impl (str, optional): which dataset implementation to use. If
|
||||
not provided, it will be inferred automatically. For legacy indexed
|
||||
data we use the 'cached' implementation by default.
|
||||
combine (bool, optional): automatically load and combine multiple
|
||||
datasets. For example, if *path* is 'data-bin/train', then we will
|
||||
combine 'data-bin/train', 'data-bin/train1', ... and return a
|
||||
single ConcatDataset instance.
|
||||
"""
|
||||
from fairseq.data.concat_dataset import ConcatDataset
|
||||
import fairseq.data.indexed_dataset as indexed_dataset
|
||||
|
||||
datasets = []
|
||||
for k in itertools.count():
|
||||
path_k = path + (str(k) if k > 0 else "")
|
||||
path_k = indexed_dataset.get_indexed_dataset_to_local(path_k)
|
||||
|
||||
dataset_impl_k = dataset_impl
|
||||
if dataset_impl_k is None:
|
||||
dataset_impl_k = indexed_dataset.infer_dataset_impl(path_k)
|
||||
dataset = indexed_dataset.make_dataset(
|
||||
path_k,
|
||||
impl=dataset_impl_k or default,
|
||||
fix_lua_indexing=True,
|
||||
dictionary=dictionary,
|
||||
)
|
||||
if dataset is None:
|
||||
break
|
||||
logger.info("loaded {} examples from: {}".format(len(dataset), path_k))
|
||||
datasets.append(dataset)
|
||||
if not combine:
|
||||
break
|
||||
if len(datasets) == 0:
|
||||
return None
|
||||
elif len(datasets) == 1:
|
||||
return datasets[0]
|
||||
else:
|
||||
return ConcatDataset(datasets)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def numpy_seed(seed, *addl_seeds):
|
||||
"""Context manager which seeds the NumPy PRNG with the specified seed and
|
||||
restores the state afterward"""
|
||||
if seed is None:
|
||||
yield
|
||||
return
|
||||
if len(addl_seeds) > 0:
|
||||
seed = int(hash((seed, *addl_seeds)) % 1e6)
|
||||
state = np.random.get_state()
|
||||
np.random.seed(seed)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
np.random.set_state(state)
|
||||
|
||||
|
||||
def collect_filtered(function, iterable, filtered):
|
||||
"""
|
||||
Similar to :func:`filter` but collects filtered elements in ``filtered``.
|
||||
|
||||
Args:
|
||||
function (callable): function that returns ``False`` for elements that
|
||||
should be filtered
|
||||
iterable (iterable): iterable to filter
|
||||
filtered (list): list to store filtered elements
|
||||
"""
|
||||
for el in iterable:
|
||||
if function(el):
|
||||
yield el
|
||||
else:
|
||||
filtered.append(el)
|
||||
|
||||
|
||||
def _filter_by_size_dynamic(indices, size_fn, max_positions, raise_exception=False):
|
||||
def compare_leq(a, b):
|
||||
return a <= b if not isinstance(a, tuple) else max(a) <= b
|
||||
|
||||
def check_size(idx):
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
return size_fn(idx) <= max_positions
|
||||
elif isinstance(max_positions, dict):
|
||||
idx_size = size_fn(idx)
|
||||
assert isinstance(idx_size, dict)
|
||||
intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
|
||||
return all(
|
||||
all(
|
||||
a is None or b is None or a <= b
|
||||
for a, b in zip(idx_size[key], max_positions[key])
|
||||
)
|
||||
for key in intersect_keys
|
||||
)
|
||||
else:
|
||||
# Hacky as heck, for the specific case of multilingual training with RoundRobin.
|
||||
if isinstance(size_fn(idx), dict) and isinstance(max_positions, tuple):
|
||||
return all(
|
||||
a is None or b is None or compare_leq(a, b)
|
||||
for a, b in zip(size_fn(idx).values(), max_positions)
|
||||
)
|
||||
# For MultiCorpusSampledDataset, will generalize it later
|
||||
if not isinstance(size_fn(idx), Iterable):
|
||||
return all(size_fn(idx) <= b for b in max_positions)
|
||||
return all(
|
||||
a is None or b is None or a <= b
|
||||
for a, b in zip(size_fn(idx), max_positions)
|
||||
)
|
||||
|
||||
ignored = []
|
||||
itr = collect_filtered(check_size, indices, ignored)
|
||||
indices = np.fromiter(itr, dtype=np.int64, count=-1)
|
||||
return indices, ignored
|
||||
|
||||
|
||||
def filter_by_size(indices, dataset, max_positions, raise_exception=False):
|
||||
"""
|
||||
[deprecated] Filter indices based on their size.
|
||||
Use `FairseqDataset::filter_indices_by_size` instead.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
dataset (FairseqDataset): fairseq dataset instance
|
||||
max_positions (tuple): filter elements larger than this size.
|
||||
Comparisons are done component-wise.
|
||||
raise_exception (bool, optional): if ``True``, raise an exception if
|
||||
any elements are filtered (default: False).
|
||||
"""
|
||||
warnings.warn(
|
||||
"data_utils.filter_by_size is deprecated. "
|
||||
"Use `FairseqDataset::filter_indices_by_size` instead.",
|
||||
stacklevel=2,
|
||||
)
|
||||
if isinstance(max_positions, float) or isinstance(max_positions, int):
|
||||
if hasattr(dataset, "sizes") and isinstance(dataset.sizes, np.ndarray):
|
||||
ignored = indices[dataset.sizes[indices] > max_positions].tolist()
|
||||
indices = indices[dataset.sizes[indices] <= max_positions]
|
||||
elif (
|
||||
hasattr(dataset, "sizes")
|
||||
and isinstance(dataset.sizes, list)
|
||||
and len(dataset.sizes) == 1
|
||||
):
|
||||
ignored = indices[dataset.sizes[0][indices] > max_positions].tolist()
|
||||
indices = indices[dataset.sizes[0][indices] <= max_positions]
|
||||
else:
|
||||
indices, ignored = _filter_by_size_dynamic(
|
||||
indices, dataset.size, max_positions
|
||||
)
|
||||
else:
|
||||
indices, ignored = _filter_by_size_dynamic(indices, dataset.size, max_positions)
|
||||
|
||||
if len(ignored) > 0 and raise_exception:
|
||||
raise Exception(
|
||||
(
|
||||
"Size of sample #{} is invalid (={}) since max_positions={}, "
|
||||
"skip this example with --skip-invalid-size-inputs-valid-test"
|
||||
).format(ignored[0], dataset.size(ignored[0]), max_positions)
|
||||
)
|
||||
if len(ignored) > 0:
|
||||
logger.warning(
|
||||
(
|
||||
"{} samples have invalid sizes and will be skipped, "
|
||||
"max_positions={}, first few sample ids={}"
|
||||
).format(len(ignored), max_positions, ignored[:10])
|
||||
)
|
||||
return indices
|
||||
|
||||
|
||||
def filter_paired_dataset_indices_by_size(src_sizes, tgt_sizes, indices, max_sizes):
|
||||
"""Filter a list of sample indices. Remove those that are longer
|
||||
than specified in max_sizes.
|
||||
|
||||
Args:
|
||||
indices (np.array): original array of sample indices
|
||||
max_sizes (int or list[int] or tuple[int]): max sample size,
|
||||
can be defined separately for src and tgt (then list or tuple)
|
||||
|
||||
Returns:
|
||||
np.array: filtered sample array
|
||||
list: list of removed indices
|
||||
"""
|
||||
if max_sizes is None:
|
||||
return indices, []
|
||||
if type(max_sizes) in (int, float):
|
||||
max_src_size, max_tgt_size = max_sizes, max_sizes
|
||||
else:
|
||||
max_src_size, max_tgt_size = max_sizes
|
||||
if tgt_sizes is None:
|
||||
ignored = indices[src_sizes[indices] > max_src_size]
|
||||
else:
|
||||
ignored = indices[
|
||||
(src_sizes[indices] > max_src_size) | (tgt_sizes[indices] > max_tgt_size)
|
||||
]
|
||||
if len(ignored) > 0:
|
||||
if tgt_sizes is None:
|
||||
indices = indices[src_sizes[indices] <= max_src_size]
|
||||
else:
|
||||
indices = indices[
|
||||
(src_sizes[indices] <= max_src_size)
|
||||
& (tgt_sizes[indices] <= max_tgt_size)
|
||||
]
|
||||
return indices, ignored.tolist()
|
||||
|
||||
|
||||
def batch_by_size(
|
||||
indices,
|
||||
num_tokens_fn,
|
||||
max_tokens=None,
|
||||
max_sentences=None,
|
||||
required_batch_size_multiple=1,
|
||||
fixed_shapes=None,
|
||||
):
|
||||
"""
|
||||
Yield mini-batches of indices bucketed by size. Batches may contain
|
||||
sequences of different lengths.
|
||||
|
||||
Args:
|
||||
indices (List[int]): ordered list of dataset indices
|
||||
num_tokens_fn (callable): function that returns the number of tokens at
|
||||
a given index
|
||||
max_tokens (int, optional): max number of tokens in each batch
|
||||
(default: None).
|
||||
max_sentences (int, optional): max number of sentences in each
|
||||
batch (default: None).
|
||||
required_batch_size_multiple (int, optional): require batch size to
|
||||
be less than N or a multiple of N (default: 1).
|
||||
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
|
||||
only be created with the given shapes. *max_sentences* and
|
||||
*required_batch_size_multiple* will be ignored (default: None).
|
||||
"""
|
||||
try:
|
||||
from fairseq.data.data_utils_fast import (
|
||||
batch_by_size_fast,
|
||||
batch_fixed_shapes_fast,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please build Cython components with: `pip install --editable .` "
|
||||
"or `python setup.py build_ext --inplace`"
|
||||
)
|
||||
|
||||
max_tokens = max_tokens if max_tokens is not None else -1
|
||||
max_sentences = max_sentences if max_sentences is not None else -1
|
||||
bsz_mult = required_batch_size_multiple
|
||||
|
||||
if not isinstance(indices, np.ndarray):
|
||||
indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
||||
|
||||
if fixed_shapes is None:
|
||||
return batch_by_size_fast(
|
||||
indices,
|
||||
num_tokens_fn,
|
||||
max_tokens,
|
||||
max_sentences,
|
||||
bsz_mult,
|
||||
)
|
||||
else:
|
||||
fixed_shapes = np.array(fixed_shapes, dtype=np.int64)
|
||||
sort_order = np.lexsort(
|
||||
[
|
||||
fixed_shapes[:, 1].argsort(), # length
|
||||
fixed_shapes[:, 0].argsort(), # bsz
|
||||
]
|
||||
)
|
||||
fixed_shapes_sorted = fixed_shapes[sort_order]
|
||||
return batch_fixed_shapes_fast(indices, num_tokens_fn, fixed_shapes_sorted)
|
||||
|
||||
|
||||
def post_process(sentence: str, symbol: str):
|
||||
if symbol == "sentencepiece":
|
||||
sentence = sentence.replace(" ", "").replace("\u2581", " ").strip()
|
||||
elif symbol == "wordpiece":
|
||||
sentence = sentence.replace(" ", "").replace("_", " ").strip()
|
||||
elif symbol == "letter":
|
||||
sentence = sentence.replace(" ", "").replace("|", " ").strip()
|
||||
elif symbol == "_EOW":
|
||||
sentence = sentence.replace(" ", "").replace("_EOW", " ").strip()
|
||||
elif symbol is not None and symbol != "none":
|
||||
sentence = (sentence + " ").replace(symbol, "").rstrip()
|
||||
return sentence
|
||||
|
||||
|
||||
def compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
mask_type: str = "static",
|
||||
mask_other: float = 0.0,
|
||||
min_masks: int = 0,
|
||||
no_overlap: bool = False,
|
||||
min_space: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_type: how to compute mask lengths
|
||||
static = fixed size
|
||||
uniform = sample from uniform distribution [mask_other, mask_length*2]
|
||||
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
|
||||
poisson = sample from possion distribution with lambda = mask length
|
||||
min_masks: minimum number of masked spans
|
||||
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
|
||||
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
|
||||
"""
|
||||
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
if mask_type == "static":
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
elif mask_type == "uniform":
|
||||
lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask)
|
||||
elif mask_type == "normal":
|
||||
lengths = np.random.normal(mask_length, mask_other, size=num_mask)
|
||||
lengths = [max(1, int(round(x))) for x in lengths]
|
||||
elif mask_type == "poisson":
|
||||
lengths = np.random.poisson(mask_length, size=num_mask)
|
||||
lengths = [int(round(x)) for x in lengths]
|
||||
else:
|
||||
raise Exception("unknown mask selection " + mask_type)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
if no_overlap:
|
||||
mask_idc = []
|
||||
|
||||
def arrange(s, e, length, keep_length):
|
||||
span_start = np.random.randint(s, e - length)
|
||||
mask_idc.extend(span_start + i for i in range(length))
|
||||
|
||||
new_parts = []
|
||||
if span_start - s - min_space >= keep_length:
|
||||
new_parts.append((s, span_start - min_space + 1))
|
||||
if e - span_start - keep_length - min_space > keep_length:
|
||||
new_parts.append((span_start + length + min_space, e))
|
||||
return new_parts
|
||||
|
||||
parts = [(0, sz)]
|
||||
min_length = min(lengths)
|
||||
for length in sorted(lengths, reverse=True):
|
||||
lens = np.fromiter(
|
||||
(e - s if e - s >= length + min_space else 0 for s, e in parts),
|
||||
np.int,
|
||||
)
|
||||
l_sum = np.sum(lens)
|
||||
if l_sum == 0:
|
||||
break
|
||||
probs = lens / np.sum(lens)
|
||||
c = np.random.choice(len(parts), p=probs)
|
||||
s, e = parts.pop(c)
|
||||
parts.extend(arrange(s, e, length, min_length))
|
||||
mask_idc = np.asarray(mask_idc)
|
||||
else:
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
|
||||
mask_idc = np.asarray(
|
||||
[
|
||||
mask_idc[j] + offset
|
||||
for j in range(len(mask_idc))
|
||||
for offset in range(lengths[j])
|
||||
]
|
||||
)
|
||||
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def get_mem_usage():
|
||||
try:
|
||||
import psutil
|
||||
|
||||
mb = 1024 * 1024
|
||||
return f"used={psutil.virtual_memory().used / mb}Mb; avail={psutil.virtual_memory().available / mb}Mb"
|
||||
except ImportError:
|
||||
return "N/A"
|
||||
|
||||
|
||||
def lengths_to_padding_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
||||
bsz, max_lens = lens.size(0), torch.max(lens).item()
|
||||
mask = torch.arange(max_lens).to(lens.device).view(1, max_lens)
|
||||
mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens)
|
||||
return mask
|
||||
|
||||
|
||||
def lengths_to_mask(lens: torch.LongTensor) -> torch.BoolTensor:
|
||||
return ~lengths_to_padding_mask(lens)
|
|
@ -0,0 +1,123 @@
|
|||
# cython: language_level=3
|
||||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
cimport cython
|
||||
cimport numpy as np
|
||||
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
|
||||
ctypedef int64_t DTYPE_t
|
||||
|
||||
|
||||
cdef _is_batch_full(int64_t num_sentences, int64_t num_tokens, int64_t max_tokens, int64_t max_sentences):
|
||||
if num_sentences == 0:
|
||||
return 0
|
||||
if max_sentences > 0 and num_sentences == max_sentences:
|
||||
return 1
|
||||
if max_tokens > 0 and num_tokens > max_tokens:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
@cython.cdivision(True)
|
||||
cpdef list batch_by_size_fast(
|
||||
np.ndarray[DTYPE_t, ndim=1] indices,
|
||||
num_tokens_fn,
|
||||
int64_t max_tokens,
|
||||
int64_t max_sentences,
|
||||
int32_t bsz_mult,
|
||||
):
|
||||
cdef int64_t sample_len = 0
|
||||
cdef list sample_lens = []
|
||||
cdef list batch = []
|
||||
cdef list batches = []
|
||||
cdef int64_t mod_len
|
||||
cdef int64_t i
|
||||
cdef int64_t idx
|
||||
cdef int64_t num_tokens
|
||||
cdef DTYPE_t[:] indices_view = indices
|
||||
|
||||
for i in range(len(indices_view)):
|
||||
idx = indices_view[i]
|
||||
num_tokens = num_tokens_fn(idx)
|
||||
sample_lens.append(num_tokens)
|
||||
sample_len = max(sample_len, num_tokens)
|
||||
|
||||
assert max_tokens <= 0 or sample_len <= max_tokens, (
|
||||
"sentence at index {} of size {} exceeds max_tokens "
|
||||
"limit of {}!".format(idx, sample_len, max_tokens)
|
||||
)
|
||||
num_tokens = (len(batch) + 1) * sample_len
|
||||
|
||||
if _is_batch_full(len(batch), num_tokens, max_tokens, max_sentences):
|
||||
mod_len = max(
|
||||
bsz_mult * (len(batch) // bsz_mult),
|
||||
len(batch) % bsz_mult,
|
||||
)
|
||||
batches.append(batch[:mod_len])
|
||||
batch = batch[mod_len:]
|
||||
sample_lens = sample_lens[mod_len:]
|
||||
sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
||||
batch.append(idx)
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
return batches
|
||||
|
||||
|
||||
cdef _find_valid_shape(
|
||||
DTYPE_t[:, :] shapes_view,
|
||||
int64_t num_sentences,
|
||||
int64_t num_tokens,
|
||||
):
|
||||
"""Return index of first valid shape of -1 if none is found."""
|
||||
for i in range(shapes_view.shape[0]):
|
||||
if num_sentences <= shapes_view[i][0] and num_tokens <= shapes_view[i][1]:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
@cython.cdivision(True)
|
||||
cpdef list batch_fixed_shapes_fast(
|
||||
np.ndarray[DTYPE_t, ndim=1] indices,
|
||||
num_tokens_fn,
|
||||
np.ndarray[DTYPE_t, ndim=2] fixed_shapes_sorted,
|
||||
):
|
||||
cdef int64_t sample_len = 0
|
||||
cdef list sample_lens = []
|
||||
cdef list batch = []
|
||||
cdef list batches = []
|
||||
cdef int64_t mod_len
|
||||
cdef int64_t i
|
||||
cdef int64_t idx
|
||||
cdef int64_t num_tokens
|
||||
cdef DTYPE_t[:] indices_view = indices
|
||||
cdef DTYPE_t[:, :] shapes_view = fixed_shapes_sorted
|
||||
|
||||
for i in range(len(indices_view)):
|
||||
idx = indices_view[i]
|
||||
num_tokens = num_tokens_fn(idx)
|
||||
sample_lens.append(num_tokens)
|
||||
sample_len = max(sample_len, num_tokens)
|
||||
|
||||
shape_idx = _find_valid_shape(shapes_view, len(batch) + 1, sample_len)
|
||||
if shape_idx == -1:
|
||||
batches.append(batch)
|
||||
batch = []
|
||||
sample_lens = []
|
||||
sample_len = 0
|
||||
shapes_view = fixed_shapes_sorted
|
||||
elif shape_idx > 0:
|
||||
# small optimization for the next call to _find_valid_shape
|
||||
shapes_view = shapes_view[shape_idx:]
|
||||
|
||||
batch.append(idx)
|
||||
|
||||
if len(batch) > 0:
|
||||
batches.append(batch)
|
||||
|
||||
return batches
|
|
@ -0,0 +1,418 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from collections import Counter
|
||||
from multiprocessing import Pool
|
||||
|
||||
import torch
|
||||
from fairseq import utils
|
||||
from fairseq.binarizer import safe_readline
|
||||
from fairseq.data import data_utils
|
||||
from fairseq.file_io import PathManager
|
||||
from fairseq.tokenizer import tokenize_line
|
||||
|
||||
|
||||
class Dictionary(object):
|
||||
"""A mapping from symbols to consecutive integers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_file=None, # begin keyword-only arguments
|
||||
bos="<s>",
|
||||
pad="<pad>",
|
||||
eos="</s>",
|
||||
unk="<unk>",
|
||||
extra_special_symbols=None,
|
||||
):
|
||||
|
||||
self.symbols = []
|
||||
self.count = []
|
||||
self.indices = {}
|
||||
if input_file is not None and 'json' in input_file:
|
||||
self.add_from_json(input_file)
|
||||
else:
|
||||
self.add_from_file(input_file)
|
||||
self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos
|
||||
self.bos_index = self.add_symbol(bos)
|
||||
self.pad_index = self.add_symbol(pad)
|
||||
self.eos_index = self.add_symbol(eos)
|
||||
self.unk_index = self.add_symbol(unk)
|
||||
if extra_special_symbols:
|
||||
for s in extra_special_symbols:
|
||||
self.add_symbol(s)
|
||||
self.nspecial = len(self.symbols)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.indices == other.indices
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if idx < len(self.symbols):
|
||||
return self.symbols[idx]
|
||||
return self.unk_word
|
||||
|
||||
def __len__(self):
|
||||
"""Returns the number of symbols in the dictionary"""
|
||||
return len(self.symbols)
|
||||
|
||||
def __contains__(self, sym):
|
||||
return sym in self.indices
|
||||
|
||||
def index(self, sym):
|
||||
"""Returns the index of the specified symbol"""
|
||||
assert isinstance(sym, str)
|
||||
if sym in self.indices:
|
||||
return self.indices[sym]
|
||||
return self.unk_index
|
||||
|
||||
def string(
|
||||
self,
|
||||
tensor,
|
||||
bpe_symbol=None,
|
||||
escape_unk=False,
|
||||
extra_symbols_to_ignore=None,
|
||||
unk_string=None,
|
||||
):
|
||||
"""Helper for converting a tensor of token indices to a string.
|
||||
|
||||
Can optionally remove BPE symbols or escape <unk> words.
|
||||
"""
|
||||
if torch.is_tensor(tensor) and tensor.dim() == 2:
|
||||
return "\n".join(
|
||||
self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore)
|
||||
for t in tensor
|
||||
)
|
||||
|
||||
extra_symbols_to_ignore = set(extra_symbols_to_ignore or [])
|
||||
extra_symbols_to_ignore.add(self.eos())
|
||||
|
||||
def token_string(i):
|
||||
if i == self.unk():
|
||||
if unk_string is not None:
|
||||
return unk_string
|
||||
else:
|
||||
return self.unk_string(escape_unk)
|
||||
else:
|
||||
return self[i]
|
||||
|
||||
if hasattr(self, "bos_index"):
|
||||
extra_symbols_to_ignore.add(self.bos())
|
||||
|
||||
sent = " ".join(
|
||||
token_string(i)
|
||||
for i in tensor
|
||||
if utils.item(i) not in extra_symbols_to_ignore
|
||||
)
|
||||
|
||||
return data_utils.post_process(sent, bpe_symbol)
|
||||
|
||||
def unk_string(self, escape=False):
|
||||
"""Return unknown string, optionally escaped as: <<unk>>"""
|
||||
if escape:
|
||||
return "<{}>".format(self.unk_word)
|
||||
else:
|
||||
return self.unk_word
|
||||
|
||||
def add_symbol(self, word, n=1, overwrite=False):
|
||||
"""Adds a word to the dictionary"""
|
||||
if word in self.indices and not overwrite:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + n
|
||||
return idx
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(n)
|
||||
return idx
|
||||
|
||||
def update(self, new_dict):
|
||||
"""Updates counts from new dictionary."""
|
||||
for word in new_dict.symbols:
|
||||
idx2 = new_dict.indices[word]
|
||||
if word in self.indices:
|
||||
idx = self.indices[word]
|
||||
self.count[idx] = self.count[idx] + new_dict.count[idx2]
|
||||
else:
|
||||
idx = len(self.symbols)
|
||||
self.indices[word] = idx
|
||||
self.symbols.append(word)
|
||||
self.count.append(new_dict.count[idx2])
|
||||
|
||||
def finalize(self, threshold=-1, nwords=-1, padding_factor=8):
|
||||
"""Sort symbols by frequency in descending order, ignoring special ones.
|
||||
|
||||
Args:
|
||||
- threshold defines the minimum word count
|
||||
- nwords defines the total number of words in the final dictionary,
|
||||
including special symbols
|
||||
- padding_factor can be used to pad the dictionary size to be a
|
||||
multiple of 8, which is important on some hardware (e.g., Nvidia
|
||||
Tensor Cores).
|
||||
"""
|
||||
if nwords <= 0:
|
||||
nwords = len(self)
|
||||
|
||||
new_indices = dict(zip(self.symbols[: self.nspecial], range(self.nspecial)))
|
||||
new_symbols = self.symbols[: self.nspecial]
|
||||
new_count = self.count[: self.nspecial]
|
||||
|
||||
c = Counter(
|
||||
dict(
|
||||
sorted(zip(self.symbols[self.nspecial :], self.count[self.nspecial :]))
|
||||
)
|
||||
)
|
||||
for symbol, count in c.most_common(nwords - self.nspecial):
|
||||
if count >= threshold:
|
||||
new_indices[symbol] = len(new_symbols)
|
||||
new_symbols.append(symbol)
|
||||
new_count.append(count)
|
||||
else:
|
||||
break
|
||||
|
||||
assert len(new_symbols) == len(new_indices)
|
||||
|
||||
self.count = list(new_count)
|
||||
self.symbols = list(new_symbols)
|
||||
self.indices = new_indices
|
||||
|
||||
self.pad_to_multiple_(padding_factor)
|
||||
|
||||
def pad_to_multiple_(self, padding_factor):
|
||||
"""Pad Dictionary size to be a multiple of *padding_factor*."""
|
||||
if padding_factor > 1:
|
||||
i = 0
|
||||
while len(self) % padding_factor != 0:
|
||||
symbol = "madeupword{:04d}".format(i)
|
||||
self.add_symbol(symbol, n=0)
|
||||
i += 1
|
||||
|
||||
def bos(self):
|
||||
"""Helper to get index of beginning-of-sentence symbol"""
|
||||
return self.bos_index
|
||||
|
||||
def pad(self):
|
||||
"""Helper to get index of pad symbol"""
|
||||
return self.pad_index
|
||||
|
||||
def eos(self):
|
||||
"""Helper to get index of end-of-sentence symbol"""
|
||||
return self.eos_index
|
||||
|
||||
def unk(self):
|
||||
"""Helper to get index of unk symbol"""
|
||||
return self.unk_index
|
||||
|
||||
@classmethod
|
||||
def load(cls, f):
|
||||
"""Loads the dictionary from a text file with the format:
|
||||
|
||||
```
|
||||
<symbol0> <count0>
|
||||
<symbol1> <count1>
|
||||
...
|
||||
```
|
||||
"""
|
||||
d = cls()
|
||||
d.add_from_file(f)
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, f):
|
||||
d = cls()
|
||||
d.add_from_json(f)
|
||||
return d
|
||||
|
||||
def add_from_json(self, f):
|
||||
import json
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
||||
self.add_from_json(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception(
|
||||
"Incorrect encoding detected in {}, please "
|
||||
"rebuild the dataset".format(f)
|
||||
)
|
||||
return
|
||||
|
||||
vocab = json.load(f)
|
||||
for k, v in vocab.items():
|
||||
self.add_symbol(k)
|
||||
|
||||
|
||||
def add_from_file(self, f):
|
||||
"""
|
||||
Loads a pre-existing dictionary from a text file and adds its symbols
|
||||
to this instance.
|
||||
"""
|
||||
if isinstance(f, str):
|
||||
try:
|
||||
with open(PathManager.get_local_path(f), "r", encoding="utf-8") as fd:
|
||||
self.add_from_file(fd)
|
||||
except FileNotFoundError as fnfe:
|
||||
raise fnfe
|
||||
except UnicodeError:
|
||||
raise Exception(
|
||||
"Incorrect encoding detected in {}, please "
|
||||
"rebuild the dataset".format(f)
|
||||
)
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
indices_start_line = self._load_meta(lines)
|
||||
|
||||
for line in lines[indices_start_line:]:
|
||||
try:
|
||||
line, field = line.rstrip().rsplit(" ", 1)
|
||||
if field == "#fairseq:overwrite":
|
||||
overwrite = True
|
||||
line, field = line.rsplit(" ", 1)
|
||||
else:
|
||||
overwrite = False
|
||||
count = int(field)
|
||||
word = line
|
||||
if word in self and not overwrite:
|
||||
raise RuntimeError(
|
||||
"Duplicate word found when loading Dictionary: '{}'. "
|
||||
"Duplicate words can overwrite earlier ones by adding the "
|
||||
"#fairseq:overwrite flag at the end of the corresponding row "
|
||||
"in the dictionary file. If using the Camembert model, please "
|
||||
"download an updated copy of the model file.".format(word)
|
||||
)
|
||||
self.add_symbol(word, n=count, overwrite=overwrite)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"Incorrect dictionary format, expected '<token> <cnt> [flags]'"
|
||||
)
|
||||
|
||||
def _save(self, f, kv_iterator):
|
||||
if isinstance(f, str):
|
||||
PathManager.mkdirs(os.path.dirname(f))
|
||||
with PathManager.open(f, "w", encoding="utf-8") as fd:
|
||||
return self.save(fd)
|
||||
for k, v in kv_iterator:
|
||||
print("{} {}".format(k, v), file=f)
|
||||
|
||||
def _get_meta(self):
|
||||
return [], []
|
||||
|
||||
def _load_meta(self, lines):
|
||||
return 0
|
||||
|
||||
def save(self, f):
|
||||
"""Stores dictionary into a text file"""
|
||||
ex_keys, ex_vals = self._get_meta()
|
||||
self._save(
|
||||
f,
|
||||
zip(
|
||||
ex_keys + self.symbols[self.nspecial :],
|
||||
ex_vals + self.count[self.nspecial :],
|
||||
),
|
||||
)
|
||||
|
||||
def dummy_sentence(self, length):
|
||||
t = torch.Tensor(length).uniform_(self.nspecial + 1, len(self)).long()
|
||||
t[-1] = self.eos()
|
||||
return t
|
||||
|
||||
def encode_line(
|
||||
self,
|
||||
line,
|
||||
line_tokenizer=tokenize_line,
|
||||
add_if_not_exist=True,
|
||||
consumer=None,
|
||||
append_eos=True,
|
||||
reverse_order=False,
|
||||
):
|
||||
words = line_tokenizer(line)
|
||||
if reverse_order:
|
||||
words = list(reversed(words))
|
||||
nwords = len(words)
|
||||
ids = torch.IntTensor(nwords + 1 if append_eos else nwords)
|
||||
|
||||
for i, word in enumerate(words):
|
||||
if add_if_not_exist:
|
||||
idx = self.add_symbol(word)
|
||||
else:
|
||||
idx = self.index(word)
|
||||
if consumer is not None:
|
||||
consumer(word, idx)
|
||||
ids[i] = idx
|
||||
if append_eos:
|
||||
ids[nwords] = self.eos_index
|
||||
return ids
|
||||
|
||||
@staticmethod
|
||||
def _add_file_to_dictionary_single_worker(
|
||||
filename, tokenize, eos_word, worker_id=0, num_workers=1
|
||||
):
|
||||
counter = Counter()
|
||||
with open(PathManager.get_local_path(filename), "r", encoding="utf-8") as f:
|
||||
size = os.fstat(f.fileno()).st_size
|
||||
chunk_size = size // num_workers
|
||||
offset = worker_id * chunk_size
|
||||
end = offset + chunk_size
|
||||
f.seek(offset)
|
||||
if offset > 0:
|
||||
safe_readline(f) # drop first incomplete line
|
||||
line = f.readline()
|
||||
while line:
|
||||
for word in tokenize(line):
|
||||
counter.update([word])
|
||||
counter.update([eos_word])
|
||||
if f.tell() > end:
|
||||
break
|
||||
line = f.readline()
|
||||
return counter
|
||||
|
||||
@staticmethod
|
||||
def add_file_to_dictionary(filename, dict, tokenize, num_workers):
|
||||
def merge_result(counter):
|
||||
for w, c in sorted(counter.items()):
|
||||
dict.add_symbol(w, c)
|
||||
|
||||
if num_workers > 1:
|
||||
pool = Pool(processes=num_workers)
|
||||
results = []
|
||||
for worker_id in range(num_workers):
|
||||
results.append(
|
||||
pool.apply_async(
|
||||
Dictionary._add_file_to_dictionary_single_worker,
|
||||
(filename, tokenize, dict.eos_word, worker_id, num_workers),
|
||||
)
|
||||
)
|
||||
pool.close()
|
||||
pool.join()
|
||||
for r in results:
|
||||
merge_result(r.get())
|
||||
else:
|
||||
merge_result(
|
||||
Dictionary._add_file_to_dictionary_single_worker(
|
||||
filename, tokenize, dict.eos_word
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TruncatedDictionary(object):
|
||||
def __init__(self, wrapped_dict, length):
|
||||
self.__class__ = type(
|
||||
wrapped_dict.__class__.__name__,
|
||||
(self.__class__, wrapped_dict.__class__),
|
||||
{},
|
||||
)
|
||||
self.__dict__ = wrapped_dict.__dict__
|
||||
self.wrapped_dict = wrapped_dict
|
||||
self.length = min(len(self.wrapped_dict), length)
|
||||
|
||||
def __len__(self):
|
||||
return self.length
|
||||
|
||||
def __getitem__(self, i):
|
||||
if i < self.length:
|
||||
return self.wrapped_dict[i]
|
||||
return self.wrapped_dict.unk()
|
|
@ -0,0 +1,29 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
||||
from fairseq import registry
|
||||
|
||||
|
||||
build_tokenizer, register_tokenizer, TOKENIZER_REGISTRY, _ = registry.setup_registry(
|
||||
"--tokenizer",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
build_bpe, register_bpe, BPE_REGISTRY, _ = registry.setup_registry(
|
||||
"--bpe",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
# automatically import any Python files in the encoders/ directory
|
||||
for file in os.listdir(os.path.dirname(__file__)):
|
||||
if file.endswith(".py") and not file.startswith("_"):
|
||||
module = file[: file.find(".py")]
|
||||
importlib.import_module("fairseq.data.encoders." + module)
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.data.encoders.byte_utils import (
|
||||
SPACE,
|
||||
SPACE_ESCAPE,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ByteBpeConfig(FairseqDataclass):
|
||||
sentencepiece_model_path: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("byte_bpe", dataclass=ByteBpeConfig)
|
||||
class ByteBPE(object):
|
||||
def __init__(self, cfg):
|
||||
vocab = file_utils.cached_path(cfg.sentencepiece_model_path)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(vocab)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install sentencepiece with: pip install sentencepiece"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
byte_encoded = byte_encode(x)
|
||||
return SPACE.join(self.sp.EncodeAsPieces(byte_encoded))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
||||
return smart_byte_decode(unescaped)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import re
|
||||
|
||||
|
||||
WHITESPACE_NORMALIZER = re.compile(r"\s+")
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
# excluding non-breaking space (160) here
|
||||
PRINTABLE_LATIN = set(
|
||||
list(range(32, 126 + 1)) + list(range(161, 172 + 1)) + list(range(174, 255 + 1))
|
||||
)
|
||||
BYTE_TO_BCHAR = {
|
||||
b: chr(b) if b in PRINTABLE_LATIN else chr(256 + b) for b in range(256)
|
||||
}
|
||||
BCHAR_TO_BYTE = {bc: b for b, bc in BYTE_TO_BCHAR.items()}
|
||||
|
||||
|
||||
def byte_encode(x: str) -> str:
|
||||
normalized = WHITESPACE_NORMALIZER.sub(SPACE, x)
|
||||
return "".join([BYTE_TO_BCHAR[b] for b in normalized.encode("utf-8")])
|
||||
|
||||
|
||||
def byte_decode(x: str) -> str:
|
||||
try:
|
||||
return bytes([BCHAR_TO_BYTE[bc] for bc in x]).decode("utf-8")
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
|
||||
def smart_byte_decode(x: str) -> str:
|
||||
output = byte_decode(x)
|
||||
if output == "":
|
||||
# DP the best recovery (max valid chars) if it's broken
|
||||
n_bytes = len(x)
|
||||
f = [0 for _ in range(n_bytes + 1)]
|
||||
pt = [0 for _ in range(n_bytes + 1)]
|
||||
for i in range(1, n_bytes + 1):
|
||||
f[i], pt[i] = f[i - 1], i - 1
|
||||
for j in range(1, min(4, i) + 1):
|
||||
if f[i - j] + 1 > f[i] and len(byte_decode(x[i - j : i])) > 0:
|
||||
f[i], pt[i] = f[i - j] + 1, i - j
|
||||
cur_pt = n_bytes
|
||||
while cur_pt > 0:
|
||||
if f[cur_pt] == f[pt[cur_pt]] + 1:
|
||||
output = byte_decode(x[pt[cur_pt] : cur_pt]) + output
|
||||
cur_pt = pt[cur_pt]
|
||||
return output
|
|
@ -0,0 +1,34 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.data.encoders.byte_utils import (
|
||||
SPACE,
|
||||
SPACE_ESCAPE,
|
||||
byte_encode,
|
||||
smart_byte_decode,
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("bytes")
|
||||
class Bytes(object):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encode(x: str) -> str:
|
||||
encoded = byte_encode(x)
|
||||
escaped = encoded.replace(SPACE, SPACE_ESCAPE)
|
||||
return SPACE.join(list(escaped))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
unescaped = x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
||||
return smart_byte_decode(unescaped)
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
|
||||
|
||||
SPACE = chr(32)
|
||||
SPACE_ESCAPE = chr(9601)
|
||||
|
||||
|
||||
@register_bpe("characters")
|
||||
class Characters(object):
|
||||
def __init__(self, *unused):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def add_args(parser):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def encode(x: str) -> str:
|
||||
escaped = x.replace(SPACE, SPACE_ESCAPE)
|
||||
return SPACE.join(list(escaped))
|
||||
|
||||
@staticmethod
|
||||
def decode(x: str) -> str:
|
||||
return x.replace(SPACE, "").replace(SPACE_ESCAPE, SPACE)
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class fastBPEConfig(FairseqDataclass):
|
||||
bpe_codes: str = field(default="???", metadata={"help": "path to fastBPE BPE"})
|
||||
|
||||
|
||||
@register_bpe("fastbpe", dataclass=fastBPEConfig)
|
||||
class fastBPE(object):
|
||||
def __init__(self, cfg):
|
||||
if cfg.bpe_codes is None:
|
||||
raise ValueError("--bpe-codes is required for --bpe=fastbpe")
|
||||
codes = file_utils.cached_path(cfg.bpe_codes)
|
||||
try:
|
||||
import fastBPE
|
||||
|
||||
self.bpe = fastBPE.fastBPE(codes)
|
||||
self.bpe_symbol = "@@ "
|
||||
except ImportError:
|
||||
raise ImportError("Please install fastBPE with: pip install fastBPE")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.bpe.apply([x])[0]
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return (x + " ").replace(self.bpe_symbol, "").rstrip()
|
|
@ -0,0 +1,45 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
from .gpt2_bpe_utils import get_encoder
|
||||
|
||||
|
||||
DEFAULT_ENCODER_JSON = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"
|
||||
DEFAULT_VOCAB_BPE = "https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPT2BPEConfig(FairseqDataclass):
|
||||
gpt2_encoder_json: str = field(
|
||||
default=DEFAULT_ENCODER_JSON, metadata={"help": "path to encoder.json"}
|
||||
)
|
||||
gpt2_vocab_bpe: str = field(
|
||||
default=DEFAULT_VOCAB_BPE, metadata={"help": "path to vocab.bpe"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("gpt2", dataclass=GPT2BPEConfig)
|
||||
class GPT2BPE(object):
|
||||
def __init__(self, cfg):
|
||||
encoder_json = file_utils.cached_path(cfg.gpt2_encoder_json)
|
||||
vocab_bpe = file_utils.cached_path(cfg.gpt2_vocab_bpe)
|
||||
self.bpe = get_encoder(encoder_json, vocab_bpe)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x)))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bpe.decode(
|
||||
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return self.decode(x).startswith(" ")
|
|
@ -0,0 +1,140 @@
|
|||
"""
|
||||
Byte pair encoding utilities from GPT-2.
|
||||
|
||||
Original source: https://github.com/openai/gpt-2/blob/master/src/encoder.py
|
||||
Original license: MIT
|
||||
"""
|
||||
|
||||
import json
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = (
|
||||
list(range(ord("!"), ord("~") + 1))
|
||||
+ list(range(ord("¡"), ord("¬") + 1))
|
||||
+ list(range(ord("®"), ord("ÿ") + 1))
|
||||
)
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2 ** 8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2 ** 8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
class Encoder:
|
||||
def __init__(self, encoder, bpe_merges, errors="replace"):
|
||||
self.encoder = encoder
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.errors = errors # how to handle errors in decoding
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
||||
self.cache = {}
|
||||
|
||||
try:
|
||||
import regex as re
|
||||
|
||||
self.re = re
|
||||
except ImportError:
|
||||
raise ImportError("Please install regex with: pip install regex")
|
||||
|
||||
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
|
||||
self.pat = self.re.compile(
|
||||
r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
|
||||
)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = " ".join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
for token in self.re.findall(self.pat, text):
|
||||
token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
|
||||
bpe_tokens.extend(
|
||||
self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
|
||||
)
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = "".join([self.decoder.get(token, token) for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
"utf-8", errors=self.errors
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def get_encoder(encoder_json_path, vocab_bpe_path):
|
||||
with open(encoder_json_path, "r") as f:
|
||||
encoder = json.load(f)
|
||||
with open(vocab_bpe_path, "r", encoding="utf-8") as f:
|
||||
bpe_data = f.read()
|
||||
bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]]
|
||||
return Encoder(
|
||||
encoder=encoder,
|
||||
bpe_merges=bpe_merges,
|
||||
)
|
|
@ -0,0 +1,50 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BertBPEConfig(FairseqDataclass):
|
||||
bpe_cased: bool = field(default=False, metadata={"help": "set for cased BPE"})
|
||||
bpe_vocab_file: Optional[str] = field(
|
||||
default=None, metadata={"help": "bpe vocab file"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("bert", dataclass=BertBPEConfig)
|
||||
class BertBPE(object):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from transformers import BertTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install transformers with: pip install transformers"
|
||||
)
|
||||
|
||||
if cfg.bpe_vocab_file:
|
||||
self.bert_tokenizer = BertTokenizer(
|
||||
cfg.bpe_vocab_file, do_lower_case=not cfg.bpe_cased
|
||||
)
|
||||
else:
|
||||
vocab_file_name = (
|
||||
"bert-base-cased" if cfg.bpe_cased else "bert-base-uncased"
|
||||
)
|
||||
self.bert_tokenizer = BertTokenizer.from_pretrained(vocab_file_name)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(self.bert_tokenizer.tokenize(x))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bert_tokenizer.clean_up_tokenization(
|
||||
self.bert_tokenizer.convert_tokens_to_string(x.split(" "))
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return not x.startswith("##")
|
|
@ -0,0 +1,46 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class HuggingFaceByteLevelBPEConfig(FairseqDataclass):
|
||||
bpe_merges: str = field(default="???", metadata={"help": "path to merges.txt"})
|
||||
bpe_vocab: str = field(default="???", metadata={"help": "path to vocab.json"})
|
||||
bpe_add_prefix_space: bool = field(
|
||||
default=False, metadata={"help": "add prefix space before encoding"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("hf_byte_bpe", dataclass=HuggingFaceByteLevelBPEConfig)
|
||||
class HuggingFaceByteLevelBPE(object):
|
||||
def __init__(self, cfg):
|
||||
try:
|
||||
from tokenizers import ByteLevelBPETokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install huggingface/tokenizers with: " "pip install tokenizers"
|
||||
)
|
||||
|
||||
self.bpe = ByteLevelBPETokenizer(
|
||||
cfg.bpe_vocab,
|
||||
cfg.bpe_merges,
|
||||
add_prefix_space=cfg.bpe_add_prefix_space,
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(map(str, self.bpe.encode(x).ids))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.bpe.decode(
|
||||
[int(tok) if tok not in {"<unk>", "<mask>"} else tok for tok in x.split()]
|
||||
)
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
return self.decode(x).startswith(" ")
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MosesTokenizerConfig(FairseqDataclass):
|
||||
source_lang: str = field(default="en", metadata={"help": "source language"})
|
||||
target_lang: str = field(default="en", metadata={"help": "target language"})
|
||||
moses_no_dash_splits: bool = field(
|
||||
default=False, metadata={"help": "don't apply dash split rules"}
|
||||
)
|
||||
moses_no_escape: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "don't perform HTML escaping on apostrophe, quotes, etc."},
|
||||
)
|
||||
|
||||
|
||||
@register_tokenizer("moses", dataclass=MosesTokenizerConfig)
|
||||
class MosesTokenizer(object):
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
|
||||
try:
|
||||
from sacremoses import MosesTokenizer, MosesDetokenizer
|
||||
|
||||
self.tok = MosesTokenizer(cfg.source_lang)
|
||||
self.detok = MosesDetokenizer(cfg.target_lang)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install Moses tokenizer with: pip install sacremoses"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return self.tok.tokenize(
|
||||
x,
|
||||
aggressive_dash_splits=(not self.cfg.moses_no_dash_splits),
|
||||
return_str=True,
|
||||
escape=(not self.cfg.moses_no_escape),
|
||||
)
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return self.detok.detokenize(x.split())
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from fairseq.data.encoders import register_tokenizer
|
||||
|
||||
|
||||
@register_tokenizer("nltk")
|
||||
class NLTKTokenizer(object):
|
||||
def __init__(self, *unused):
|
||||
try:
|
||||
from nltk.tokenize import word_tokenize
|
||||
|
||||
self.word_tokenize = word_tokenize
|
||||
except ImportError:
|
||||
raise ImportError("Please install nltk with: pip install nltk")
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(self.word_tokenize(x))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return x
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||
#
|
||||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fairseq import file_utils
|
||||
from fairseq.data.encoders import register_bpe
|
||||
from fairseq.dataclass import FairseqDataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class SentencepieceConfig(FairseqDataclass):
|
||||
sentencepiece_model: str = field(
|
||||
default="???", metadata={"help": "path to sentencepiece model"}
|
||||
)
|
||||
|
||||
|
||||
@register_bpe("sentencepiece", dataclass=SentencepieceConfig)
|
||||
class SentencepieceBPE(object):
|
||||
def __init__(self, cfg):
|
||||
sentencepiece_model = file_utils.cached_path(cfg.sentencepiece_model)
|
||||
try:
|
||||
import sentencepiece as spm
|
||||
|
||||
self.sp = spm.SentencePieceProcessor()
|
||||
self.sp.Load(sentencepiece_model)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install sentencepiece with: pip install sentencepiece"
|
||||
)
|
||||
|
||||
def encode(self, x: str) -> str:
|
||||
return " ".join(self.sp.EncodeAsPieces(x))
|
||||
|
||||
def decode(self, x: str) -> str:
|
||||
return x.replace(" ", "").replace("\u2581", " ").strip()
|
||||
|
||||
def is_beginning_of_word(self, x: str) -> bool:
|
||||
if x in ["<unk>", "<s>", "</s>", "<pad>"]:
|
||||
# special elements are always considered beginnings
|
||||
# HACK: this logic is already present in fairseq/tasks/masked_lm.py
|
||||
# but these special tokens are also contained in the sentencepiece
|
||||
# vocabulary which causes duplicate special tokens. This hack makes
|
||||
# sure that they are all taken into account.
|
||||
return True
|
||||
return x.startswith("\u2581")
|
Некоторые файлы не были показаны из-за слишком большого количества измененных файлов Показать больше
Загрузка…
Ссылка в новой задаче